diff --git a/Cargo.toml b/Cargo.toml index ad972752..a52686b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "async-graphql" -version = "1.9.10" +version = "1.9.11" authors = ["sunli "] edition = "2018" description = "The GraphQL server library implemented by rust" @@ -18,7 +18,7 @@ default = ["bson", "uuid", "url", "chrono-tz", "validators"] validators = ["regex"] [dependencies] -async-graphql-derive = { path = "async-graphql-derive", version = "1.9.10" } +async-graphql-derive = { path = "async-graphql-derive", version = "1.9.11" } graphql-parser = "=0.2.3" anyhow = "1.0.26" thiserror = "1.0.11" diff --git a/async-graphql-actix-web/Cargo.toml b/async-graphql-actix-web/Cargo.toml index 8c7dedcd..9b6086e2 100644 --- a/async-graphql-actix-web/Cargo.toml +++ b/async-graphql-actix-web/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "async-graphql-actix-web" -version = "1.0.10" +version = "1.1.0" authors = ["sunli "] edition = "2018" description = "async-graphql for actix-web" @@ -13,7 +13,7 @@ keywords = ["futures", "async", "graphql"] categories = ["network-programming", "asynchronous"] [dependencies] -async-graphql = { path = "..", version = "1.9.10" } +async-graphql = { path = "..", version = "1.9.11" } actix-web = "2.0.0" actix-web-actors = "2.0.0" actix = "0.9.0" diff --git a/async-graphql-actix-web/src/subscription.rs b/async-graphql-actix-web/src/subscription.rs index a7334f32..489cc4db 100644 --- a/async-graphql-actix-web/src/subscription.rs +++ b/async-graphql-actix-web/src/subscription.rs @@ -2,10 +2,12 @@ use actix::{ Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture, }; use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext}; -use async_graphql::{ObjectType, Schema, SubscriptionType, WebSocketTransport}; +use async_graphql::{Data, ObjectType, Schema, SubscriptionType, WebSocketTransport}; use bytes::Bytes; use futures::channel::mpsc; use futures::SinkExt; +use std::any::Any; +use std::sync::Arc; use std::time::{Duration, Instant}; const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); @@ -16,6 +18,7 @@ pub struct WSSubscription { schema: Schema, hb: Instant, sink: Option>, + data: Data, } impl WSSubscription @@ -30,9 +33,18 @@ where schema: schema.clone(), hb: Instant::now(), sink: None, + data: Default::default(), } } + /// Add a context data that can be accessed in the `Context`, you access it with `Context::data`. + /// + /// **This data is only valid for this subscription** + pub fn data(mut self, data: D) -> Self { + self.data.insert(data); + self + } + fn hb(&self, ctx: &mut WebsocketContext) { ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT { @@ -54,7 +66,10 @@ where fn started(&mut self, ctx: &mut Self::Context) { self.hb(ctx); let schema = self.schema.clone(); - let (sink, stream) = schema.subscription_connection(WebSocketTransport::default()); + let (sink, stream) = schema.subscription_connection( + WebSocketTransport::default(), + Some(Arc::new(std::mem::take(&mut self.data))), + ); ctx.add_stream(stream); self.sink = Some(sink); } diff --git a/async-graphql-derive/Cargo.toml b/async-graphql-derive/Cargo.toml index 2ac4dfda..b8effe49 100644 --- a/async-graphql-derive/Cargo.toml +++ b/async-graphql-derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "async-graphql-derive" -version = "1.9.10" +version = "1.9.11" authors = ["sunli "] edition = "2018" description = "Macros for async-graphql" diff --git a/async-graphql-derive/src/subscription.rs b/async-graphql-derive/src/subscription.rs index e1cf9873..c9550c1f 100644 --- a/async-graphql-derive/src/subscription.rs +++ b/async-graphql-derive/src/subscription.rs @@ -3,7 +3,9 @@ use crate::utils::{build_value_repr, check_reserved_name, get_crate_name}; use inflector::Inflector; use proc_macro::TokenStream; use quote::quote; -use syn::{Error, FnArg, ImplItem, ItemImpl, Pat, Result, ReturnType, Type, TypeImplTrait}; +use syn::{ + Error, FnArg, ImplItem, ItemImpl, Pat, Result, ReturnType, Type, TypeImplTrait, TypeReference, +}; pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result { let crate_name = get_crate_name(object_args.internal); @@ -61,20 +63,25 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< )); } - match &method.sig.inputs[0] { - FnArg::Receiver(_) => {} - _ => { - return Err(Error::new_spanned( - &method.sig.inputs[0], - "The first argument must be self receiver", - )); - } - } - + let mut arg_ctx = false; let mut args = Vec::new(); - for arg in method.sig.inputs.iter_mut().skip(1) { - if let FnArg::Typed(pat) = arg { + for (idx, arg) in method.sig.inputs.iter_mut().enumerate() { + if let FnArg::Receiver(receiver) = arg { + if idx != 0 { + return Err(Error::new_spanned( + receiver, + "The self receiver must be the first parameter.", + )); + } + } else if let FnArg::Typed(pat) = arg { + if idx == 0 { + return Err(Error::new_spanned( + pat, + "The self receiver must be the first parameter.", + )); + } + match (&*pat.pat, &*pat.ty) { (Pat::Ident(arg_ident), Type::Path(arg_ty)) => { args.push(( @@ -84,6 +91,19 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< )); pat.attrs.clear(); } + (_, Type::Reference(TypeReference { elem, .. })) => { + if let Type::Path(path) = elem.as_ref() { + if idx != 1 + || path.path.segments.last().unwrap().ident != "Context" + { + return Err(Error::new_spanned( + arg, + "The Context must be the second argument.", + )); + } + arg_ctx = true; + } + } _ => { return Err(Error::new_spanned(arg, "Incorrect argument type")); } @@ -182,6 +202,12 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< }); }); + let ctx_param = if arg_ctx { + quote! { &ctx, } + } else { + quote! {} + }; + create_stream.push(quote! { if ctx.name.as_str() == #field_name { let field_name = ctx.result_name().to_string(); @@ -190,21 +216,20 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< let schema = schema.clone(); let pos = ctx.position; let environment = environment.clone(); - let stream = #crate_name::futures::stream::StreamExt::then(self.#ident(#(#use_params),*).await.fuse(), move |msg| { + let stream = #crate_name::futures::stream::StreamExt::then(self.#ident(#ctx_param #(#use_params),*).await.fuse(), move |msg| { let environment = environment.clone(); let field_selection_set = field_selection_set.clone(); let schema = schema.clone(); async move { let resolve_id = std::sync::atomic::AtomicUsize::default(); let ctx_selection_set = environment.create_context( - &*field_selection_set, + &schema, Some(#crate_name::QueryPathNode { parent: None, segment: #crate_name::QueryPathSegment::Name("time"), }), + &*field_selection_set, &resolve_id, - schema.registry(), - schema.data(), ); #crate_name::OutputValueType::resolve(&msg, &ctx_selection_set, pos).await } diff --git a/async-graphql-warp/Cargo.toml b/async-graphql-warp/Cargo.toml index a4996b0c..70cdb81d 100644 --- a/async-graphql-warp/Cargo.toml +++ b/async-graphql-warp/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "async-graphql-warp" -version = "1.0.11" +version = "1.1.0" authors = ["sunli "] edition = "2018" description = "async-graphql for warp" @@ -13,7 +13,7 @@ keywords = ["futures", "async", "graphql"] categories = ["network-programming", "asynchronous"] [dependencies] -async-graphql = { path = "..", version = "1.9.10" } +async-graphql = { path = "..", version = "1.9.11" } warp = "0.2.2" futures = "0.3.0" bytes = "0.5.4" diff --git a/async-graphql-warp/src/lib.rs b/async-graphql-warp/src/lib.rs index be0c29ff..7eba8b7f 100644 --- a/async-graphql-warp/src/lib.rs +++ b/async-graphql-warp/src/lib.rs @@ -6,8 +6,8 @@ use async_graphql::http::StreamBody; use async_graphql::{ - IntoQueryBuilder, IntoQueryBuilderOpts, ObjectType, QueryBuilder, Schema, SubscriptionType, - WebSocketTransport, + Data, IntoQueryBuilder, IntoQueryBuilderOpts, ObjectType, QueryBuilder, Schema, + SubscriptionType, WebSocketTransport, }; use bytes::Bytes; use futures::select; @@ -145,27 +145,30 @@ where /// #[tokio::main] /// async fn main() { /// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); -/// let filter = async_graphql_warp::graphql_subscription(schema); +/// let filter = async_graphql_warp::graphql_subscription(schema, None); /// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await; /// } /// ``` pub fn graphql_subscription( schema: Schema, + ctx_data: Option, ) -> BoxedFilter<(impl Reply,)> where Query: ObjectType + Sync + Send + 'static, Mutation: ObjectType + Sync + Send + 'static, Subscription: SubscriptionType + Send + Sync + 'static, { + let ctx_data = ctx_data.map(Arc::new); warp::any() .and(warp::ws()) .and(warp::any().map(move || schema.clone())) + .and(warp::any().map(move || ctx_data.clone())) .map( - |ws: warp::ws::Ws, schema: Schema| { + |ws: warp::ws::Ws, schema: Schema, ctx_data: Option>| { ws.on_upgrade(move |websocket| { let (mut tx, rx) = websocket.split(); let (mut stx, srx) = - schema.subscription_connection(WebSocketTransport::default()); + schema.subscription_connection(WebSocketTransport::default(),ctx_data.clone()); let mut rx = rx.fuse(); let mut srx = srx.fuse(); diff --git a/src/context.rs b/src/context.rs index 45e4e10e..10101133 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,6 +1,6 @@ use crate::extensions::BoxExtension; use crate::registry::Registry; -use crate::{InputValueType, Pos, QueryError, Result, Type}; +use crate::{InputValueType, Pos, QueryError, Result, Schema, Type}; use graphql_parser::query::{ Directive, Field, FragmentDefinition, SelectionSet, Value, VariableDefinition, }; @@ -9,6 +9,7 @@ use std::collections::{BTreeMap, HashMap}; use std::ops::{Deref, DerefMut}; use std::path::Path; use std::sync::atomic::AtomicUsize; +use std::sync::Arc; /// Variables of query #[derive(Debug, Clone)] @@ -128,9 +129,11 @@ fn json_value_to_gql_value(value: serde_json::Value) -> Value { } #[derive(Default)] +/// Schema/Context data pub struct Data(BTreeMap>); impl Data { + #[allow(missing_docs)] pub fn insert(&mut self, data: D) { self.0.insert(TypeId::of::(), Box::new(data)); } @@ -247,17 +250,17 @@ pub struct Environment { pub variables: Variables, pub variable_definitions: Vec, pub fragments: HashMap, + pub ctx_data: Arc, } impl Environment { #[doc(hidden)] - pub fn create_context<'a, T>( + pub fn create_context<'a, T, Query, Mutation, Subscription>( &'a self, - item: T, + schema: &'a Schema, path_node: Option>, + item: T, resolve_id: &'a AtomicUsize, - registry: &'a Registry, - data: &'a Data, ) -> ContextBase<'a, T> { ContextBase { path_node, @@ -266,24 +269,15 @@ impl Environment { item, variables: &self.variables, variable_definitions: &self.variable_definitions, - registry, - data, - ctx_data: None, + registry: &schema.0.registry, + data: &schema.0.data, + ctx_data: Some(&self.ctx_data), fragments: &self.fragments, } } } impl<'a, T> ContextBase<'a, T> { - #[doc(hidden)] - pub fn create_environment(&self) -> Environment { - Environment { - variables: self.variables.clone(), - variable_definitions: self.variable_definitions.to_vec(), - fragments: self.fragments.clone(), - } - } - #[doc(hidden)] pub fn get_resolve_id(&self) -> usize { self.resolve_id diff --git a/src/lib.rs b/src/lib.rs index 81464666..88e221cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -106,7 +106,9 @@ pub use serde_json; pub mod http; pub use base::{Scalar, Type}; -pub use context::{Context, ContextBase, Environment, QueryPathNode, QueryPathSegment, Variables}; +pub use context::{ + Context, ContextBase, Data, Environment, QueryPathNode, QueryPathSegment, Variables, +}; pub use error::{ Error, ErrorExtensions, FieldError, FieldResult, ParseRequestError, QueryError, ResultExt, }; diff --git a/src/schema.rs b/src/schema.rs index 27e0f2ec..c26b826c 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -7,8 +7,8 @@ use crate::subscription::{create_connection, create_subscription_stream, Subscri use crate::types::QueryRoot; use crate::validation::{check_rules, ValidationMode}; use crate::{ - ContextSelectionSet, Error, ObjectType, Pos, QueryError, QueryResponse, Result, - SubscriptionStream, SubscriptionType, Type, Variables, + Environment, Error, ObjectType, Pos, QueryError, QueryResponse, Result, SubscriptionStream, + SubscriptionType, Type, Variables, }; use bytes::Bytes; use futures::channel::mpsc; @@ -240,6 +240,7 @@ where source: &str, operation_name: Option<&str>, variables: Variables, + ctx_data: Option>, ) -> Result + Send> { let document = parse_query(source).map_err(Into::::into)?; check_rules(&self.0.registry, &document, self.0.validation_mode)?; @@ -272,22 +273,15 @@ where })?; let resolve_id = AtomicUsize::default(); - let ctx = ContextSelectionSet { - path_node: None, - extensions: &[], - item: &subscription.selection_set, - resolve_id: &resolve_id, - variables: &variables, - variable_definitions: &subscription.variable_definitions, - registry: &self.0.registry, - data: &Default::default(), - ctx_data: None, - fragments: &fragments, - }; - + let environment = Arc::new(Environment { + variables, + variable_definitions: subscription.variable_definitions, + fragments, + ctx_data: ctx_data.unwrap_or_default(), + }); + let ctx = environment.create_context(self, None, &subscription.selection_set, &resolve_id); let mut streams = Vec::new(); - create_subscription_stream(self, Arc::new(ctx.create_environment()), &ctx, &mut streams) - .await?; + create_subscription_stream(self, environment.clone(), &ctx, &mut streams).await?; Ok(futures::stream::select_all(streams)) } @@ -295,10 +289,11 @@ where pub fn subscription_connection( &self, transport: T, + ctx_data: Option>, ) -> ( mpsc::Sender, SubscriptionStream, ) { - create_connection(self.clone(), transport) + create_connection(self.clone(), transport, ctx_data.unwrap_or_default()) } } diff --git a/src/subscription/connection.rs b/src/subscription/connection.rs index 738d9c93..a27bb54b 100644 --- a/src/subscription/connection.rs +++ b/src/subscription/connection.rs @@ -1,3 +1,4 @@ +use crate::context::Data; use crate::{ObjectType, Schema, SubscriptionType}; use bytes::Bytes; use futures::channel::mpsc; @@ -6,6 +7,7 @@ use futures::Stream; use slab::Slab; use std::future::Future; use std::pin::Pin; +use std::sync::Arc; /// Use to hold all subscription stream for the `SubscriptionConnection` pub struct SubscriptionStreams { @@ -42,6 +44,7 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static { schema: &Schema, streams: &mut SubscriptionStreams, data: Bytes, + ctx_data: Arc, ) -> std::result::Result, Self::Error> where Query: ObjectType + Sync + Send + 'static, @@ -55,6 +58,7 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static { pub fn create_connection( schema: Schema, transport: T, + ctx_data: Arc, ) -> ( mpsc::Sender, SubscriptionStream, @@ -69,6 +73,7 @@ where tx_bytes, SubscriptionStream { schema, + ctx_data, transport, streams: SubscriptionStreams { streams: Default::default(), @@ -83,6 +88,7 @@ where #[allow(clippy::type_complexity)] pub struct SubscriptionStream { schema: Schema, + ctx_data: Arc, transport: T, streams: SubscriptionStreams, rx_bytes: mpsc::Receiver, @@ -127,11 +133,13 @@ where let transport = &mut this.transport as *mut T; let schema = &this.schema as *const Schema; let streams = &mut this.streams as *mut SubscriptionStreams; + let ctx_data = this.ctx_data.clone(); unsafe { this.handle_request_fut = Some(Box::pin((*transport).handle_request( &*schema, &mut *streams, data, + ctx_data.clone(), ))); } continue; diff --git a/src/subscription/ws_transport.rs b/src/subscription/ws_transport.rs index 4dea3f7e..80f7ac68 100644 --- a/src/subscription/ws_transport.rs +++ b/src/subscription/ws_transport.rs @@ -1,3 +1,4 @@ +use crate::context::Data; use crate::http::{GQLError, GQLRequest, GQLResponse}; use crate::{ ObjectType, QueryResponse, Schema, SubscriptionStreams, SubscriptionTransport, @@ -5,12 +6,17 @@ use crate::{ }; use bytes::Bytes; use std::collections::HashMap; +use std::sync::Arc; #[derive(Serialize, Deserialize)] struct OperationMessage { #[serde(rename = "type")] ty: String, + + #[serde(skip_serializing_if = "Option::is_none")] id: Option, + + #[serde(skip_serializing_if = "Option::is_none")] payload: Option, } @@ -30,6 +36,7 @@ impl SubscriptionTransport for WebSocketTransport { schema: &Schema, streams: &mut SubscriptionStreams, data: Bytes, + ctx_data: Arc, ) -> std::result::Result, Self::Error> where Query: ObjectType + Sync + Send + 'static, @@ -60,6 +67,7 @@ impl SubscriptionTransport for WebSocketTransport { &request.query, request.operation_name.as_deref(), variables, + Some(ctx_data), ) .await { diff --git a/tests/subscription.rs b/tests/subscription.rs index f9a80876..6f9192b9 100644 --- a/tests/subscription.rs +++ b/tests/subscription.rs @@ -1,5 +1,6 @@ use async_graphql::*; -use futures::{Stream, StreamExt}; +use futures::{SinkExt, Stream, StreamExt}; +use std::sync::Arc; #[async_std::test] pub async fn test_subscription() { @@ -40,6 +41,7 @@ pub async fn test_subscription() { "subscription { values(start: 10, end: 20) }", None, Default::default(), + None, ) .await .unwrap(); @@ -58,6 +60,7 @@ pub async fn test_subscription() { "subscription { events(start: 10, end: 20) { a b } }", None, Default::default(), + None, ) .await .unwrap(); @@ -113,6 +116,7 @@ pub async fn test_simple_broker() { "subscription { events1 { value } }", None, Default::default(), + None, ) .await .unwrap(); @@ -121,6 +125,7 @@ pub async fn test_simple_broker() { "subscription { events2 { value } }", None, Default::default(), + None, ) .await .unwrap(); @@ -148,3 +153,116 @@ pub async fn test_simple_broker() { Some(serde_json::json!({ "events2": {"value": 99} })) ); } + +#[async_std::test] +pub async fn test_subscription_with_ctx_data() { + struct QueryRoot; + + #[Object] + impl QueryRoot {} + + struct SubscriptionRoot; + + #[Subscription] + impl SubscriptionRoot { + #[field] + async fn values(&self, ctx: &Context<'_>) -> impl Stream { + let value = *ctx.data::(); + futures::stream::once(async move { value }) + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); + + { + let mut stream = schema + .create_subscription_stream( + "subscription { values }", + None, + Default::default(), + Some(Arc::new({ + let mut data = Data::default(); + data.insert(100i32); + data + })), + ) + .await + .unwrap(); + assert_eq!( + Some(serde_json::json!({ "values": 100 })), + stream.next().await + ); + assert!(stream.next().await.is_none()); + } +} + +#[async_std::test] +pub async fn test_subscription_ws_transport() { + struct QueryRoot; + + #[Object] + impl QueryRoot {} + + struct SubscriptionRoot; + + #[Subscription] + impl SubscriptionRoot { + #[field] + async fn values(&self, ctx: &Context<'_>) -> impl Stream { + let step = *ctx.data::(); + futures::stream::iter((0..10).map(move |n| n * step)) + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); + let (mut sink, mut stream) = schema.subscription_connection( + WebSocketTransport::default(), + Some(Arc::new({ + let mut data = Data::default(); + data.insert(5); + data + })), + ); + + sink.send( + serde_json::to_vec(&serde_json::json!({ + "type": "connection_init", + })) + .unwrap() + .into(), + ) + .await + .unwrap(); + + assert_eq!( + Some(serde_json::json!({ + "type": "connection_ack", + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); + + sink.send( + serde_json::to_vec(&serde_json::json!({ + "type": "start", + "id": "1", + "payload": { + "query": "subscription { values }" + }, + })) + .unwrap() + .into(), + ) + .await + .unwrap(); + + for i in 0..10 { + assert_eq!( + Some(serde_json::json!({ + "type": "data", + "id": "1", + "payload": { "data": { "values": i * 5 } }, + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); + } +}