diff --git a/src/subscription/connection.rs b/src/subscription/connection.rs index 28a53f9a..689eaa1e 100644 --- a/src/subscription/connection.rs +++ b/src/subscription/connection.rs @@ -4,6 +4,7 @@ use futures::channel::mpsc; use futures::task::{AtomicWaker, Context, Poll}; use futures::{Stream, StreamExt}; use slab::Slab; +use std::collections::VecDeque; use std::future::Future; use std::pin::Pin; @@ -44,7 +45,7 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static { schema: &Schema, streams: &mut SubscriptionStreams, data: Bytes, - ) -> std::result::Result, Self::Error> + ) -> std::result::Result>, Self::Error> where Query: ObjectType + Sync + Send + 'static, Mutation: ObjectType + Sync + Send + 'static, @@ -78,6 +79,7 @@ where rx_bytes, handle_request_fut: None, waker: AtomicWaker::new(), + send_buf: VecDeque::new(), }; while let Some(data) = inner_stream.next().await { yield data; @@ -90,7 +92,7 @@ type HandleRequestBoxFut<'a, T> = Pin< Box< dyn Future< Output = ( - std::result::Result, ::Error>, + std::result::Result>, ::Error>, &'a mut T, &'a mut SubscriptionStreams, ), @@ -108,6 +110,7 @@ struct SubscriptionStream<'a, Query, Mutation, Subscription, T: SubscriptionTran rx_bytes: mpsc::UnboundedReceiver, handle_request_fut: Option>, waker: AtomicWaker, + send_buf: VecDeque, } impl<'a, Query, Mutation, Subscription, T> Stream @@ -125,14 +128,21 @@ where loop { // receive bytes + if let Some(bytes) = this.send_buf.pop_front() { + return Poll::Ready(Some(bytes)); + } + if let Some(handle_request_fut) = &mut this.handle_request_fut { match handle_request_fut.as_mut().poll(cx) { Poll::Ready((Ok(bytes), transport, streams)) => { this.transport = Some(transport); this.streams = Some(streams); this.handle_request_fut = None; - if let Some(bytes) = bytes { - return Poll::Ready(Some(bytes)); + if let Some(mut msgs) = bytes { + if !msgs.is_empty() { + this.send_buf.extend(msgs.drain(1..)); + return Poll::Ready(Some(msgs.remove(0))); + } } continue; } diff --git a/src/subscription/ws_transport.rs b/src/subscription/ws_transport.rs index 0da774ea..30086a33 100644 --- a/src/subscription/ws_transport.rs +++ b/src/subscription/ws_transport.rs @@ -50,7 +50,7 @@ impl SubscriptionTransport for WebSocketTransport { schema: &Schema, streams: &mut SubscriptionStreams, data: Bytes, - ) -> std::result::Result, Self::Error> + ) -> std::result::Result>, Self::Error> where Query: ObjectType + Sync + Send + 'static, Mutation: ObjectType + Sync + Send + 'static, @@ -64,15 +64,13 @@ impl SubscriptionTransport for WebSocketTransport { self.data = Arc::new(init_context_data(payload)?); } } - Ok(Some( - serde_json::to_vec(&OperationMessage { - ty: "connection_ack".to_string(), - id: None, - payload: None, - }) - .unwrap() - .into(), - )) + Ok(Some(vec![serde_json::to_vec(&OperationMessage { + ty: "connection_ack".to_string(), + id: None, + payload: None, + }) + .unwrap() + .into()])) } "start" => { if let (Some(id), Some(payload)) = (msg.id, msg.payload) { @@ -108,10 +106,10 @@ impl SubscriptionTransport for WebSocketTransport { } match builder.execute(schema).await { - Ok(resp) => Ok(Some( + Ok(resp) => Ok(Some(vec![ serde_json::to_vec(&OperationMessage { - ty: "complete".to_string(), - id: Some(id), + ty: "data".to_string(), + id: Some(id.clone()), payload: Some( serde_json::to_value(&GQLResponse(Ok(resp))) .unwrap(), @@ -119,9 +117,16 @@ impl SubscriptionTransport for WebSocketTransport { }) .unwrap() .into(), - )), - Err(err) => Ok(Some( serde_json::to_vec(&OperationMessage { + ty: "complete".to_string(), + id: Some(id), + payload: None, + }) + .unwrap() + .into(), + ])), + Err(err) => { + Ok(Some(vec![serde_json::to_vec(&OperationMessage { ty: "error".to_string(), id: Some(id), payload: Some( @@ -129,21 +134,17 @@ impl SubscriptionTransport for WebSocketTransport { ), }) .unwrap() - .into(), - )), + .into()])) + } } } - Err(err) => Ok(Some( - serde_json::to_vec(&OperationMessage { - ty: "error".to_string(), - id: Some(id), - payload: Some( - serde_json::to_value(GQLError(&err)).unwrap(), - ), - }) - .unwrap() - .into(), - )), + Err(err) => Ok(Some(vec![serde_json::to_vec(&OperationMessage { + ty: "error".to_string(), + id: Some(id), + payload: Some(serde_json::to_value(GQLError(&err)).unwrap()), + }) + .unwrap() + .into()])), } } else { Ok(None) @@ -157,15 +158,13 @@ impl SubscriptionTransport for WebSocketTransport { if let Some(sid) = self.id_to_sid.remove(&id) { self.sid_to_id.remove(&sid); streams.remove(sid); - return Ok(Some( - serde_json::to_vec(&OperationMessage { - ty: "complete".to_string(), - id: Some(id), - payload: None, - }) - .unwrap() - .into(), - )); + return Ok(Some(vec![serde_json::to_vec(&OperationMessage { + ty: "complete".to_string(), + id: Some(id), + payload: None, + }) + .unwrap() + .into()])); } } Ok(None) diff --git a/tests/subscription_websocket.rs b/tests/subscription_websocket.rs index 4f9ca0a5..8c0838b7 100644 --- a/tests/subscription_websocket.rs +++ b/tests/subscription_websocket.rs @@ -282,10 +282,18 @@ pub async fn test_query_over_websocket() { assert_eq!( Some(serde_json::json!({ - "type": "complete", + "type": "data", "id": "1", "payload": { "data": { "value": 999 } }, })), serde_json::from_slice(&stream.next().await.unwrap()).unwrap() - ) + ); + + assert_eq!( + Some(serde_json::json!({ + "type": "complete", + "id": "1", + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); }