From 998e9b7b85e5cb156f17e897f75da9f721263faf Mon Sep 17 00:00:00 2001 From: sunli Date: Thu, 23 Apr 2020 14:52:22 +0800 Subject: [PATCH] Websocket transport creates context data from the connect_init.payload property --- async-graphql-actix-web/src/subscription.rs | 19 +-- async-graphql-warp/src/lib.rs | 13 +- src/context.rs | 7 +- src/schema.rs | 3 +- src/subscription/connection.rs | 22 ++- src/subscription/ws_transport.rs | 44 +++-- tests/subscription.rs | 170 ++++++++++---------- 7 files changed, 139 insertions(+), 139 deletions(-) diff --git a/async-graphql-actix-web/src/subscription.rs b/async-graphql-actix-web/src/subscription.rs index 489cc4db..a7334f32 100644 --- a/async-graphql-actix-web/src/subscription.rs +++ b/async-graphql-actix-web/src/subscription.rs @@ -2,12 +2,10 @@ use actix::{ Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture, }; use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext}; -use async_graphql::{Data, ObjectType, Schema, SubscriptionType, WebSocketTransport}; +use async_graphql::{ObjectType, Schema, SubscriptionType, WebSocketTransport}; use bytes::Bytes; use futures::channel::mpsc; use futures::SinkExt; -use std::any::Any; -use std::sync::Arc; use std::time::{Duration, Instant}; const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); @@ -18,7 +16,6 @@ pub struct WSSubscription { schema: Schema, hb: Instant, sink: Option>, - data: Data, } impl WSSubscription @@ -33,18 +30,9 @@ where schema: schema.clone(), hb: Instant::now(), sink: None, - data: Default::default(), } } - /// Add a context data that can be accessed in the `Context`, you access it with `Context::data`. - /// - /// **This data is only valid for this subscription** - pub fn data(mut self, data: D) -> Self { - self.data.insert(data); - self - } - fn hb(&self, ctx: &mut WebsocketContext) { ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT { @@ -66,10 +54,7 @@ where fn started(&mut self, ctx: &mut Self::Context) { self.hb(ctx); let schema = self.schema.clone(); - let (sink, stream) = schema.subscription_connection( - WebSocketTransport::default(), - Some(Arc::new(std::mem::take(&mut self.data))), - ); + let (sink, stream) = schema.subscription_connection(WebSocketTransport::default()); ctx.add_stream(stream); self.sink = Some(sink); } diff --git a/async-graphql-warp/src/lib.rs b/async-graphql-warp/src/lib.rs index 7eba8b7f..be0c29ff 100644 --- a/async-graphql-warp/src/lib.rs +++ b/async-graphql-warp/src/lib.rs @@ -6,8 +6,8 @@ use async_graphql::http::StreamBody; use async_graphql::{ - Data, IntoQueryBuilder, IntoQueryBuilderOpts, ObjectType, QueryBuilder, Schema, - SubscriptionType, WebSocketTransport, + IntoQueryBuilder, IntoQueryBuilderOpts, ObjectType, QueryBuilder, Schema, SubscriptionType, + WebSocketTransport, }; use bytes::Bytes; use futures::select; @@ -145,30 +145,27 @@ where /// #[tokio::main] /// async fn main() { /// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); -/// let filter = async_graphql_warp::graphql_subscription(schema, None); +/// let filter = async_graphql_warp::graphql_subscription(schema); /// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await; /// } /// ``` pub fn graphql_subscription( schema: Schema, - ctx_data: Option, ) -> BoxedFilter<(impl Reply,)> where Query: ObjectType + Sync + Send + 'static, Mutation: ObjectType + Sync + Send + 'static, Subscription: SubscriptionType + Send + Sync + 'static, { - let ctx_data = ctx_data.map(Arc::new); warp::any() .and(warp::ws()) .and(warp::any().map(move || schema.clone())) - .and(warp::any().map(move || ctx_data.clone())) .map( - |ws: warp::ws::Ws, schema: Schema, ctx_data: Option>| { + |ws: warp::ws::Ws, schema: Schema| { ws.on_upgrade(move |websocket| { let (mut tx, rx) = websocket.split(); let (mut stx, srx) = - schema.subscription_connection(WebSocketTransport::default(),ctx_data.clone()); + schema.subscription_connection(WebSocketTransport::default()); let mut rx = rx.fuse(); let mut srx = srx.fuse(); diff --git a/src/context.rs b/src/context.rs index 10101133..bb5a5c16 100644 --- a/src/context.rs +++ b/src/context.rs @@ -329,11 +329,16 @@ impl<'a, T> ContextBase<'a, T> { /// Gets the global data defined in the `Context` or `Schema`. pub fn data(&self) -> &D { + self.data_opt::() + .expect("The specified data type does not exist.") + } + + /// Gets the global data defined in the `Context` or `Schema`, returns `None` if the specified type data does not exist. + pub fn data_opt(&self) -> Option<&D> { self.ctx_data .and_then(|ctx_data| ctx_data.0.get(&TypeId::of::())) .or_else(|| self.data.0.get(&TypeId::of::())) .and_then(|d| d.downcast_ref::()) - .expect("The specified data type does not exist.") } fn var_value(&self, name: &str, pos: Pos) -> Result { diff --git a/src/schema.rs b/src/schema.rs index c26b826c..d6314534 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -289,11 +289,10 @@ where pub fn subscription_connection( &self, transport: T, - ctx_data: Option>, ) -> ( mpsc::Sender, SubscriptionStream, ) { - create_connection(self.clone(), transport, ctx_data.unwrap_or_default()) + create_connection(self.clone(), transport) } } diff --git a/src/subscription/connection.rs b/src/subscription/connection.rs index a27bb54b..6cb9d7f0 100644 --- a/src/subscription/connection.rs +++ b/src/subscription/connection.rs @@ -1,4 +1,3 @@ -use crate::context::Data; use crate::{ObjectType, Schema, SubscriptionType}; use bytes::Bytes; use futures::channel::mpsc; @@ -7,7 +6,6 @@ use futures::Stream; use slab::Slab; use std::future::Future; use std::pin::Pin; -use std::sync::Arc; /// Use to hold all subscription stream for the `SubscriptionConnection` pub struct SubscriptionStreams { @@ -44,7 +42,6 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static { schema: &Schema, streams: &mut SubscriptionStreams, data: Bytes, - ctx_data: Arc, ) -> std::result::Result, Self::Error> where Query: ObjectType + Sync + Send + 'static, @@ -58,7 +55,6 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static { pub fn create_connection( schema: Schema, transport: T, - ctx_data: Arc, ) -> ( mpsc::Sender, SubscriptionStream, @@ -73,7 +69,6 @@ where tx_bytes, SubscriptionStream { schema, - ctx_data, transport, streams: SubscriptionStreams { streams: Default::default(), @@ -84,19 +79,22 @@ where ) } +type HandleRequestBoxFut = Pin< + Box< + dyn Future, ::Error>> + + Send + + 'static, + >, +>; + #[allow(missing_docs)] #[allow(clippy::type_complexity)] pub struct SubscriptionStream { schema: Schema, - ctx_data: Arc, transport: T, streams: SubscriptionStreams, rx_bytes: mpsc::Receiver, - handle_request_fut: Option< - Pin< - Box, T::Error>> + Send + 'static>, - >, - >, + handle_request_fut: Option>, } impl Stream @@ -133,13 +131,11 @@ where let transport = &mut this.transport as *mut T; let schema = &this.schema as *const Schema; let streams = &mut this.streams as *mut SubscriptionStreams; - let ctx_data = this.ctx_data.clone(); unsafe { this.handle_request_fut = Some(Box::pin((*transport).handle_request( &*schema, &mut *streams, data, - ctx_data.clone(), ))); } continue; diff --git a/src/subscription/ws_transport.rs b/src/subscription/ws_transport.rs index 80f7ac68..66af1302 100644 --- a/src/subscription/ws_transport.rs +++ b/src/subscription/ws_transport.rs @@ -25,6 +25,20 @@ struct OperationMessage { pub struct WebSocketTransport { id_to_sid: HashMap, sid_to_id: HashMap, + data: Arc, + init_with_payload: Option Data + Send + Sync>>, +} + +impl WebSocketTransport { + /// Creates a websocket transport and sets the function that converts the `payload` of the `connect_init` message to `Data`. + pub fn new Data + Send + Sync + 'static>( + init_with_payload: F, + ) -> Self { + WebSocketTransport { + init_with_payload: Some(Box::new(init_with_payload)), + ..WebSocketTransport::default() + } + } } #[async_trait::async_trait] @@ -36,7 +50,6 @@ impl SubscriptionTransport for WebSocketTransport { schema: &Schema, streams: &mut SubscriptionStreams, data: Bytes, - ctx_data: Arc, ) -> std::result::Result, Self::Error> where Query: ObjectType + Sync + Send + 'static, @@ -45,15 +58,22 @@ impl SubscriptionTransport for WebSocketTransport { { match serde_json::from_slice::(&data) { Ok(msg) => match msg.ty.as_str() { - "connection_init" => Ok(Some( - serde_json::to_vec(&OperationMessage { - ty: "connection_ack".to_string(), - id: None, - payload: None, - }) - .unwrap() - .into(), - )), + "connection_init" => { + if let Some(payload) = msg.payload { + if let Some(init_with_payload) = &self.init_with_payload { + self.data = Arc::new(init_with_payload(payload)); + } + } + Ok(Some( + serde_json::to_vec(&OperationMessage { + ty: "connection_ack".to_string(), + id: None, + payload: None, + }) + .unwrap() + .into(), + )) + } "start" => { if let (Some(id), Some(payload)) = (msg.id, msg.payload) { if let Ok(request) = serde_json::from_value::(payload) { @@ -67,7 +87,7 @@ impl SubscriptionTransport for WebSocketTransport { &request.query, request.operation_name.as_deref(), variables, - Some(ctx_data), + Some(self.data.clone()), ) .await { @@ -106,7 +126,7 @@ impl SubscriptionTransport for WebSocketTransport { Ok(None) } "connection_terminate" => Err("connection_terminate".to_string()), - _ => Err("unknown op".to_string()), + _ => Err("Unknown op".to_string()), }, Err(err) => Err(err.to_string()), } diff --git a/tests/subscription.rs b/tests/subscription.rs index 372cb51d..5cfb8b04 100644 --- a/tests/subscription.rs +++ b/tests/subscription.rs @@ -270,100 +270,98 @@ pub async fn test_subscription_ws_transport() { #[Subscription] impl SubscriptionRoot { #[field] - async fn values(&self, ctx: &Context<'_>) -> impl Stream { - let step = *ctx.data::(); - futures::stream::iter((0..10).map(move |n| n * step)) + async fn values(&self) -> impl Stream { + futures::stream::iter(0..10) } } let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); - let (mut sink, mut stream) = schema.subscription_connection( - WebSocketTransport::default(), - Some(Arc::new({ - let mut data = Data::default(); - data.insert(5); - data - })), - ); - - sink.send( - serde_json::to_vec(&serde_json::json!({ - "type": "connection_init", - })) - .unwrap() - .into(), - ) - .await - .unwrap(); - - assert_eq!( - Some(serde_json::json!({ - "type": "connection_ack", - })), - serde_json::from_slice(&stream.next().await.unwrap()).unwrap() - ); - - sink.send( - serde_json::to_vec(&serde_json::json!({ - "type": "start", - "id": "1", - "payload": { - "query": "subscription { values }" - }, - })) - .unwrap() - .into(), - ) - .await - .unwrap(); - - for i in 0..10 { - assert_eq!( - Some(serde_json::json!({ - "type": "data", - "id": "1", - "payload": { "data": { "values": i * 5 } }, - })), - serde_json::from_slice(&stream.next().await.unwrap()).unwrap() - ); - } -} - -#[async_std::test] -pub async fn test_subscription_ws_transport_with_token() { - struct Token(String); - - struct QueryRoot; - - #[Object] - impl QueryRoot {} - - struct SubscriptionRoot; - - #[Subscription] - impl SubscriptionRoot { - #[field] - async fn values(&self, ctx: &Context<'_>) -> FieldResult> { - if ctx.data::().0 != "123456" { - return Err("forbidden".into()); - } - Ok(futures::stream::iter(0..10)) - } - } - - let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); - let (mut sink, mut stream) = schema.subscription_connection( - WebSocketTransport::default(), - Some(Arc::new({ - let mut data = Data::default(); - data.insert(Token("123456".to_string())); - data - })), - ); + let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::default()); sink.send( serde_json::to_vec(&serde_json::json!({ "type": "connection_init", + "payload": { "token": "123456" } + })) + .unwrap() + .into(), + ) + .await + .unwrap(); + + assert_eq!( + Some(serde_json::json!({ + "type": "connection_ack", + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); + + sink.send( + serde_json::to_vec(&serde_json::json!({ + "type": "start", + "id": "1", + "payload": { + "query": "subscription { values }" + }, + })) + .unwrap() + .into(), + ) + .await + .unwrap(); + + for i in 0..10 { + assert_eq!( + Some(serde_json::json!({ + "type": "data", + "id": "1", + "payload": { "data": { "values": i } }, + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); + } +} + +#[async_std::test] +pub async fn test_subscription_ws_transport_with_token() { + struct Token(String); + + struct QueryRoot; + + #[Object] + impl QueryRoot {} + + struct SubscriptionRoot; + + #[Subscription] + impl SubscriptionRoot { + #[field] + async fn values(&self, ctx: &Context<'_>) -> FieldResult> { + if ctx.data::().0 != "123456" { + return Err("forbidden".into()); + } + Ok(futures::stream::iter(0..10)) + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); + + let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::new(|value| { + #[derive(serde_derive::Deserialize)] + struct Payload { + token: String, + } + + let payload: Payload = serde_json::from_value(value).unwrap(); + let mut data = Data::default(); + data.insert(Token(payload.token)); + data + })); + + sink.send( + serde_json::to_vec(&serde_json::json!({ + "type": "connection_init", + "payload": { "token": "123456" } })) .unwrap() .into(),