diff --git a/integrations/actix-web/Cargo.toml b/integrations/actix-web/Cargo.toml index f01ed276..51a56f87 100644 --- a/integrations/actix-web/Cargo.toml +++ b/integrations/actix-web/Cargo.toml @@ -16,8 +16,8 @@ categories = ["network-programming", "asynchronous"] async-graphql = { path = "../..", version = "2.0.0-alpha.8" } actix-web = "3.0.0" actix-web-actors = "3.0.0" +actix-http = "2.0.0" actix = "0.10.0" actix-rt = "1.1.0" -futures = "0.3.0" -bytes = "0.5.4" +futures = "0.3.5" serde_json = "1.0.48" diff --git a/integrations/actix-web/src/subscription.rs b/integrations/actix-web/src/subscription.rs index 93469f02..13f2261e 100644 --- a/integrations/actix-web/src/subscription.rs +++ b/integrations/actix-web/src/subscription.rs @@ -1,11 +1,13 @@ use actix::{ - Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture, + Actor, ActorContext, ActorFuture, ActorStream, AsyncContext, ContextFutureSpawner, + StreamHandler, WrapFuture, WrapStream, }; +use actix_http::ws; use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext}; -use async_graphql::http::WebSocketStream; +use async_graphql::http::WebSocket; use async_graphql::{resolver_utils::ObjectType, Data, FieldResult, Schema, SubscriptionType}; -use futures::stream::SplitSink; -use futures::{SinkExt, StreamExt}; +use futures::channel::mpsc; +use futures::SinkExt; use std::time::{Duration, Instant}; const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); @@ -13,10 +15,11 @@ const CLIENT_TIMEOUT: Duration = Duration::from_secs(10); /// Actor for subscription via websocket pub struct WSSubscription { - schema: Schema, - hb: Instant, - sink: Option>, - initializer: Option FieldResult + Send + Sync>>, + schema: Option>, + last_heartbeat: Instant, + messages: Option>>, + initializer: Option FieldResult + Send + Sync>>, + continuation: Vec, } impl WSSubscription @@ -26,19 +29,20 @@ where Subscription: SubscriptionType + Send + Sync + 'static, { /// Create an actor for subscription connection via websocket. - pub fn new(schema: &Schema) -> Self { + pub fn new(schema: Schema) -> Self { Self { - schema: schema.clone(), - hb: Instant::now(), - sink: None, + schema: Some(schema), + last_heartbeat: Instant::now(), + messages: None, initializer: None, + continuation: Vec::new(), } } /// Set a context data initialization function. pub fn initializer(self, f: F) -> Self where - F: Fn(serde_json::Value) -> FieldResult + Send + Sync + 'static, + F: FnOnce(serde_json::Value) -> FieldResult + Send + Sync + 'static, { Self { initializer: Some(Box::new(f)), @@ -46,9 +50,9 @@ where } } - fn hb(&self, ctx: &mut WebsocketContext) { + fn send_heartbeats(&self, ctx: &mut WebsocketContext) { ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { - if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT { + if Instant::now().duration_since(act.last_heartbeat) > CLIENT_TIMEOUT { ctx.stop(); } ctx.ping(b""); @@ -65,20 +69,19 @@ where type Context = WebsocketContext; fn started(&mut self, ctx: &mut Self::Context) { - self.hb(ctx); - if let Some(initializer) = self.initializer.take() { - let (sink, stream) = async_graphql::http::WebSocketStream::new_with_initializer( - &self.schema, - initializer, - ) - .split(); - ctx.add_stream(stream); - self.sink = Some(sink); - } else { - let (sink, stream) = async_graphql::http::WebSocketStream::new(&self.schema).split(); - ctx.add_stream(stream); - self.sink = Some(sink); - }; + self.send_heartbeats(ctx); + + let (tx, rx) = mpsc::unbounded(); + + WebSocket::with_data(self.schema.take().unwrap(), rx, self.initializer.take()) + .into_actor(self) + .map(|response, _act, ctx| { + ctx.text(response); + }) + .finish() + .spawn(ctx); + + self.messages = Some(tx); } } @@ -98,47 +101,49 @@ where Ok(msg) => msg, }; - match msg { + let message = match msg { Message::Ping(msg) => { - self.hb = Instant::now(); + self.last_heartbeat = Instant::now(); ctx.pong(&msg); + None } Message::Pong(_) => { - self.hb = Instant::now(); + self.last_heartbeat = Instant::now(); + None } - Message::Text(s) => { - if let Some(mut sink) = self.sink.take() { - async move { - let res = sink.send(s).await; - res.map(|_| sink) - } - .into_actor(self) - .then(|res, actor, ctx| { - match res { - Ok(sink) => actor.sink = Some(sink), - Err(_) => ctx.stop(), - } - async {}.into_actor(actor) - }) - .wait(ctx); + Message::Continuation(item) => match item { + ws::Item::FirstText(bytes) | ws::Item::FirstBinary(bytes) => { + self.continuation = bytes.to_vec(); + None } - } - Message::Binary(_) | Message::Close(_) | Message::Continuation(_) => { + ws::Item::Continue(bytes) => { + self.continuation.extend_from_slice(&bytes); + None + } + ws::Item::Last(bytes) => { + self.continuation.extend_from_slice(&bytes); + Some(std::mem::take(&mut self.continuation)) + } + }, + Message::Text(s) => Some(s.into_bytes()), + Message::Binary(bytes) => Some(bytes.to_vec()), + Message::Close(_) => { ctx.stop(); + None } - Message::Nop => {} + Message::Nop => None, + }; + + if let Some(message) = message { + let mut sender = self.messages.as_ref().unwrap().clone(); + + async move { sender.send(message).await } + .into_actor(self) + .map(|res, _actor, ctx| match res { + Ok(()) => {} + Err(_) => ctx.stop(), + }) + .spawn(ctx) } } } - -impl StreamHandler - for WSSubscription -where - Query: ObjectType + Send + Sync + 'static, - Mutation: ObjectType + Send + Sync + 'static, - Subscription: SubscriptionType + Send + Sync + 'static, -{ - fn handle(&mut self, data: String, ctx: &mut Self::Context) { - ctx.text(data); - } -} diff --git a/integrations/warp/src/lib.rs b/integrations/warp/src/lib.rs index ef542dd8..97ce08c3 100644 --- a/integrations/warp/src/lib.rs +++ b/integrations/warp/src/lib.rs @@ -1,4 +1,4 @@ -//! Async-graphql integration with Wrap +//! Async-graphql integration with Warp #![warn(missing_docs)] #![allow(clippy::type_complexity)] @@ -9,13 +9,11 @@ use async_graphql::http::MultipartOptions; use async_graphql::{ resolver_utils::ObjectType, Data, FieldResult, Request, Schema, SubscriptionType, }; -use futures::io::ErrorKind; -use futures::{select, TryStreamExt}; -use futures::{SinkExt, StreamExt}; +use futures::{future, StreamExt, TryStreamExt}; use hyper::Method; -use std::io; +use std::io::{self, ErrorKind}; use std::sync::Arc; -use warp::filters::ws::Message; +use warp::filters::ws; use warp::filters::BoxedFilter; use warp::reject::Reject; use warp::reply::Response; @@ -169,81 +167,57 @@ where /// ``` pub fn graphql_subscription( schema: Schema, -) -> BoxedFilter<(impl Reply,)> +) -> impl Filter + Clone where Query: ObjectType + Sync + Send + 'static, Mutation: ObjectType + Sync + Send + 'static, Subscription: SubscriptionType + Send + Sync + 'static, { - graphql_subscription_with_initializer(schema, |_| Ok(Default::default())) + graphql_subscription_with_data::<_, _, _, fn(serde_json::Value) -> FieldResult>( + schema, None, + ) } /// GraphQL subscription filter /// /// Specifies that a function converts the init payload to data. -pub fn graphql_subscription_with_initializer( +pub fn graphql_subscription_with_data( schema: Schema, - initializer: F, -) -> BoxedFilter<(impl Reply,)> + initializer: Option, +) -> impl Filter + Clone where Query: ObjectType + Sync + Send + 'static, Mutation: ObjectType + Sync + Send + 'static, Subscription: SubscriptionType + Send + Sync + 'static, - F: Fn(serde_json::Value) -> FieldResult + Send + Sync + Clone + 'static, + F: FnOnce(serde_json::Value) -> FieldResult + Send + Sync + Clone + 'static, { warp::any() .and(warp::ws()) .and(warp::any().map(move || schema.clone())) .and(warp::any().map(move || initializer.clone())) .map( - |ws: warp::ws::Ws, schema: Schema, initializer: F| { + |ws: ws::Ws, schema: Schema, initializer: Option| { ws.on_upgrade(move |websocket| { - let (mut tx, rx) = websocket.split(); - let (mut stx, srx) = - async_graphql::http::WebSocketStream::new_with_initializer( - &schema, - initializer, - ) - .split(); - - let mut rx = rx.fuse(); - let srx = srx.fuse(); + let (ws_sender, ws_receiver) = websocket.split(); async move { - futures::pin_mut!(srx); - - loop { - select! { - bytes = srx.next() => { - if let Some(text) = bytes { - if tx.send(Message::text(text)).await.is_err() { - return; - } - } else { - return; - } - } - msg = rx.next() => { - if let Some(Ok(msg)) = msg { - if msg.is_text() { - if let Ok(text) = String::from_utf8(msg.into_bytes()) { - if stx.send(text).await.is_err() { - return; - } - } - } - } else { - return; - } - } - } - } + let _ = async_graphql::http::WebSocket::with_data( + schema, + ws_receiver + .take_while(|msg| future::ready(msg.is_ok())) + .map(Result::unwrap) + .map(ws::Message::into_bytes), + initializer, + ) + .map(ws::Message::text) + .map(Ok) + .forward(ws_sender) + .await; } }) }, ) .map(|reply| warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws")) - .boxed() } /// GraphQL reply diff --git a/src/http/mod.rs b/src/http/mod.rs index 79a5c921..85c7b2dd 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -10,7 +10,7 @@ pub use graphiql_source::graphiql_source; #[cfg(feature = "multipart")] pub use multipart::{receive_multipart, MultipartOptions}; pub use playground_source::{playground_source, GraphQLPlaygroundConfig}; -pub use websocket::WebSocketStream; +pub use websocket::WebSocket; use crate::{ParseRequestError, Request}; use futures::io::AsyncRead; diff --git a/src/http/websocket.rs b/src/http/websocket.rs index 84f43139..7183f2df 100644 --- a/src/http/websocket.rs +++ b/src/http/websocket.rs @@ -2,291 +2,207 @@ use crate::resolver_utils::ObjectType; use crate::{Data, FieldResult, Request, Response, Schema, SubscriptionType}; -use futures::channel::mpsc; -use futures::task::{Context, Poll}; -use futures::{Future, Sink, SinkExt, Stream, StreamExt}; +use futures::Stream; +use pin_project_lite::pin_project; use serde::{Deserialize, Serialize}; -use std::collections::{HashMap, VecDeque}; +use std::collections::HashMap; use std::pin::Pin; use std::sync::Arc; +use std::task::{Context, Poll}; -#[derive(Serialize, Deserialize)] -struct OperationMessage<'a, T> { - #[serde(rename = "type")] - ty: &'a str, - - #[serde(skip_serializing_if = "Option::is_none")] - id: Option, - - #[serde(skip_serializing_if = "Option::is_none")] - payload: Option, -} - -type SubscriptionStreams = HashMap + Send>>>; - -type HandleRequestBoxFut = Pin> + Send>>; - -type InitializerFn = Arc FieldResult + Send + Sync>; - -/// A wrapper around an underlying raw stream which implements the WebSocket protocol. -/// -/// Only Text messages can be transmitted. You can use `futures::stream::StreamExt::split` function -/// to splits this object into separate Sink and Stream objects. -pub struct WebSocketStream { - tx: mpsc::UnboundedSender, - rx: Pin + Send>>, -} - -impl Sink for WebSocketStream { - type Error = mpsc::SendError; - - fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.tx.poll_ready_unpin(cx) - } - - fn start_send(mut self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> { - self.tx.start_send(item) - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.tx.poll_flush_unpin(cx) - } - - fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.tx.poll_close_unpin(cx) +pin_project! { + /// A GraphQL connection over websocket. + /// + /// [Reference](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md). + pub struct WebSocket { + data_initializer: Option, + data: Arc, + schema: Schema, + streams: HashMap + Send>>>, + #[pin] + stream: S, } } -impl Stream for WebSocketStream { - type Item = String; - - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - self.rx.poll_next_unpin(cx) - } -} - -impl WebSocketStream { - /// Create a websocket transport. - pub fn new( - schema: &Schema, - ) -> Self - where - Query: ObjectType + Send + Sync + 'static, - Mutation: ObjectType + Send + Sync + 'static, - Subscription: SubscriptionType + Send + Sync + 'static, - { - Self::new_with_initializer(schema, |_| Ok(Default::default())) - } - - /// Create a websocket transport and specify a context initialization function. - pub fn new_with_initializer( - schema: &Schema, - initializer: impl Fn(serde_json::Value) -> FieldResult + Send + Sync + 'static, - ) -> Self - where - Query: ObjectType + Send + Sync + 'static, - Mutation: ObjectType + Send + Sync + 'static, - Subscription: SubscriptionType + Send + Sync + 'static, - { - let (tx, rx) = mpsc::unbounded(); - WebSocketStream { - tx, - rx: SubscriptionStream { - schema: schema.clone(), - initializer: Arc::new(initializer), - rx_bytes: rx, - handle_request_fut: None, - ctx: Some(WSContext { - streams: Default::default(), - send_buf: Default::default(), - ctx_data: Arc::new(Data::default()), - }), - } - .boxed(), +impl + WebSocket FieldResult, Query, Mutation, Subscription> +{ + /// Create a new websocket. + #[must_use] + pub fn new(schema: Schema, stream: S) -> Self { + Self { + data_initializer: None, + data: Arc::default(), + schema, + streams: HashMap::new(), + stream, } } } -struct WSContext { - streams: SubscriptionStreams, - send_buf: VecDeque, - ctx_data: Arc, -} - -fn send_message(send_buf: &mut VecDeque, msg: &T) { - if let Ok(data) = serde_json::to_string(msg) { - send_buf.push_back(data); +impl WebSocket { + /// Create a new websocket with a data initialization function. + /// + /// This function, if present, will be called with the data sent by the client in the + /// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init). + /// From that point on the returned data will be accessible to all requests. + #[must_use] + pub fn with_data( + schema: Schema, + stream: S, + data_initializer: Option, + ) -> Self { + Self { + data_initializer, + data: Arc::default(), + schema, + streams: HashMap::new(), + stream, + } } } -#[allow(missing_docs)] -#[allow(clippy::type_complexity)] -struct SubscriptionStream { - schema: Schema, - initializer: InitializerFn, - rx_bytes: mpsc::UnboundedReceiver, - handle_request_fut: Option, - ctx: Option, -} - -impl<'a, Query, Mutation, Subscription> Stream for SubscriptionStream +impl Stream for WebSocket where + S: Stream, + S::Item: AsRef<[u8]>, + F: FnOnce(serde_json::Value) -> FieldResult, Query: ObjectType + Send + Sync + 'static, Mutation: ObjectType + Send + Sync + 'static, Subscription: SubscriptionType + Send + Sync + 'static, { type Item = String; - fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let this = &mut *self; + fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let this = self.project(); - loop { - // receive bytes - if let Some(ctx) = &mut this.ctx { - if let Some(bytes) = ctx.send_buf.pop_front() { - return Poll::Ready(Some(bytes)); - } - } + match this.stream.poll_next(cx) { + Poll::Ready(message) => { + let message = match message { + Some(message) => message, + None => return Poll::Ready(None), + }; - if let Some(handle_request_fut) = &mut this.handle_request_fut { - match handle_request_fut.as_mut().poll(cx) { - Poll::Ready(Ok(ctx)) => { - this.ctx = Some(ctx); - this.handle_request_fut = None; - continue; + let message: ClientMessage = match serde_json::from_slice(message.as_ref()) { + Ok(message) => message, + Err(e) => { + return Poll::Ready(Some( + serde_json::to_string(&ServerMessage::ConnectionError { + payload: ConnectionError { + message: e.to_string(), + extensions: None, + }, + }) + .unwrap(), + )) } - Poll::Ready(Err(_)) => return Poll::Ready(None), - Poll::Pending => {} - } - } else { - match Pin::new(&mut this.rx_bytes).poll_next(cx) { - Poll::Ready(Some(data)) => { - let ctx = this.ctx.take().unwrap(); - this.handle_request_fut = Some(Box::pin(handle_request( - this.schema.clone(), - this.initializer.clone(), - ctx, - data, - ))); - continue; - } - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => {} - } - } + }; - // receive msg - if let Some(ctx) = &mut this.ctx { - let mut closed = Vec::new(); - - for (id, incoming_stream) in ctx.streams.iter_mut() { - match incoming_stream.as_mut().poll_next(cx) { - Poll::Ready(Some(res)) => { - if let Some(err) = &res.error { - closed.push(id.to_string()); - send_message( - &mut ctx.send_buf, - &OperationMessage { - ty: "error", - id: Some(id.to_string()), - payload: Some(err), - }, - ); - } else { - send_message( - &mut ctx.send_buf, - &OperationMessage { - ty: "data", - id: Some(id.to_string()), - payload: Some(&res), - }, - ); + match message { + ClientMessage::ConnectionInit { payload } => { + if let Some(payload) = payload { + if let Some(data_initializer) = this.data_initializer.take() { + *this.data = Arc::new(match data_initializer(payload) { + Ok(data) => data, + Err(e) => { + return Poll::Ready(Some( + serde_json::to_string( + &ServerMessage::ConnectionError { + payload: ConnectionError { + message: e.0, + extensions: e.1, + }, + }, + ) + .unwrap(), + )) + } + }); } } - Poll::Ready(None) => { - closed.push(id.to_string()); - send_message( - &mut ctx.send_buf, - &OperationMessage { - ty: "complete", - id: Some(id.to_string()), - payload: Option::::None, - }, - ); - } - Poll::Pending => {} + return Poll::Ready(Some( + serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(), + )); } - } - - for id in closed { - ctx.streams.remove(&id); - } - - if !ctx.send_buf.is_empty() { - continue; - } - } - - return Poll::Pending; - } - } -} - -async fn handle_request( - schema: Schema, - initializer: InitializerFn, - mut ctx: WSContext, - data: String, -) -> FieldResult -where - Query: ObjectType + Send + Sync + 'static, - Mutation: ObjectType + Send + Sync + 'static, - Subscription: SubscriptionType + Send + Sync + 'static, -{ - match serde_json::from_str::>(&data) { - Ok(msg) => match msg.ty { - "connection_init" => { - if let Some(payload) = msg.payload { - ctx.ctx_data = Arc::new(initializer(payload)?); - } - send_message( - &mut ctx.send_buf, - &OperationMessage { - ty: "connection_ack", - id: None, - payload: Option::::None, - }, - ); - } - "start" => { - if let (Some(id), Some(payload)) = (msg.id, msg.payload) { - if let Ok(request) = serde_json::from_value::(payload) { - let stream = schema - .execute_stream_with_ctx_data(request, ctx.ctx_data.clone()) - .boxed(); - ctx.streams.insert(id, stream); - } - } - } - "stop" => { - if let Some(id) = msg.id { - if ctx.streams.remove(&id).is_some() { - send_message( - &mut ctx.send_buf, - &OperationMessage { - ty: "complete", - id: Some(id), - payload: Option::::None, - }, + ClientMessage::Start { + id, + payload: request, + } => { + this.streams.insert( + id, + Box::pin( + this.schema + .execute_stream_with_ctx_data(request, Arc::clone(this.data)), + ), ); } + ClientMessage::Stop { id } => { + if this.streams.remove(id).is_some() { + return Poll::Ready(Some( + serde_json::to_string(&ServerMessage::Complete { id }).unwrap(), + )); + } + } + ClientMessage::ConnectionTerminate => return Poll::Ready(None), } } - "connection_terminate" => return Err("connection_terminate".into()), - _ => return Err("Unknown op".into()), - }, - Err(err) => return Err(err.into()), - } + Poll::Pending => {} + } - Ok(ctx) + for (id, stream) in &mut *this.streams { + match Pin::new(stream).poll_next(cx) { + Poll::Ready(Some(payload)) => { + return Poll::Ready(Some( + serde_json::to_string(&ServerMessage::Data { + id, + payload: Box::new(payload), + }) + .unwrap(), + )); + } + Poll::Ready(None) => { + let id = id.clone(); + this.streams.remove(&id); + return Poll::Ready(Some( + serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(), + )); + } + Poll::Pending => {} + } + } + + Poll::Pending + } +} + +#[derive(Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum ClientMessage<'a> { + ConnectionInit { payload: Option }, + Start { id: String, payload: Request }, + Stop { id: &'a str }, + ConnectionTerminate, +} + +#[derive(Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +enum ServerMessage<'a> { + ConnectionError { payload: ConnectionError }, + ConnectionAck, + Data { id: &'a str, payload: Box }, + // Not used by this library, as it's not necessary to send + // Error { + // id: &'a str, + // payload: serde_json::Value, + // }, + Complete { id: &'a str }, + // Not used by this library + // #[serde(rename = "ka")] + // KeepAlive +} + +#[derive(Serialize)] +struct ConnectionError { + message: String, + extensions: Option, } diff --git a/src/schema.rs b/src/schema.rs index 2019f0e0..a804e00a 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -436,9 +436,9 @@ where pub(crate) fn execute_stream_with_ctx_data( &self, - request: impl Into, + request: impl Into + Send, ctx_data: Arc, - ) -> impl Stream { + ) -> impl Stream + Send { let schema = self.clone(); async_stream::stream! { @@ -495,7 +495,10 @@ where } /// Execute an GraphQL subscription. - pub fn execute_stream(&self, request: impl Into) -> impl Stream { + pub fn execute_stream( + &self, + request: impl Into, + ) -> impl Stream + Send { let mut request = request.into(); let ctx_data = std::mem::take(&mut request.data); self.execute_stream_with_ctx_data(request, Arc::new(ctx_data)) diff --git a/tests/subscription_websocket.rs b/tests/subscription_websocket.rs index 7fcf6df0..7f685e70 100644 --- a/tests/subscription_websocket.rs +++ b/tests/subscription_websocket.rs @@ -1,4 +1,5 @@ use async_graphql::*; +use futures::channel::mpsc; use futures::{SinkExt, Stream, StreamExt}; #[async_std::test] @@ -18,56 +19,55 @@ pub async fn test_subscription_ws_transport() { } let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); - let mut stream = http::WebSocketStream::new(&schema); + let (mut tx, rx) = mpsc::unbounded(); + let mut stream = http::WebSocket::new(schema, rx); - stream - .send( - serde_json::to_string(&serde_json::json!({ - "type": "connection_init", - })) - .unwrap(), - ) - .await - .unwrap(); + tx.send( + serde_json::to_string(&serde_json::json!({ + "type": "connection_init", + })) + .unwrap(), + ) + .await + .unwrap(); assert_eq!( - Some(serde_json::json!({ - "type": "connection_ack", - })), - serde_json::from_str(&stream.next().await.unwrap()).unwrap() + serde_json::from_str::(&stream.next().await.unwrap()).unwrap(), + serde_json::json!({ + "type": "connection_ack", + }), ); - stream - .send( - serde_json::to_string(&serde_json::json!({ - "type": "start", - "id": "1", - "payload": { - "query": "subscription { values }" - }, - })) - .unwrap(), - ) - .await - .unwrap(); + tx.send( + serde_json::to_string(&serde_json::json!({ + "type": "start", + "id": "1", + "payload": { + "query": "subscription { values }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); for i in 0..10 { assert_eq!( - Some(serde_json::json!({ - "type": "data", - "id": "1", - "payload": { "data": { "values": i } }, - })), - serde_json::from_str(&stream.next().await.unwrap()).unwrap() + serde_json::from_str::(&stream.next().await.unwrap()).unwrap(), + serde_json::json!({ + "type": "data", + "id": "1", + "payload": { "data": { "values": i } }, + }), ); } assert_eq!( - Some(serde_json::json!({ - "type": "complete", - "id": "1", - })), - serde_json::from_str(&stream.next().await.unwrap()).unwrap() + serde_json::from_str::(&stream.next().await.unwrap()).unwrap(), + serde_json::json!({ + "type": "complete", + "id": "1", + }), ); } @@ -93,56 +93,59 @@ pub async fn test_subscription_ws_transport_with_token() { } let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); - let mut stream = http::WebSocketStream::new_with_initializer(&schema, |value| { - #[derive(serde::Deserialize)] - struct Payload { - token: String, - } + let (mut tx, rx) = mpsc::unbounded(); + let mut stream = http::WebSocket::with_data( + schema, + rx, + Some(|value| { + #[derive(serde::Deserialize)] + struct Payload { + token: String, + } - let payload: Payload = serde_json::from_value(value).unwrap(); - let mut data = Data::default(); - data.insert(Token(payload.token)); - Ok(data) - }); + let payload: Payload = serde_json::from_value(value).unwrap(); + let mut data = Data::default(); + data.insert(Token(payload.token)); + Ok(data) + }), + ); - stream - .send( - serde_json::to_string(&serde_json::json!({ - "type": "connection_init", - "payload": { "token": "123456" } - })) - .unwrap(), - ) - .await - .unwrap(); + tx.send( + serde_json::to_string(&serde_json::json!({ + "type": "connection_init", + "payload": { "token": "123456" } + })) + .unwrap(), + ) + .await + .unwrap(); assert_eq!( Some(serde_json::json!({ - "type": "connection_ack", + "type": "connection_ack", })), serde_json::from_str(&stream.next().await.unwrap()).unwrap() ); - stream - .send( - serde_json::to_string(&serde_json::json!({ - "type": "start", - "id": "1", - "payload": { - "query": "subscription { values }" - }, - })) - .unwrap(), - ) - .await - .unwrap(); + tx.send( + serde_json::to_string(&serde_json::json!({ + "type": "start", + "id": "1", + "payload": { + "query": "subscription { values }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); for i in 0..10 { assert_eq!( Some(serde_json::json!({ - "type": "data", - "id": "1", - "payload": { "data": { "values": i } }, + "type": "data", + "id": "1", + "payload": { "data": { "values": i } }, })), serde_json::from_str(&stream.next().await.unwrap()).unwrap() ); @@ -150,8 +153,8 @@ pub async fn test_subscription_ws_transport_with_token() { assert_eq!( Some(serde_json::json!({ - "type": "complete", - "id": "1", + "type": "complete", + "id": "1", })), serde_json::from_str(&stream.next().await.unwrap()).unwrap() ); @@ -189,45 +192,44 @@ pub async fn test_subscription_ws_transport_error() { } let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); - let mut stream = http::WebSocketStream::new(&schema); + let (mut tx, rx) = mpsc::unbounded(); + let mut stream = http::WebSocket::new(schema, rx); - stream - .send( - serde_json::to_string(&serde_json::json!({ - "type": "connection_init" - })) - .unwrap(), - ) - .await - .unwrap(); + tx.send( + serde_json::to_string(&serde_json::json!({ + "type": "connection_init" + })) + .unwrap(), + ) + .await + .unwrap(); assert_eq!( Some(serde_json::json!({ - "type": "connection_ack", + "type": "connection_ack", })), serde_json::from_str(&stream.next().await.unwrap()).unwrap() ); - stream - .send( - serde_json::to_string(&serde_json::json!({ - "type": "start", - "id": "1", - "payload": { - "query": "subscription { events { value } }" - }, - })) - .unwrap(), - ) - .await - .unwrap(); + tx.send( + serde_json::to_string(&serde_json::json!({ + "type": "start", + "id": "1", + "payload": { + "query": "subscription { events { value } }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); for i in 0i32..5 { assert_eq!( Some(serde_json::json!({ - "type": "data", - "id": "1", - "payload": { "data": { "events": { "value": i } } }, + "type": "data", + "id": "1", + "payload": { "data": { "events": { "value": i } } }, })), serde_json::from_str(&stream.next().await.unwrap()).unwrap() ); @@ -235,13 +237,15 @@ pub async fn test_subscription_ws_transport_error() { assert_eq!( Some(serde_json::json!({ - "type": "error", - "id": "1", - "payload": [{ - "message": "TestError", - "locations": [{"line": 1, "column": 25}], - "path": ["events", "value"], - }], + "type": "data", + "id": "1", + "payload": { + "errors": [{ + "message": "TestError", + "locations": [{"line": 1, "column": 25}], + "path": ["events", "value"], + }], + }, })), serde_json::from_str(&stream.next().await.unwrap()).unwrap() ); @@ -259,17 +263,17 @@ pub async fn test_query_over_websocket() { } let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription); - let mut stream = http::WebSocketStream::new(&schema); + let (mut tx, rx) = mpsc::unbounded(); + let mut stream = http::WebSocket::new(schema, rx); - stream - .send( - serde_json::to_string(&serde_json::json!({ - "type": "connection_init", - })) - .unwrap(), - ) - .await - .unwrap(); + tx.send( + serde_json::to_string(&serde_json::json!({ + "type": "connection_init", + })) + .unwrap(), + ) + .await + .unwrap(); assert_eq!( Some(serde_json::json!({ @@ -278,19 +282,18 @@ pub async fn test_query_over_websocket() { serde_json::from_str(&stream.next().await.unwrap()).unwrap() ); - stream - .send( - serde_json::to_string(&serde_json::json!({ - "type": "start", - "id": "1", - "payload": { - "query": "query { value }" - }, - })) - .unwrap(), - ) - .await - .unwrap(); + tx.send( + serde_json::to_string(&serde_json::json!({ + "type": "start", + "id": "1", + "payload": { + "query": "query { value }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); assert_eq!( Some(serde_json::json!({