From a13204273ee8651a6d9764a9944bf8168b74ff37 Mon Sep 17 00:00:00 2001 From: sunli Date: Tue, 31 Mar 2020 11:19:18 +0800 Subject: [PATCH] v1.6.8 Add context data support --- Cargo.toml | 5 +- async-graphql-actix-web/Cargo.toml | 4 +- async-graphql-actix-web/examples/token.rs | 36 +++++++ async-graphql-actix-web/src/lib.rs | 98 +++++++++++++++++-- async-graphql-actix-web/src/session.rs | 23 ++++- async-graphql-derive/Cargo.toml | 2 +- src/context.rs | 16 +-- src/extensions/tracing.rs | 6 +- src/http/mod.rs | 26 ++++- src/lib.rs | 5 +- src/query.rs | 17 ++++ src/schema.rs | 27 ++--- src/subscription/connection.rs | 31 +++--- src/subscription/connection_builder.rs | 43 ++++++++ src/subscription/mod.rs | 6 +- ...subscribe_stub.rs => subscription_stub.rs} | 4 + 16 files changed, 286 insertions(+), 63 deletions(-) create mode 100644 async-graphql-actix-web/examples/token.rs create mode 100644 src/subscription/connection_builder.rs rename src/subscription/{subscribe_stub.rs => subscription_stub.rs} (92%) diff --git a/Cargo.toml b/Cargo.toml index 7564f6fb..aa352b31 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "async-graphql" -version = "1.6.7" +version = "1.6.8" authors = ["sunli "] edition = "2018" description = "The GraphQL server library implemented by rust" @@ -18,7 +18,7 @@ default = ["bson", "uuid", "url", "chrono-tz", "validators"] validators = ["regex"] [dependencies] -async-graphql-derive = { path = "async-graphql-derive", version = "1.6.7" } +async-graphql-derive = { path = "async-graphql-derive", version = "1.6.8" } graphql-parser = "=0.2.3" anyhow = "1.0.26" thiserror = "1.0.11" @@ -26,7 +26,6 @@ async-trait = "0.1.24" serde = "1.0.104" serde_derive = "1.0.104" serde_json = "1.0.48" -fnv = "1.0.6" bytes = "0.5.4" Inflector = "0.11.4" base64 = "0.12.0" diff --git a/async-graphql-actix-web/Cargo.toml b/async-graphql-actix-web/Cargo.toml index ef9d4c29..6c6ad070 100644 --- a/async-graphql-actix-web/Cargo.toml +++ b/async-graphql-actix-web/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "async-graphql-actix-web" -version = "0.6.10" +version = "0.6.11" authors = ["sunli "] edition = "2018" description = "async-graphql for actix-web" @@ -13,7 +13,7 @@ keywords = ["futures", "async", "graphql"] categories = ["network-programming", "asynchronous"] [dependencies] -async-graphql = { path = "..", version = "1.6.7" } +async-graphql = { path = "..", version = "1.6.8" } actix-web = "2.0.0" actix-multipart = "0.2.0" actix-web-actors = "2.0.0" diff --git a/async-graphql-actix-web/examples/token.rs b/async-graphql-actix-web/examples/token.rs new file mode 100644 index 00000000..b239c316 --- /dev/null +++ b/async-graphql-actix-web/examples/token.rs @@ -0,0 +1,36 @@ +use actix_web::{web, App, HttpServer}; +use async_graphql::{Context, EmptyMutation, EmptySubscription, Schema}; + +struct MyToken(Option); + +struct QueryRoot; + +#[async_graphql::Object] +impl QueryRoot { + #[field] + async fn current_token<'a>(&self, ctx: &'a Context<'_>) -> Option<&'a str> { + ctx.data::().0.as_deref() + } +} + +#[actix_rt::main] +async fn main() -> std::io::Result<()> { + HttpServer::new(move || { + let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription); + let handler = async_graphql_actix_web::HandlerBuilder::new(schema) + .enable_subscription() + .enable_ui("http://localhost:8000", None) + .on_request(|req, builder| { + builder.data(MyToken( + req.headers() + .get("Token") + .and_then(|value| value.to_str().map(ToString::to_string).ok()), + )) + }) + .build(); + App::new().service(web::resource("/").to(handler)) + }) + .bind("127.0.0.1:8000")? + .run() + .await +} diff --git a/async-graphql-actix-web/src/lib.rs b/async-graphql-actix-web/src/lib.rs index cc95ef90..85f49b5f 100644 --- a/async-graphql-actix-web/src/lib.rs +++ b/async-graphql-actix-web/src/lib.rs @@ -11,15 +11,31 @@ use actix_web::web::{BytesMut, Payload}; use actix_web::{web, FromRequest, HttpRequest, HttpResponse, Responder}; use actix_web_actors::ws; use async_graphql::http::{GQLRequest, GQLResponse}; -use async_graphql::{ObjectType, Schema, SubscriptionType}; +use async_graphql::{ + ObjectType, QueryBuilder, Schema, SubscriptionConnectionBuilder, SubscriptionType, + WebSocketTransport, +}; use bytes::Bytes; use futures::StreamExt; use mime::Mime; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; +use std::sync::Arc; -// pub use pubsub::publish_message; +type BoxOnRequestFn = Arc< + dyn for<'a> Fn( + &HttpRequest, + QueryBuilder<'a, Query, Mutation, Subscription>, + ) -> QueryBuilder<'a, Query, Mutation, Subscription>, +>; + +type BoxOnConnectFn = Arc< + dyn Fn( + &HttpRequest, + SubscriptionConnectionBuilder, + ) -> SubscriptionConnectionBuilder, +>; /// Actix-web handler builder pub struct HandlerBuilder { @@ -28,6 +44,8 @@ pub struct HandlerBuilder { max_file_count: usize, enable_subscription: bool, enable_ui: Option<(String, Option)>, + on_request: Option>, + on_connect: Option>, } impl HandlerBuilder @@ -44,6 +62,8 @@ where max_file_count: 9, enable_subscription: false, enable_ui: None, + on_request: None, + on_connect: None, } } @@ -85,6 +105,41 @@ where } } + /// When a new request arrives, you can use this closure to append your own data to the `QueryBuilder`. + pub fn on_request< + F: for<'a> Fn( + &HttpRequest, + QueryBuilder<'a, Query, Mutation, Subscription>, + ) -> QueryBuilder<'a, Query, Mutation, Subscription> + + 'static, + >( + self, + f: F, + ) -> Self { + Self { + on_request: Some(Arc::new(f)), + ..self + } + } + + /// When there is a new subscription connection, you can use this closure to append your own data to the `SubscriptionConnectionBuilder`. + pub fn on_connect< + F: Fn( + &HttpRequest, + SubscriptionConnectionBuilder, + ) + -> SubscriptionConnectionBuilder + + 'static, + >( + self, + f: F, + ) -> Self { + Self { + on_connect: Some(Arc::new(f)), + ..self + } + } + /// Create an HTTP handler. pub fn build( self, @@ -99,10 +154,14 @@ where let max_file_count = self.max_file_count; let enable_ui = self.enable_ui; let enable_subscription = self.enable_subscription; + let on_request = self.on_request; + let on_connect = self.on_connect; move |req: HttpRequest, payload: Payload| { let schema = schema.clone(); let enable_ui = enable_ui.clone(); + let on_request = on_request.clone(); + let on_connect = on_connect.clone(); Box::pin(async move { if req.method() == Method::GET { @@ -111,7 +170,11 @@ where if let Ok(s) = s.to_str() { if s.to_ascii_lowercase().contains("websocket") { return ws::start_with_protocols( - WsSession::new(schema.clone()), + WsSession::new( + schema.clone(), + req.clone(), + on_connect.clone(), + ), &["graphql-ws"], &req, payload, @@ -132,7 +195,15 @@ where } if req.method() == Method::POST { - handle_request(&schema, max_file_size, max_file_count, req, payload).await + handle_request( + &schema, + max_file_size, + max_file_count, + req, + payload, + on_request.as_ref(), + ) + .await } else { Ok(HttpResponse::MethodNotAllowed().finish()) } @@ -147,6 +218,7 @@ async fn handle_request( max_file_count: usize, req: HttpRequest, mut payload: Payload, + on_request: Option<&BoxOnRequestFn>, ) -> actix_web::Result where Query: ObjectType + Send + Sync + 'static, @@ -171,7 +243,15 @@ where .map_err(actix_web::error::ErrorBadRequest)? }; - let mut prepared = match gql_request.prepare(schema) { + let mut builder = gql_request + .builder(schema) + .map_err(actix_web::error::ErrorBadRequest)?; + + if let Some(on_request) = on_request { + builder = on_request(&req, builder); + } + + let mut prepared = match builder.prepare() { Ok(prepared) => prepared, Err(err) => return Ok(web::Json(GQLResponse(Err(err))).respond_to(&req).await?), }; @@ -243,7 +323,13 @@ where let mut gql_req = web::Json::::from_request(&req, &mut payload.0) .await? .into_inner(); - let prepared = match gql_req.prepare(schema) { + let mut builder = gql_req + .builder(schema) + .map_err(actix_web::error::ErrorBadRequest)?; + if let Some(on_request) = on_request { + builder = on_request(&req, builder); + } + let prepared = match builder.prepare() { Ok(prepared) => prepared, Err(err) => return Ok(web::Json(GQLResponse(Err(err))).respond_to(&req).await?), }; diff --git a/async-graphql-actix-web/src/session.rs b/async-graphql-actix-web/src/session.rs index e786156a..cd2b60f4 100644 --- a/async-graphql-actix-web/src/session.rs +++ b/async-graphql-actix-web/src/session.rs @@ -1,6 +1,8 @@ +use crate::BoxOnConnectFn; use actix::{ Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture, }; +use actix_web::HttpRequest; use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext}; use async_graphql::{ObjectType, Schema, SubscriptionType, WebSocketTransport}; use bytes::Bytes; @@ -9,9 +11,11 @@ use futures::SinkExt; use std::time::{Duration, Instant}; pub struct WsSession { + req: HttpRequest, schema: Schema, hb: Instant, sink: Option>, + on_connect: Option>, } impl WsSession @@ -20,11 +24,17 @@ where Mutation: ObjectType + Send + Sync + 'static, Subscription: SubscriptionType + Send + Sync + 'static, { - pub fn new(schema: Schema) -> Self { + pub fn new( + schema: Schema, + req: HttpRequest, + on_connect: Option>, + ) -> Self { Self { + req, schema, hb: Instant::now(), sink: None, + on_connect, } } @@ -48,11 +58,16 @@ where fn started(&mut self, ctx: &mut Self::Context) { self.hb(ctx); let schema = self.schema.clone(); + let on_connect = self.on_connect.clone(); + let req = self.req.clone(); async move { - schema + let mut builder = schema .clone() - .subscription_connection(WebSocketTransport::default()) - .await + .subscription_connection(WebSocketTransport::default()); + if let Some(on_connect) = on_connect { + builder = on_connect(&req, builder); + } + builder.build().await } .into_actor(self) .then(|(sink, stream), actor, ctx| { diff --git a/async-graphql-derive/Cargo.toml b/async-graphql-derive/Cargo.toml index aeb323d5..ee00bacb 100644 --- a/async-graphql-derive/Cargo.toml +++ b/async-graphql-derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "async-graphql-derive" -version = "1.6.7" +version = "1.6.8" authors = ["sunli "] edition = "2018" description = "Macros for async-graphql" diff --git a/src/context.rs b/src/context.rs index ea91ec91..9836fc68 100644 --- a/src/context.rs +++ b/src/context.rs @@ -2,13 +2,11 @@ use crate::extensions::BoxExtension; use crate::registry::Registry; use crate::{ErrorWithPosition, InputValueType, QueryError, Result, Type}; use bytes::Bytes; -use fnv::FnvHasher; use graphql_parser::query::{ Directive, Field, FragmentDefinition, SelectionSet, Value, VariableDefinition, }; use std::any::{Any, TypeId}; use std::collections::{BTreeMap, HashMap}; -use std::hash::BuildHasherDefault; use std::ops::{Deref, DerefMut}; use std::sync::atomic::AtomicUsize; @@ -131,7 +129,7 @@ fn json_value_to_gql_value(value: serde_json::Value) -> Value { } #[derive(Default)] -pub struct Data(HashMap, BuildHasherDefault>); +pub struct Data(BTreeMap>); impl Data { pub fn insert(&mut self, data: D) { @@ -220,6 +218,7 @@ pub struct ContextBase<'a, T> { pub(crate) variable_definitions: Option<&'a [VariableDefinition]>, pub(crate) registry: &'a Registry, pub(crate) data: &'a Data, + pub(crate) ctx_data: Option<&'a Data>, pub(crate) fragments: &'a HashMap, } @@ -257,6 +256,7 @@ impl<'a, T> ContextBase<'a, T> { variable_definitions: self.variable_definitions, registry: self.registry, data: self.data, + ctx_data: self.ctx_data, fragments: self.fragments, } } @@ -275,15 +275,16 @@ impl<'a, T> ContextBase<'a, T> { variable_definitions: self.variable_definitions, registry: self.registry, data: self.data, + ctx_data: self.ctx_data, fragments: self.fragments, } } - /// Gets the global data defined in the `Schema`. + /// Gets the global data defined in the `Context` or `Schema`. pub fn data(&self) -> &D { - self.data - .0 - .get(&TypeId::of::()) + self.ctx_data + .and_then(|ctx_data| ctx_data.0.get(&TypeId::of::())) + .or_else(|| self.data.0.get(&TypeId::of::())) .and_then(|d| d.downcast_ref::()) .expect("The specified data type does not exist.") } @@ -413,6 +414,7 @@ impl<'a> ContextBase<'a, &'a SelectionSet> { variable_definitions: self.variable_definitions, registry: self.registry, data: self.data, + ctx_data: self.ctx_data, fragments: self.fragments, } } diff --git a/src/extensions/tracing.rs b/src/extensions/tracing.rs index 30271a2d..4df064f7 100644 --- a/src/extensions/tracing.rs +++ b/src/extensions/tracing.rs @@ -1,12 +1,10 @@ use crate::extensions::{Extension, ResolveInfo}; use crate::QueryPathSegment; use chrono::{DateTime, Utc}; -use fnv::FnvHasher; use parking_lot::Mutex; use serde::ser::SerializeMap; use serde::{Serialize, Serializer}; -use std::collections::HashMap; -use std::hash::BuildHasherDefault; +use std::collections::BTreeMap; use std::ops::Deref; struct PendingResolve { @@ -50,7 +48,7 @@ impl Serialize for ResolveStat { struct Inner { start_time: DateTime, end_time: DateTime, - pending_resolves: HashMap>, + pending_resolves: BTreeMap, resolves: Vec, } diff --git a/src/http/mod.rs b/src/http/mod.rs index 07919913..1c84f5c8 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -8,7 +8,10 @@ pub use playground_source::playground_source; use crate::error::{ExtendedError, RuleError, RuleErrors}; use crate::query::PreparedQuery; -use crate::{ObjectType, PositionError, QueryResult, Result, Schema, SubscriptionType, Variables}; +use crate::{ + ObjectType, PositionError, QueryBuilder, QueryResult, Result, Schema, SubscriptionType, + Variables, +}; use graphql_parser::Pos; use serde::ser::{SerializeMap, SerializeSeq}; use serde::{Serialize, Serializer}; @@ -45,11 +48,11 @@ impl GQLRequest { } } - /// Prepare a query and return a `PreparedQuery` object that gets some information about the query. - pub fn prepare<'a, Query, Mutation, Subscription>( + /// Create query builder + pub fn builder<'a, Query, Mutation, Subscription>( &'a mut self, schema: &'a Schema, - ) -> Result> + ) -> Result> where Query: ObjectType + Send + Sync + 'static, Mutation: ObjectType + Send + Sync + 'static, @@ -71,7 +74,20 @@ impl GQLRequest { Some(name) => query.operator_name(name), None => query, }; - query.prepare() + Ok(query) + } + + /// Prepare a query and return a `PreparedQuery` object that gets some information about the query. + pub fn prepare<'a, Query, Mutation, Subscription>( + &'a mut self, + schema: &'a Schema, + ) -> Result> + where + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Subscription: SubscriptionType + Send + Sync + 'static, + { + self.builder(schema)?.prepare() } } diff --git a/src/lib.rs b/src/lib.rs index 017e53b7..dd2fa525 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,6 +39,7 @@ //! * Custom extension //! * Apollo Tracing extension //! * Limit query complexity/depth +//! * Error Extensions //! //! ## Integrations //! @@ -108,8 +109,8 @@ pub use registry::CacheControl; pub use scalars::ID; pub use schema::{publish, Schema}; pub use subscription::{ - SubscriptionStream, SubscriptionStub, SubscriptionStubs, SubscriptionTransport, - WebSocketTransport, + SubscriptionConnectionBuilder, SubscriptionStream, SubscriptionStub, SubscriptionStubs, + SubscriptionTransport, WebSocketTransport, }; pub use types::{ Connection, DataSource, EmptyEdgeFields, EmptyMutation, EmptySubscription, QueryOperation, diff --git a/src/query.rs b/src/query.rs index f461871c..75aca410 100644 --- a/src/query.rs +++ b/src/query.rs @@ -10,6 +10,7 @@ use graphql_parser::parse_query; use graphql_parser::query::{ Definition, FragmentDefinition, OperationDefinition, SelectionSet, VariableDefinition, }; +use std::any::Any; use std::collections::HashMap; use std::sync::atomic::AtomicUsize; @@ -26,6 +27,7 @@ pub struct QueryBuilder<'a, Query, Mutation, Subscription> { pub(crate) operation_name: Option<&'a str>, pub(crate) variables: Option, pub(crate) data: &'a Data, + pub(crate) ctx_data: Option, } impl<'a, Query, Mutation, Subscription> QueryBuilder<'a, Query, Mutation, Subscription> { @@ -45,6 +47,18 @@ impl<'a, Query, Mutation, Subscription> QueryBuilder<'a, Query, Mutation, Subscr } } + /// Add a context data that can be accessed in the `Context`, you access it with `Context::data`. + pub fn data(mut self, data: D) -> Self { + if let Some(ctx_data) = &mut self.ctx_data { + ctx_data.insert(data); + } else { + let mut ctx_data = Data::default(); + ctx_data.insert(data); + self.ctx_data = Some(ctx_data); + } + self + } + /// Prepare query pub fn prepare(self) -> Result> { self.extensions @@ -119,6 +133,7 @@ impl<'a, Query, Mutation, Subscription> QueryBuilder<'a, Query, Mutation, Subscr registry: &self.schema.0.registry, variables: self.variables.unwrap_or_default(), data: self.data, + ctx_data: self.ctx_data, fragments, selection_set: selection_set.ok_or({ if let Some(name) = self.operation_name { @@ -161,6 +176,7 @@ pub struct PreparedQuery<'a, Query, Mutation> { registry: &'a Registry, variables: Variables, data: &'a Data, + ctx_data: Option, fragments: HashMap, selection_set: SelectionSet, variable_definitions: Option>, @@ -210,6 +226,7 @@ impl<'a, Query, Mutation> PreparedQuery<'a, Query, Mutation> { variable_definitions: self.variable_definitions.as_deref(), registry: self.registry, data: self.data, + ctx_data: self.ctx_data.as_ref(), fragments: &self.fragments, }; diff --git a/src/schema.rs b/src/schema.rs index a2149935..4a507518 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -3,14 +3,13 @@ use crate::extensions::{BoxExtension, Extension}; use crate::model::__DirectiveLocation; use crate::query::QueryBuilder; use crate::registry::{Directive, InputValue, Registry}; -use crate::subscription::{create_connection, SubscriptionStub, SubscriptionTransport}; +use crate::subscription::{SubscriptionConnectionBuilder, SubscriptionStub, SubscriptionTransport}; use crate::types::QueryRoot; use crate::validation::check_rules; use crate::{ - ContextSelectionSet, ObjectType, QueryError, QueryParseError, Result, SubscriptionStream, - SubscriptionType, Type, Variables, + ContextSelectionSet, ObjectType, QueryError, QueryParseError, Result, SubscriptionType, Type, + Variables, }; -use bytes::Bytes; use futures::channel::mpsc; use futures::lock::Mutex; use futures::SinkExt; @@ -75,7 +74,7 @@ impl self } - /// Add a global data that can be accessed in the `Context`. + /// Add a global data that can be accessed in the `Schema`, you access it with `Context::data`. pub fn data(mut self, data: D) -> Self { self.0.data.insert(data); self @@ -220,6 +219,7 @@ where operation_name: None, variables: None, data: &self.0.data, + ctx_data: None, } } @@ -273,6 +273,7 @@ where variable_definitions: Some(&subscription.variable_definitions), registry: &self.0.registry, data: &Default::default(), + ctx_data: None, fragments: &fragments, }; create_subscription_types::(&ctx, &fragments, &mut types)?; @@ -282,18 +283,20 @@ where variables, variable_definitions: subscription.variable_definitions, fragments, + ctx_data: None, }) } - /// Create subscription connection, returns `Sink` and `Stream`. - pub async fn subscription_connection( + /// Create subscription connection, returns `SubscriptionConnectionBuilder`. + pub fn subscription_connection( &self, transport: T, - ) -> ( - mpsc::Sender, - SubscriptionStream, - ) { - create_connection(self, transport).await + ) -> SubscriptionConnectionBuilder { + SubscriptionConnectionBuilder { + schema: self.clone(), + transport, + ctx_data: None, + } } } diff --git a/src/subscription/connection.rs b/src/subscription/connection.rs index eeba103d..ace29066 100644 --- a/src/subscription/connection.rs +++ b/src/subscription/connection.rs @@ -1,3 +1,4 @@ +use crate::context::Data; use crate::schema::SUBSCRIPTION_SENDERS; use crate::subscription::SubscriptionStub; use crate::{ObjectType, Result, Schema, SubscriptionType}; @@ -12,24 +13,20 @@ use std::pin::Pin; use std::sync::Arc; /// Subscription stubs, use to hold all subscription information for the `SubscriptionConnection` -pub struct SubscriptionStubs( - Slab>, -); - -impl Default for SubscriptionStubs { - fn default() -> Self { - Self(Slab::new()) - } +pub struct SubscriptionStubs { + stubs: Slab>, + ctx_data: Option>, } #[allow(missing_docs)] impl SubscriptionStubs { - pub fn add(&mut self, stub: SubscriptionStub) -> usize { - self.0.insert(stub) + pub fn add(&mut self, mut stub: SubscriptionStub) -> usize { + stub.ctx_data = self.ctx_data.clone(); + self.stubs.insert(stub) } pub fn remove(&mut self, id: usize) { - self.0.remove(id); + self.stubs.remove(id); } } @@ -56,8 +53,9 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static { } pub async fn create_connection( - schema: &Schema, + schema: Schema, transport: T, + ctx_data: Option, ) -> ( mpsc::Sender, SubscriptionStream, @@ -74,9 +72,12 @@ where ( tx_bytes.clone(), SubscriptionStream { - schema: schema.clone(), + schema, transport, - stubs: Default::default(), + stubs: SubscriptionStubs { + stubs: Default::default(), + ctx_data: ctx_data.map(Arc::new), + }, rx_bytes, rx_msg, send_queue: VecDeque::new(), @@ -151,7 +152,7 @@ where let send_queue = &mut this.send_queue as *mut VecDeque; let fut = async move { unsafe { - for (id, stub) in (*stubs).0.iter() { + for (id, stub) in (*stubs).stubs.iter() { if let Some(res) = stub.resolve(msg.as_ref()).await.transpose() { if let Some(bytes) = (*transport).handle_response(id, res) { (*send_queue).push_back(bytes); diff --git a/src/subscription/connection_builder.rs b/src/subscription/connection_builder.rs new file mode 100644 index 00000000..4e8a0fb1 --- /dev/null +++ b/src/subscription/connection_builder.rs @@ -0,0 +1,43 @@ +use crate::context::Data; +use crate::subscription::create_connection; +use crate::{ObjectType, Schema, SubscriptionStream, SubscriptionTransport, SubscriptionType}; +use bytes::Bytes; +use futures::channel::mpsc; +use std::any::Any; + +/// SubscriptionConnection builder +pub struct SubscriptionConnectionBuilder { + pub(crate) schema: Schema, + pub(crate) transport: T, + pub(crate) ctx_data: Option, +} + +impl + SubscriptionConnectionBuilder +where + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Subscription: SubscriptionType + Send + Sync + 'static, +{ + /// Add a context data that can be accessed in the `Context`, you access it with `Context::data`. + pub fn data(mut self, data: D) -> Self { + if let Some(ctx_data) = &mut self.ctx_data { + ctx_data.insert(data); + } else { + let mut ctx_data = Data::default(); + ctx_data.insert(data); + self.ctx_data = Some(ctx_data); + } + self + } + + /// Create subscription connection, returns `Sink` and `Stream`. + pub async fn build( + self, + ) -> ( + mpsc::Sender, + SubscriptionStream, + ) { + create_connection(self.schema, self.transport, self.ctx_data).await + } +} diff --git a/src/subscription/mod.rs b/src/subscription/mod.rs index a2861bab..c79b2b53 100644 --- a/src/subscription/mod.rs +++ b/src/subscription/mod.rs @@ -1,11 +1,13 @@ mod connection; -mod subscribe_stub; +mod connection_builder; +mod subscription_stub; mod subscription_type; mod ws_transport; pub use connection::{ create_connection, SubscriptionStream, SubscriptionStubs, SubscriptionTransport, }; -pub use subscribe_stub::SubscriptionStub; +pub use connection_builder::SubscriptionConnectionBuilder; +pub use subscription_stub::SubscriptionStub; pub use subscription_type::SubscriptionType; pub use ws_transport::WebSocketTransport; diff --git a/src/subscription/subscribe_stub.rs b/src/subscription/subscription_stub.rs similarity index 92% rename from src/subscription/subscribe_stub.rs rename to src/subscription/subscription_stub.rs index df409fa8..54689105 100644 --- a/src/subscription/subscribe_stub.rs +++ b/src/subscription/subscription_stub.rs @@ -1,8 +1,10 @@ +use crate::context::Data; use crate::{ContextBase, ObjectType, Result, Schema, SubscriptionType, Variables}; use graphql_parser::query::{Field, FragmentDefinition, VariableDefinition}; use std::any::{Any, TypeId}; use std::collections::HashMap; use std::sync::atomic::AtomicUsize; +use std::sync::Arc; /// Subscription stub /// @@ -14,6 +16,7 @@ pub struct SubscriptionStub { pub(crate) variables: Variables, pub(crate) variable_definitions: Vec, pub(crate) fragments: HashMap, + pub(crate) ctx_data: Option>, } impl SubscriptionStub @@ -37,6 +40,7 @@ where variable_definitions: Some(&self.variable_definitions), registry: &self.schema.0.registry, data: &self.schema.0.data, + ctx_data: self.ctx_data.as_deref(), fragments: &self.fragments, }; self.schema