diff --git a/Cargo.toml b/Cargo.toml index 3cb857c7..3b7ab515 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -50,7 +50,7 @@ bson = { version = "1.0.0", optional = true } uuid = { version = "0.8.1", features = ["v4", "serde"] } url = { version = "2.1.1", optional = true } chrono-tz = { version = "0.5.1", optional = true } -smallvec = "1.4.2" + [dev-dependencies] async-std = { version = "1.5.0", features = ["attributes"] } diff --git a/src/subscription/connection.rs b/src/subscription/connection.rs index e3d7aeee..54e0c6cd 100644 --- a/src/subscription/connection.rs +++ b/src/subscription/connection.rs @@ -4,7 +4,6 @@ use futures::channel::mpsc; use futures::task::{AtomicWaker, Context, Poll}; use futures::{Stream, StreamExt}; use slab::Slab; -use smallvec::SmallVec; use std::collections::VecDeque; use std::future::Future; use std::pin::Pin; @@ -45,8 +44,9 @@ pub trait ConnectionTransport: Send + Sync + Unpin + 'static { &mut self, schema: &Schema, streams: &mut SubscriptionStreams, - data: Bytes, - ) -> std::result::Result>, Self::Error> + request: Bytes, + send_buf: &mut VecDeque, + ) -> std::result::Result<(), Self::Error> where Query: ObjectType + Sync + Send + 'static, Mutation: ObjectType + Sync + Send + 'static, @@ -73,6 +73,7 @@ where let mut streams = SubscriptionStreams { streams: Default::default(), }; + let mut send_buf = Default::default(); let mut inner_stream = SubscriptionStream { schema: &schema, transport: Some(&mut transport), @@ -80,7 +81,7 @@ where rx_bytes, handle_request_fut: None, waker: AtomicWaker::new(), - send_buf: VecDeque::new(), + send_buf: Some(&mut send_buf), }; while let Some(data) = inner_stream.next().await { yield data; @@ -93,12 +94,10 @@ type HandleRequestBoxFut<'a, T> = Pin< Box< dyn Future< Output = ( - std::result::Result< - Option>, - ::Error, - >, + std::result::Result<(), ::Error>, &'a mut T, &'a mut SubscriptionStreams, + &'a mut VecDeque, ), > + Send + 'a, @@ -114,7 +113,7 @@ struct SubscriptionStream<'a, Query, Mutation, Subscription, T: ConnectionTransp rx_bytes: mpsc::UnboundedReceiver, handle_request_fut: Option>, waker: AtomicWaker, - send_buf: VecDeque, + send_buf: Option<&'a mut VecDeque>, } impl<'a, Query, Mutation, Subscription, T> Stream @@ -132,25 +131,22 @@ where loop { // receive bytes - if let Some(bytes) = this.send_buf.pop_front() { - return Poll::Ready(Some(bytes)); + if let Some(send_buf) = &mut this.send_buf { + if let Some(bytes) = send_buf.pop_front() { + return Poll::Ready(Some(bytes)); + } } if let Some(handle_request_fut) = &mut this.handle_request_fut { match handle_request_fut.as_mut().poll(cx) { - Poll::Ready((Ok(bytes), transport, streams)) => { + Poll::Ready((Ok(()), transport, streams, send_buf)) => { this.transport = Some(transport); this.streams = Some(streams); + this.send_buf = Some(send_buf); this.handle_request_fut = None; - if let Some(mut msgs) = bytes { - if !msgs.is_empty() { - this.send_buf.extend(msgs.drain(1..)); - return Poll::Ready(Some(msgs.remove(0))); - } - } continue; } - Poll::Ready((Err(_), _, _)) => return Poll::Ready(None), + Poll::Ready((Err(_), _, _, _)) => return Poll::Ready(None), Poll::Pending => {} } } else { @@ -159,9 +155,12 @@ where let transport = this.transport.take().unwrap(); let schema = this.schema; let streams = this.streams.take().unwrap(); + let send_buf = this.send_buf.take().unwrap(); this.handle_request_fut = Some(Box::pin(async move { - let res = transport.handle_request(schema, streams, data).await; - (res, transport, streams) + let res = transport + .handle_request(schema, streams, data, send_buf) + .await; + (res, transport, streams, send_buf) })); this.waker.wake(); continue; diff --git a/src/subscription/ws_transport.rs b/src/subscription/ws_transport.rs index 8cba04b1..8bf8a482 100644 --- a/src/subscription/ws_transport.rs +++ b/src/subscription/ws_transport.rs @@ -6,8 +6,7 @@ use crate::{ }; use bytes::Bytes; use serde::{Deserialize, Serialize}; -use smallvec::{smallvec, SmallVec}; -use std::collections::HashMap; +use std::collections::{HashMap, VecDeque}; use std::sync::Arc; #[derive(Serialize, Deserialize)] @@ -43,6 +42,12 @@ impl WebSocketTransport { } } +fn send_message(send_buf: &mut VecDeque, msg: &T) { + if let Ok(data) = serde_json::to_vec(msg) { + send_buf.push_back(data.into()); + } +} + #[async_trait::async_trait] impl ConnectionTransport for WebSocketTransport { type Error = FieldError; @@ -51,14 +56,15 @@ impl ConnectionTransport for WebSocketTransport { &mut self, schema: &Schema, streams: &mut SubscriptionStreams, - data: Bytes, - ) -> std::result::Result>, Self::Error> + request: Bytes, + send_buf: &mut VecDeque, + ) -> std::result::Result<(), Self::Error> where Query: ObjectType + Sync + Send + 'static, Mutation: ObjectType + Sync + Send + 'static, Subscription: SubscriptionType + Sync + Send + 'static, { - match serde_json::from_slice::(&data) { + match serde_json::from_slice::(&request) { Ok(msg) => match msg.ty.as_str() { "connection_init" => { if let Some(payload) = msg.payload { @@ -66,13 +72,15 @@ impl ConnectionTransport for WebSocketTransport { self.data = Arc::new(init_context_data(payload)?); } } - Ok(Some(smallvec![serde_json::to_vec(&OperationMessage { - ty: "connection_ack".to_string(), - id: None, - payload: None, - }) - .unwrap() - .into()])) + send_message( + send_buf, + &OperationMessage { + ty: "connection_ack".to_string(), + id: None, + payload: None, + }, + ); + Ok(()) } "start" => { if let (Some(id), Some(payload)) = (msg.id, msg.payload) { @@ -95,7 +103,6 @@ impl ConnectionTransport for WebSocketTransport { let stream_id = streams.add(stream); self.id_to_sid.insert(id.clone(), stream_id); self.sid_to_id.insert(stream_id, id); - Ok(None) } Err(Error::Query { err, .. }) if err == QueryError::NotSupported => @@ -108,72 +115,78 @@ impl ConnectionTransport for WebSocketTransport { } match builder.execute(schema).await { - Ok(resp) => Ok(Some(smallvec![ - serde_json::to_vec(&OperationMessage { - ty: "data".to_string(), - id: Some(id.clone()), - payload: Some( - serde_json::to_value(&GQLResponse(Ok(resp))) + Ok(resp) => { + send_message( + send_buf, + &OperationMessage { + ty: "data".to_string(), + id: Some(id.clone()), + payload: Some( + serde_json::to_value(&GQLResponse(Ok( + resp, + ))) .unwrap(), - ), - }) - .unwrap() - .into(), - serde_json::to_vec(&OperationMessage { - ty: "complete".to_string(), - id: Some(id), - payload: None, - }) - .unwrap() - .into(), - ])), - Err(err) => Ok(Some(smallvec![serde_json::to_vec( - &OperationMessage { - ty: "error".to_string(), - id: Some(id), - payload: Some( - serde_json::to_value(GQLError(&err)).unwrap(), - ), - } - ) - .unwrap() - .into()])), + ), + }, + ); + + send_message( + send_buf, + &OperationMessage { + ty: "complete".to_string(), + id: Some(id), + payload: None, + }, + ); + } + Err(err) => { + send_message( + send_buf, + &OperationMessage { + ty: "error".to_string(), + id: Some(id), + payload: Some( + serde_json::to_value(GQLError(&err)) + .unwrap(), + ), + }, + ); + } } } Err(err) => { - Ok(Some(smallvec![serde_json::to_vec(&OperationMessage { - ty: "error".to_string(), - id: Some(id), - payload: Some( - serde_json::to_value(GQLError(&err)).unwrap() - ), - }) - .unwrap() - .into()])) + send_message( + send_buf, + &OperationMessage { + ty: "error".to_string(), + id: Some(id), + payload: Some( + serde_json::to_value(GQLError(&err)).unwrap(), + ), + }, + ); } } - } else { - Ok(None) } - } else { - Ok(None) } + Ok(()) } "stop" => { if let Some(id) = msg.id { if let Some(sid) = self.id_to_sid.remove(&id) { self.sid_to_id.remove(&sid); streams.remove(sid); - return Ok(Some(smallvec![serde_json::to_vec(&OperationMessage { - ty: "complete".to_string(), - id: Some(id), - payload: None, - }) - .unwrap() - .into()])); + send_message( + send_buf, + &OperationMessage { + ty: "complete".to_string(), + id: Some(id), + payload: None, + }, + ); } } - Ok(None) + Ok(()) } "connection_terminate" => Err("connection_terminate".into()), _ => Err("Unknown op".into()),