From b85815d84a870394c6aa93c7e4eb26a96082bb8c Mon Sep 17 00:00:00 2001 From: Sunli Date: Tue, 9 Nov 2021 17:01:51 +0800 Subject: [PATCH] Rework websocket --- src/context.rs | 4 + src/http/websocket.rs | 137 ++++++++---------- tests/subscription_websocket_graphql_ws.rs | 72 +++++++-- ...on_websocket_subscriptions_transport_ws.rs | 10 +- 4 files changed, 127 insertions(+), 96 deletions(-) diff --git a/src/context.rs b/src/context.rs index be7bf5de..682a78d4 100644 --- a/src/context.rs +++ b/src/context.rs @@ -41,6 +41,10 @@ impl Data { pub fn insert(&mut self, data: D) { self.0.insert(TypeId::of::(), Box::new(data)); } + + pub(crate) fn merge(&mut self, other: Data) { + self.0.extend(other.0); + } } impl Debug for Data { diff --git a/src/http/websocket.rs b/src/http/websocket.rs index 984edb5d..af7a7461 100644 --- a/src/http/websocket.rs +++ b/src/http/websocket.rs @@ -8,10 +8,7 @@ use std::task::{Context, Poll}; use futures_util::stream::Stream; use futures_util::FutureExt; -use futures_util::{ - future::{BoxFuture, Ready}, - StreamExt, -}; +use futures_util::{future::BoxFuture, StreamExt}; use pin_project_lite::pin_project; use serde::{Deserialize, Serialize}; @@ -60,14 +57,18 @@ impl WsMessage { } } +type BoxInitializer = + Box<(dyn FnOnce(serde_json::Value) -> BoxFuture<'static, Result> + Send + 'static)>; + pin_project! { /// A GraphQL connection over websocket. /// /// [Reference](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md). /// [Reference](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md). - pub struct WebSocket { - data_initializer: Option, + pub struct WebSocket { + data_initializer: Option, init_fut: Option>>, + connection_data: Option, data: Option>, schema: Schema, streams: HashMap + Send>>>, @@ -80,80 +81,20 @@ pin_project! { type MessageMapStream = futures_util::stream::Map::Item) -> serde_json::Result>; -impl - WebSocket< - MessageMapStream, - fn(serde_json::Value) -> Ready>, - Query, - Mutation, - Subscription, - > -where - S: Stream, - S::Item: AsRef<[u8]>, -{ - /// Create a new websocket. - #[must_use] - pub fn new( - schema: Schema, - stream: S, - protocol: Protocols, - ) -> Self { - Self::with_data( - schema, - stream, - |_| futures_util::future::ready(Ok(Default::default())), - protocol, - ) - } -} - -impl - WebSocket, F, Query, Mutation, Subscription> -where - S: Stream, - S::Item: AsRef<[u8]>, -{ - /// Create a new websocket with a data initialization function. - /// - /// This function, if present, will be called with the data sent by the client in the - /// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init). - /// From that point on the returned data will be accessible to all requests. - #[must_use] - pub fn with_data( - schema: Schema, - stream: S, - data_initializer: F, - protocol: Protocols, - ) -> Self { - // let stream = stream.map(|message| serde_json::from_slice(message.as_ref())); - let stream = stream - .map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result); - - Self::with_message_stream(schema, stream, data_initializer, protocol) - } -} - -impl WebSocket +impl WebSocket where S: Stream>, { - /// Create a new websocket with a data initialization function from a stream of ClientMessage - /// structs. - /// - /// This function, if present, will be called with the data sent by the client in the - /// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init). - /// From that point on the returned data will be accessible to all requests. - #[must_use] - pub fn with_message_stream( + /// Create a new websocket from [`ClientMessage`] stream. + pub fn from_message_stream( schema: Schema, stream: S, - data_initializer: F, protocol: Protocols, ) -> Self { WebSocket { - data_initializer: Some(data_initializer), + data_initializer: Some(Box::new(|_| Box::pin(async move { Ok(Data::default()) }))), init_fut: None, + connection_data: None, data: None, schema, streams: HashMap::new(), @@ -161,14 +102,56 @@ where protocol, } } + + /// Specify a connection initializer. + /// + /// This function, if present, will be called with the data sent by the client in the + /// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init). + /// From that point on the returned data will be accessible to all requests. + pub fn with_initializer(mut self, initializer: F) -> Self + where + F: FnOnce(serde_json::Value) -> R + Send + 'static, + R: Future> + Send + 'static, + { + self.data_initializer = Some(Box::new(move |value| Box::pin(initializer(value)))); + self + } + + /// Specify a connection data. + /// + /// This data usually comes from HTTP requests. + /// When the `GQL_CONNECTION_INIT` message is received, this data will be merged with the data + /// returned by the closure specified by `with_initializer` into the final subscription context data. + pub fn connection_data(mut self, data: Data) -> Self { + self.connection_data = Some(data); + self + } } -impl Stream - for WebSocket +impl WebSocket, Query, Mutation, Subscription> +where + S: Stream, + S::Item: AsRef<[u8]>, +{ + /// Create a new websocket from bytes stream. + pub fn new( + schema: Schema, + stream: S, + protocol: Protocols, + ) -> Self + where + S: Stream, + S::Item: AsRef<[u8]>, + { + let stream = stream + .map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result); + WebSocket::from_message_stream(schema, stream, protocol) + } +} + +impl Stream for WebSocket where S: Stream>, - F: FnOnce(serde_json::Value) -> R + Send + 'static, - R: Future> + Send + 'static, Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, @@ -266,7 +249,9 @@ where *this.init_fut = None; return match res { Ok(data) => { - *this.data = Some(Arc::new(data)); + let mut ctx_data = this.connection_data.take().unwrap_or_default(); + ctx_data.merge(data); + *this.data = Some(Arc::new(ctx_data)); Poll::Ready(Some(WsMessage::Text( serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(), ))) diff --git a/tests/subscription_websocket_graphql_ws.rs b/tests/subscription_websocket_graphql_ws.rs index 6164d798..f72ba7cb 100644 --- a/tests/subscription_websocket_graphql_ws.rs +++ b/tests/subscription_websocket_graphql_ws.rs @@ -111,10 +111,8 @@ pub async fn test_subscription_ws_transport_with_token() { let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let (mut tx, rx) = mpsc::unbounded(); - let mut stream = http::WebSocket::with_data( - schema, - rx, - |value| async { + let mut stream = http::WebSocket::new(schema.clone(), rx, WebSocketProtocols::GraphQLWS) + .with_initializer(|value| async { #[derive(serde::Deserialize)] struct Payload { token: String, @@ -124,9 +122,7 @@ pub async fn test_subscription_ws_transport_with_token() { let mut data = Data::default(); data.insert(Token(payload.token)); Ok(data) - }, - WebSocketProtocols::GraphQLWS, - ); + }); tx.send( serde_json::to_string(&value!({ @@ -176,6 +172,60 @@ pub async fn test_subscription_ws_transport_with_token() { })), serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap() ); + + let (mut tx, rx) = mpsc::unbounded(); + let mut data = Data::default(); + data.insert(Token("123456".to_string())); + let mut stream = + http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS).connection_data(data); + + tx.send( + serde_json::to_string(&value!({ + "type": "connection_init", + })) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + Some(value!({ + "type": "connection_ack", + })), + serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap() + ); + + tx.send( + serde_json::to_string(&value!({ + "type": "start", + "id": "1", + "payload": { + "query": "subscription { values }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); + + for i in 0..10 { + assert_eq!( + Some(value!({ + "type": "next", + "id": "1", + "payload": { "data": { "values": i } }, + })), + serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap() + ); + } + + assert_eq!( + Some(value!({ + "type": "complete", + "id": "1", + })), + serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap() + ); } #[tokio::test] @@ -296,12 +346,8 @@ pub async fn test_subscription_init_error() { let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let (mut tx, rx) = mpsc::unbounded(); - let mut stream = http::WebSocket::with_data( - schema, - rx, - |_| async move { Err("Error!".into()) }, - WebSocketProtocols::GraphQLWS, - ); + let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS) + .with_initializer(|_| async move { Err("Error!".into()) }); tx.send( serde_json::to_string(&value!({ diff --git a/tests/subscription_websocket_subscriptions_transport_ws.rs b/tests/subscription_websocket_subscriptions_transport_ws.rs index 25a3258f..56e110c3 100644 --- a/tests/subscription_websocket_subscriptions_transport_ws.rs +++ b/tests/subscription_websocket_subscriptions_transport_ws.rs @@ -107,10 +107,8 @@ pub async fn test_subscription_ws_transport_with_token() { let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let (mut tx, rx) = mpsc::unbounded(); - let mut stream = http::WebSocket::with_data( - schema, - rx, - |value| async { + let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS) + .with_initializer(|value| async { #[derive(serde::Deserialize)] struct Payload { token: String, @@ -120,9 +118,7 @@ pub async fn test_subscription_ws_transport_with_token() { let mut data = Data::default(); data.insert(Token(payload.token)); Ok(data) - }, - WebSocketProtocols::SubscriptionsTransportWS, - ); + }); tx.send( serde_json::to_string(&value!({