From 631de524000b458206abafc5ff5928e2908e7239 Mon Sep 17 00:00:00 2001 From: sunli Date: Mon, 6 Apr 2020 13:49:39 +0800 Subject: [PATCH] The subscription field now returns a stream --- .../examples/subscription.rs | 121 ++------------ async-graphql-actix-web/src/lib.rs | 40 +---- async-graphql-actix-web/src/session.rs | 33 +--- async-graphql-derive/src/subscription.rs | 154 +++++++++--------- src/context.rs | 41 +++++ src/lib.rs | 12 +- src/schema.rs | 132 ++++----------- src/subscription/connection.rs | 110 +++++-------- src/subscription/connection_builder.rs | 43 ----- src/subscription/mod.rs | 8 +- src/subscription/subscription_stub.rs | 52 ------ src/subscription/subscription_type.rs | 92 +++++++++-- src/subscription/ws_transport.rs | 23 ++- src/types/empty_subscription.rs | 38 +++-- src/validators/mod.rs | 2 + 15 files changed, 331 insertions(+), 570 deletions(-) delete mode 100644 src/subscription/connection_builder.rs delete mode 100644 src/subscription/subscription_stub.rs diff --git a/async-graphql-actix-web/examples/subscription.rs b/async-graphql-actix-web/examples/subscription.rs index b096885f..10e7af1a 100644 --- a/async-graphql-actix-web/examples/subscription.rs +++ b/async-graphql-actix-web/examples/subscription.rs @@ -1,109 +1,15 @@ +use actix::clock::Duration; use actix_web::{web, App, HttpServer}; -use async_graphql::{publish, Context, FieldResult, Schema, ID}; -use futures::lock::Mutex; -use slab::Slab; -use std::sync::Arc; - -#[derive(Clone)] -struct Book { - id: ID, - name: String, - author: String, -} - -#[async_graphql::Object] -impl Book { - #[field] - async fn id(&self) -> &str { - &self.id - } - - #[field] - async fn name(&self) -> &str { - &self.name - } - - #[field] - async fn author(&self) -> &str { - &self.author - } -} - -type Storage = Arc>>; +use async_graphql::{EmptyMutation, Schema}; +use futures::{Stream, StreamExt}; struct QueryRoot; -#[async_graphql::Object(cache_control(max_age = 5))] +#[async_graphql::Object] impl QueryRoot { #[field] - async fn books(&self, ctx: &Context<'_>) -> Vec { - let books = ctx.data::().lock().await; - books.iter().map(|(_, book)| book).cloned().collect() - } -} - -struct MutationRoot; - -#[async_graphql::Object] -impl MutationRoot { - #[field] - async fn create_book(&self, ctx: &Context<'_>, name: String, author: String) -> ID { - let mut books = ctx.data::().lock().await; - let entry = books.vacant_entry(); - let id: ID = entry.key().into(); - let book = Book { - id: id.clone(), - name, - author, - }; - entry.insert(book); - publish(BookChanged { - mutation_type: MutationType::Created, - id: id.clone(), - }) - .await; - id - } - - #[field] - async fn delete_book(&self, ctx: &Context<'_>, id: ID) -> FieldResult { - let mut books = ctx.data::().lock().await; - let id = id.parse::()?; - if books.contains(id) { - books.remove(id); - publish(BookChanged { - mutation_type: MutationType::Deleted, - id: id.into(), - }) - .await; - Ok(true) - } else { - Ok(false) - } - } -} - -#[async_graphql::Enum] -enum MutationType { - Created, - Deleted, -} - -struct BookChanged { - mutation_type: MutationType, - id: ID, -} - -#[async_graphql::Object] -impl BookChanged { - #[field] - async fn mutation_type(&self) -> &MutationType { - &self.mutation_type - } - - #[field] - async fn id(&self) -> &ID { - &self.id + async fn value(&self) -> i32 { + 0 } } @@ -112,20 +18,19 @@ struct SubscriptionRoot; #[async_graphql::Subscription] impl SubscriptionRoot { #[field] - fn books(&self, changed: &BookChanged, mutation_type: Option) -> bool { - if let Some(mutation_type) = mutation_type { - return changed.mutation_type == mutation_type; - } - true + fn interval(&self, n: i32) -> impl Stream { + let mut value = 0; + actix_rt::time::interval(Duration::from_secs(1)).map(move |_| { + value += n; + value + }) } } #[actix_rt::main] async fn main() -> std::io::Result<()> { HttpServer::new(move || { - let schema = Schema::build(QueryRoot, MutationRoot, SubscriptionRoot) - .data(Storage::default()) - .finish(); + let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let handler = async_graphql_actix_web::HandlerBuilder::new(schema) .enable_ui("http://localhost:8000", Some("ws://localhost:8000")) .enable_subscription() diff --git a/async-graphql-actix-web/src/lib.rs b/async-graphql-actix-web/src/lib.rs index 224e295b..d4fd8586 100644 --- a/async-graphql-actix-web/src/lib.rs +++ b/async-graphql-actix-web/src/lib.rs @@ -11,10 +11,7 @@ use actix_web::web::{BytesMut, Payload}; use actix_web::{web, FromRequest, HttpRequest, HttpResponse, Responder}; use actix_web_actors::ws; use async_graphql::http::{GQLRequest, GQLResponse}; -use async_graphql::{ - ObjectType, QueryBuilder, Schema, SubscriptionConnectionBuilder, SubscriptionType, - WebSocketTransport, -}; +use async_graphql::{ObjectType, QueryBuilder, Schema, SubscriptionType}; use bytes::Bytes; use futures::StreamExt; use mime::Mime; @@ -30,13 +27,6 @@ type BoxOnRequestFn = Arc< ) -> QueryBuilder, >; -type BoxOnConnectFn = Arc< - dyn Fn( - &HttpRequest, - SubscriptionConnectionBuilder, - ) -> SubscriptionConnectionBuilder, ->; - /// Actix-web handler builder pub struct HandlerBuilder { schema: Schema, @@ -45,7 +35,6 @@ pub struct HandlerBuilder { enable_subscription: bool, enable_ui: Option<(String, Option)>, on_request: Option>, - on_connect: Option>, } impl HandlerBuilder @@ -63,7 +52,6 @@ where enable_subscription: false, enable_ui: None, on_request: None, - on_connect: None, } } @@ -122,24 +110,6 @@ where } } - /// When there is a new subscription connection, you can use this closure to append your own data to the `SubscriptionConnectionBuilder`. - pub fn on_connect< - F: Fn( - &HttpRequest, - SubscriptionConnectionBuilder, - ) - -> SubscriptionConnectionBuilder - + 'static, - >( - self, - f: F, - ) -> Self { - Self { - on_connect: Some(Arc::new(f)), - ..self - } - } - /// Create an HTTP handler. pub fn build( self, @@ -155,13 +125,11 @@ where let enable_ui = self.enable_ui; let enable_subscription = self.enable_subscription; let on_request = self.on_request; - let on_connect = self.on_connect; move |req: HttpRequest, payload: Payload| { let schema = schema.clone(); let enable_ui = enable_ui.clone(); let on_request = on_request.clone(); - let on_connect = on_connect.clone(); Box::pin(async move { if req.method() == Method::GET { @@ -170,11 +138,7 @@ where if let Ok(s) = s.to_str() { if s.to_ascii_lowercase().contains("websocket") { return ws::start_with_protocols( - WsSession::new( - schema.clone(), - req.clone(), - on_connect.clone(), - ), + WsSession::new(schema.clone()), &["graphql-ws"], &req, payload, diff --git a/async-graphql-actix-web/src/session.rs b/async-graphql-actix-web/src/session.rs index cd2b60f4..43b8288d 100644 --- a/async-graphql-actix-web/src/session.rs +++ b/async-graphql-actix-web/src/session.rs @@ -1,8 +1,6 @@ -use crate::BoxOnConnectFn; use actix::{ Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture, }; -use actix_web::HttpRequest; use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext}; use async_graphql::{ObjectType, Schema, SubscriptionType, WebSocketTransport}; use bytes::Bytes; @@ -11,11 +9,9 @@ use futures::SinkExt; use std::time::{Duration, Instant}; pub struct WsSession { - req: HttpRequest, schema: Schema, hb: Instant, sink: Option>, - on_connect: Option>, } impl WsSession @@ -24,17 +20,11 @@ where Mutation: ObjectType + Send + Sync + 'static, Subscription: SubscriptionType + Send + Sync + 'static, { - pub fn new( - schema: Schema, - req: HttpRequest, - on_connect: Option>, - ) -> Self { + pub fn new(schema: Schema) -> Self { Self { - req, schema, hb: Instant::now(), sink: None, - on_connect, } } @@ -58,24 +48,9 @@ where fn started(&mut self, ctx: &mut Self::Context) { self.hb(ctx); let schema = self.schema.clone(); - let on_connect = self.on_connect.clone(); - let req = self.req.clone(); - async move { - let mut builder = schema - .clone() - .subscription_connection(WebSocketTransport::default()); - if let Some(on_connect) = on_connect { - builder = on_connect(&req, builder); - } - builder.build().await - } - .into_actor(self) - .then(|(sink, stream), actor, ctx| { - actor.sink = Some(sink); - ctx.add_stream(stream); - async {}.into_actor(actor) - }) - .wait(ctx); + let (sink, stream) = schema.subscription_connection(WebSocketTransport::default()); + ctx.add_stream(stream); + self.sink = Some(sink); } } diff --git a/async-graphql-derive/src/subscription.rs b/async-graphql-derive/src/subscription.rs index 0f6dca38..544b9d31 100644 --- a/async-graphql-derive/src/subscription.rs +++ b/async-graphql-derive/src/subscription.rs @@ -3,7 +3,7 @@ use crate::utils::{build_value_repr, check_reserved_name, get_crate_name}; use inflector::Inflector; use proc_macro::TokenStream; use quote::quote; -use syn::{Error, FnArg, ImplItem, ItemImpl, Pat, Result, ReturnType, Type}; +use syn::{Error, FnArg, ImplItem, ItemImpl, Pat, Result, ReturnType, Type, TypeImplTrait}; pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result { let crate_name = get_crate_name(object_args.internal); @@ -32,8 +32,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< .map(|s| quote! {Some(#s)}) .unwrap_or_else(|| quote! {None}); - let mut create_types = Vec::new(); - let mut filters = Vec::new(); + let mut create_stream = Vec::new(); let mut schema_fields = Vec::new(); for item in &mut item_impl.items { @@ -55,32 +54,10 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< .map(|s| quote! {Some(#s)}) .unwrap_or_else(|| quote! {None}); - if method.sig.inputs.len() < 2 { - return Err(Error::new_spanned( - &method.sig.inputs, - "The filter function needs at least two arguments", - )); - } - if method.sig.asyncness.is_some() { return Err(Error::new_spanned( - &method.sig.inputs, - "The filter function must be synchronous", - )); - } - - let mut res_typ_ok = false; - if let ReturnType::Type(_, res_ty) = &method.sig.output { - if let Type::Path(p) = res_ty.as_ref() { - if p.path.is_ident("bool") { - res_typ_ok = true; - } - } - } - if !res_typ_ok { - return Err(Error::new_spanned( - &method.sig.output, - "The filter function must return a boolean value", + &method.sig.asyncness, + "The subscription stream function must be synchronous", )); } @@ -94,23 +71,9 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< } } - let ty = if let FnArg::Typed(ty) = &method.sig.inputs[1] { - match ty.ty.as_ref() { - Type::Reference(r) => r.elem.as_ref().clone(), - _ => { - return Err(Error::new_spanned(ty, "Incorrect object type")); - } - } - } else { - return Err(Error::new_spanned( - &method.sig.inputs[1], - "Incorrect object type", - )); - }; - let mut args = Vec::new(); - for arg in method.sig.inputs.iter_mut().skip(2) { + for arg in method.sig.inputs.iter_mut().skip(1) { if let FnArg::Typed(pat) = arg { match (&*pat.pat, &*pat.ty) { (Pat::Ident(arg_ident), Type::Path(arg_ty)) => { @@ -181,10 +144,26 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< }; get_params.push(quote! { - let #ident: #ty = ctx_field.param_value(#name, field.position, #default)?; + let #ident: #ty = ctx.param_value(#name, ctx.position, #default)?; }); } + let stream_ty = match &method.sig.output { + ReturnType::Default => { + return Err(Error::new_spanned( + &method.sig.output, + "Must be return a stream type", + )) + } + ReturnType::Type(_, ty) => { + if let Type::ImplTrait(TypeImplTrait { bounds, .. }) = ty.as_ref() { + quote! { #bounds } + } else { + quote! { #ty } + } + } + }; + schema_fields.push(quote! { fields.insert(#field_name.to_string(), #crate_name::registry::Field { name: #field_name.to_string(), @@ -194,30 +173,46 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< #(#schema_args)* args }, - ty: <#ty as #crate_name::Type>::create_type_info(registry), + ty: <#stream_ty as #crate_name::futures::stream::Stream>::Item::create_type_info(registry), deprecation: #field_deprecation, cache_control: Default::default(), }); }); - create_types.push(quote! { - if field.name.as_str() == #field_name { - types.insert(std::any::TypeId::of::<#ty>(), field.clone()); - return Ok(()); - } - }); - - filters.push(quote! { - if let Some(msg) = msg.downcast_ref::<#ty>() { + create_stream.push(quote! { + if ctx.name.as_str() == #field_name { + let field_name = ctx.result_name().to_string(); #(#get_params)* - if self.#ident(msg, #(#use_params)*) { - let ctx_selection_set = ctx_field.with_selection_set(&field.selection_set); - let value = - #crate_name::OutputValueType::resolve(msg, &ctx_selection_set, field.position).await?; - let mut res = #crate_name::serde_json::Map::new(); - res.insert(ctx_field.result_name().to_string(), value); - return Ok(Some(res.into())); - } + let field_selection_set = std::sync::Arc::new(ctx.selection_set.clone()); + let schema = schema.clone(); + let pos = ctx.position; + let environment = environment.clone(); + let stream = #crate_name::futures::stream::StreamExt::then(self.#ident(#(#use_params)*).fuse(), move |msg| { + let environment = environment.clone(); + let field_selection_set = field_selection_set.clone(); + let schema = schema.clone(); + async move { + let resolve_id = std::sync::atomic::AtomicUsize::default(); + let ctx_selection_set = environment.create_context( + &*field_selection_set, + Some(#crate_name::QueryPathNode { + parent: None, + segment: #crate_name::QueryPathSegment::Name("time"), + }), + &resolve_id, + schema.registry(), + schema.data(), + ); + #crate_name::OutputValueType::resolve(&msg, &ctx_selection_set, pos).await + } + }). + filter_map(move |res| { + let res = res.ok().map(|value| { + #crate_name::serde_json::json!({ &field_name: value }) + }); + async move { res } + }); + return Ok(Box::pin(stream)); } }); @@ -234,6 +229,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< std::borrow::Cow::Borrowed(#gql_typename) } + #[allow(bare_trait_objects)] fn create_type_info(registry: &mut #crate_name::registry::Registry) -> String { registry.create_type::(|registry| #crate_name::registry::Type::Object { name: #gql_typename.to_string(), @@ -250,26 +246,24 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< #[#crate_name::async_trait::async_trait] impl #crate_name::SubscriptionType for SubscriptionRoot { - fn create_type(field: &#crate_name::graphql_parser::query::Field, types: &mut std::collections::HashMap) -> #crate_name::Result<()> { - #(#create_types)* - Err(#crate_name::QueryError::FieldNotFound { - field_name: field.name.clone(), - object: #gql_typename.to_string(), - }.into_error(field.position)) - } - - async fn resolve( + #[allow(unused_variables)] + #[allow(bare_trait_objects)] + fn create_field_stream( &self, - ctx: &#crate_name::ContextBase<'_, ()>, - types: &std::collections::HashMap, - msg: &(dyn std::any::Any + Send + Sync), - ) -> #crate_name::Result> { - let tid = msg.type_id(); - if let Some(field) = types.get(&tid) { - let ctx_field = ctx.with_field(field); - #(#filters)* - } - Ok(None) + ctx: &#crate_name::Context<'_>, + schema: &#crate_name::Schema, + environment: std::sync::Arc<#crate_name::Environment>, + ) -> #crate_name::Result>>> + where + Query: #crate_name::ObjectType + Send + Sync + 'static, + Mutation: #crate_name::ObjectType + Send + Sync + 'static, + Self: Send + Sync + 'static + Sized, + { + #(#create_stream)* + Err(#crate_name::QueryError::FieldNotFound { + field_name: ctx.name.clone(), + object: #gql_typename.to_string(), + }.into_error(ctx.position)) } } }; diff --git a/src/context.rs b/src/context.rs index 9858672e..d94ecc5f 100644 --- a/src/context.rs +++ b/src/context.rs @@ -242,7 +242,48 @@ impl<'a, T> Deref for ContextBase<'a, T> { } } +#[doc(hidden)] +pub struct Environment { + pub variables: Variables, + pub variable_definitions: Vec, + pub fragments: HashMap, +} + +impl Environment { + #[doc(hidden)] + pub fn create_context<'a, T>( + &'a self, + item: T, + path_node: Option>, + resolve_id: &'a AtomicUsize, + registry: &'a Registry, + data: &'a Data, + ) -> ContextBase<'a, T> { + ContextBase { + path_node, + resolve_id, + extensions: &[], + item, + variables: &self.variables, + variable_definitions: &self.variable_definitions, + registry, + data, + ctx_data: None, + fragments: &self.fragments, + } + } +} + impl<'a, T> ContextBase<'a, T> { + #[doc(hidden)] + pub fn create_environment(&self) -> Environment { + Environment { + variables: self.variables.clone(), + variable_definitions: self.variable_definitions.to_vec(), + fragments: self.fragments.clone(), + } + } + #[doc(hidden)] pub fn get_resolve_id(&self) -> usize { self.resolve_id diff --git a/src/lib.rs b/src/lib.rs index f07ca362..ac6d15cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,7 +76,6 @@ extern crate serde_derive; mod base; mod context; mod error; -pub mod extensions; mod model; mod mutation_resolver; mod query; @@ -87,7 +86,7 @@ mod subscription; mod types; mod validation; -/// Input value validators +pub mod extensions; pub mod validators; #[doc(hidden)] @@ -95,6 +94,8 @@ pub use anyhow; #[doc(hidden)] pub use async_trait; #[doc(hidden)] +pub use futures; +#[doc(hidden)] pub use graphql_parser; #[doc(hidden)] pub use serde_json; @@ -102,17 +103,16 @@ pub use serde_json; pub mod http; pub use base::{Scalar, Type}; -pub use context::{Context, QueryPathSegment, Variables}; +pub use context::{Context, Environment, QueryPathNode, QueryPathSegment, Variables}; pub use error::{Error, ErrorExtensions, FieldError, FieldResult, QueryError, ResultExt}; pub use graphql_parser::query::Value; pub use graphql_parser::Pos; pub use query::{QueryBuilder, QueryResponse}; pub use registry::CacheControl; pub use scalars::ID; -pub use schema::{publish, Schema}; +pub use schema::Schema; pub use subscription::{ - SubscriptionConnectionBuilder, SubscriptionStream, SubscriptionStub, SubscriptionStubs, - SubscriptionTransport, WebSocketTransport, + SubscriptionStream, SubscriptionStreams, SubscriptionTransport, WebSocketTransport, }; pub use types::{ Connection, DataSource, EmptyEdgeFields, EmptyMutation, EmptySubscription, QueryOperation, diff --git a/src/schema.rs b/src/schema.rs index bc1bbeb1..ac715301 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -3,32 +3,25 @@ use crate::extensions::{BoxExtension, Extension}; use crate::model::__DirectiveLocation; use crate::query::QueryBuilder; use crate::registry::{Directive, InputValue, Registry}; -use crate::subscription::{SubscriptionConnectionBuilder, SubscriptionStub, SubscriptionTransport}; +use crate::subscription::{create_connection, create_subscription_stream, SubscriptionTransport}; use crate::types::QueryRoot; use crate::validation::{check_rules, CheckResult}; use crate::{ ContextSelectionSet, Error, ObjectType, Pos, QueryError, QueryResponse, Result, - SubscriptionType, Type, Variables, + SubscriptionStream, SubscriptionType, Type, Variables, }; +use bytes::Bytes; use futures::channel::mpsc; -use futures::lock::Mutex; -use futures::{SinkExt, TryFutureExt}; +use futures::{Stream, TryFutureExt}; use graphql_parser::parse_query; -use graphql_parser::query::{ - Definition, Field, FragmentDefinition, OperationDefinition, Selection, -}; +use graphql_parser::query::{Definition, OperationDefinition}; use itertools::Itertools; -use once_cell::sync::Lazy; -use slab::Slab; -use std::any::{Any, TypeId}; +use std::any::Any; use std::collections::HashMap; +use std::pin::Pin; use std::sync::atomic::AtomicUsize; use std::sync::Arc; -type MsgSender = mpsc::Sender>; - -pub(crate) static SUBSCRIPTION_SENDERS: Lazy>> = Lazy::new(Default::default); - pub(crate) struct SchemaInner { pub(crate) query: QueryRoot, pub(crate) mutation: Mutation, @@ -211,6 +204,16 @@ where Self::build(query, mutation, subscription).finish() } + #[doc(hidden)] + pub fn data(&self) -> &Data { + &self.0.data + } + + #[doc(hidden)] + pub fn registry(&self) -> &Registry { + &self.0.registry + } + /// Start a query and return `QueryBuilder`. pub fn query(&self, source: &str) -> Result> { let extensions = self @@ -233,13 +236,13 @@ where if let Some(limit_complexity) = self.0.complexity { if complexity > limit_complexity { - return Err(QueryError::TooComplex.into_error(Pos { line: 0, column: 0 })); + return Err(QueryError::TooComplex.into_error(Pos::default())); } } if let Some(limit_depth) = self.0.depth { if depth > limit_depth { - return Err(QueryError::TooDeep.into_error(Pos { line: 0, column: 0 })); + return Err(QueryError::TooDeep.into_error(Pos::default())); } } @@ -261,16 +264,13 @@ where .await } - /// Create subscription stub, typically called inside the `SubscriptionTransport::handle_request` method/ - pub fn create_subscription_stub( + /// Create subscription stream, typically called inside the `SubscriptionTransport::handle_request` method + pub fn create_subscription_stream( &self, source: &str, operation_name: Option<&str>, variables: Variables, - ) -> Result> - where - Self: Sized, - { + ) -> Result>>> { let document = parse_query(source).map_err(Into::::into)?; check_rules(&self.0.registry, &document)?; @@ -301,7 +301,6 @@ where QueryError::MissingOperation.into_error(Pos::default()) })?; - let mut types = HashMap::new(); let resolve_id = AtomicUsize::default(); let ctx = ContextSelectionSet { path_node: None, @@ -315,87 +314,20 @@ where ctx_data: None, fragments: &fragments, }; - create_subscription_types::(&ctx, &fragments, &mut types)?; - Ok(SubscriptionStub { - schema: self.clone(), - types, - variables, - variable_definitions: subscription.variable_definitions, - fragments, - ctx_data: None, - }) + + let mut streams = Vec::new(); + create_subscription_stream(self, Arc::new(ctx.create_environment()), &ctx, &mut streams)?; + Ok(Box::pin(futures::stream::select_all(streams))) } - /// Create subscription connection, returns `SubscriptionConnectionBuilder`. + /// Create subscription connection, returns `Sink` and `Stream`. pub fn subscription_connection( &self, transport: T, - ) -> SubscriptionConnectionBuilder { - SubscriptionConnectionBuilder { - schema: self.clone(), - transport, - ctx_data: None, - } - } -} - -fn create_subscription_types( - ctx: &ContextSelectionSet<'_>, - fragments: &HashMap, - types: &mut HashMap, -) -> Result<()> { - for selection in &ctx.items { - match selection { - Selection::Field(field) => { - if ctx.is_skip(&field.directives)? { - continue; - } - T::create_type(field, types)?; - } - Selection::FragmentSpread(fragment_spread) => { - if ctx.is_skip(&fragment_spread.directives)? { - continue; - } - - if let Some(fragment) = fragments.get(&fragment_spread.fragment_name) { - create_subscription_types::( - &ctx.with_selection_set(&fragment.selection_set), - fragments, - types, - )?; - } else { - return Err(QueryError::UnknownFragment { - name: fragment_spread.fragment_name.clone(), - } - .into_error(fragment_spread.position)); - } - } - Selection::InlineFragment(inline_fragment) => { - if ctx.is_skip(&inline_fragment.directives)? { - continue; - } - create_subscription_types::( - &ctx.with_selection_set(&inline_fragment.selection_set), - fragments, - types, - )?; - } - } - } - Ok(()) -} - -/// Publish a message that will be pushed to all subscribed clients. -pub async fn publish(msg: T) { - let mut senders = SUBSCRIPTION_SENDERS.lock().await; - let msg = Arc::new(msg); - let mut remove = Vec::new(); - for (id, sender) in senders.iter_mut() { - if sender.send(msg.clone()).await.is_err() { - remove.push(id); - } - } - for id in remove { - senders.remove(id); + ) -> ( + mpsc::Sender, + SubscriptionStream, + ) { + create_connection(self.clone(), transport) } } diff --git a/src/subscription/connection.rs b/src/subscription/connection.rs index bfb2db5a..99d19d11 100644 --- a/src/subscription/connection.rs +++ b/src/subscription/connection.rs @@ -1,32 +1,25 @@ -use crate::context::Data; -use crate::schema::SUBSCRIPTION_SENDERS; -use crate::subscription::SubscriptionStub; -use crate::{ObjectType, Result, Schema, SubscriptionType}; +use crate::{ObjectType, Schema, SubscriptionType}; use bytes::Bytes; use futures::channel::mpsc; use futures::task::{Context, Poll}; -use futures::{Future, FutureExt, Stream}; +use futures::Stream; use slab::Slab; -use std::any::Any; use std::collections::VecDeque; use std::pin::Pin; -use std::sync::Arc; -/// Subscription stubs, use to hold all subscription information for the `SubscriptionConnection` -pub struct SubscriptionStubs { - stubs: Slab>, - ctx_data: Option>, +/// Use to hold all subscription stream for the `SubscriptionConnection` +pub struct SubscriptionStreams { + streams: Slab>>>, } #[allow(missing_docs)] -impl SubscriptionStubs { - pub fn add(&mut self, mut stub: SubscriptionStub) -> usize { - stub.ctx_data = self.ctx_data.clone(); - self.stubs.insert(stub) +impl SubscriptionStreams { + pub fn add(&mut self, stream: Pin>>) -> usize { + self.streams.insert(stream) } pub fn remove(&mut self, id: usize) { - self.stubs.remove(id); + self.streams.remove(id); } } @@ -38,12 +31,12 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static { type Error; /// Parse the request data here. - /// If you have a new request, create a `SubscriptionStub` with the `Schema::create_subscription_stub`, and then call `SubscriptionStubs::add`. + /// If you have a new subscribe, create a stream with the `Schema::create_subscription_stream`, and then call `SubscriptionStreams::add`. /// You can return a `Byte`, which will be sent to the client. If it returns an error, the connection will be broken. fn handle_request( &mut self, schema: &Schema, - stubs: &mut SubscriptionStubs, + streams: &mut SubscriptionStreams, data: Bytes, ) -> std::result::Result, Self::Error> where @@ -52,13 +45,12 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static { Subscription: SubscriptionType + Sync + Send + 'static; /// When a response message is generated, you can convert the message to the format you want here. - fn handle_response(&mut self, id: usize, result: Result) -> Option; + fn handle_response(&mut self, id: usize, value: serde_json::Value) -> Option; } -pub async fn create_connection( +pub fn create_connection( schema: Schema, transport: T, - ctx_data: Option, ) -> ( mpsc::Sender, SubscriptionStream, @@ -69,23 +61,16 @@ where Subscription: SubscriptionType + Sync + Send + 'static, { let (tx_bytes, rx_bytes) = mpsc::channel(8); - let (tx_msg, rx_msg) = mpsc::channel(8); - let mut senders = SUBSCRIPTION_SENDERS.lock().await; - senders.insert(tx_msg); ( tx_bytes.clone(), SubscriptionStream { schema, transport, - stubs: SubscriptionStubs { - stubs: Default::default(), - ctx_data: ctx_data.map(Arc::new), + streams: SubscriptionStreams { + streams: Default::default(), }, rx_bytes, - rx_msg, send_queue: VecDeque::new(), - resolve_queue: VecDeque::default(), - resolve_fut: None, }, ) } @@ -94,12 +79,9 @@ where pub struct SubscriptionStream { schema: Schema, transport: T, - stubs: SubscriptionStubs, + streams: SubscriptionStreams, rx_bytes: mpsc::Receiver, - rx_msg: mpsc::Receiver>, send_queue: VecDeque, - resolve_queue: VecDeque>, - resolve_fut: Option>>>, } impl Stream @@ -125,7 +107,7 @@ where let this = &mut *self; match this .transport - .handle_request(&this.schema, &mut this.stubs, data) + .handle_request(&this.schema, &mut this.streams, data) { Ok(Some(bytes)) => { this.send_queue.push_back(bytes); @@ -139,44 +121,38 @@ where Poll::Pending => {} } - if let Some(resolve_fut) = &mut self.resolve_fut { - match resolve_fut.poll_unpin(cx) { - Poll::Ready(_) => { - self.resolve_fut = None; - } - Poll::Pending => return Poll::Pending, - } - } else if let Some(msg) = self.resolve_queue.pop_front() { - // FIXME: I think this code is safe, but I don't know how to implement it in safe code. - let this = &mut *self; - let stubs = &this.stubs as *const SubscriptionStubs; - let transport = &mut this.transport as *mut T; - let send_queue = &mut this.send_queue as *mut VecDeque; - let fut = async move { - unsafe { - for (id, stub) in (*stubs).stubs.iter() { - if let Some(res) = stub.resolve(msg.as_ref()).await.transpose() { - if let Some(bytes) = (*transport).handle_response(id, res) { - (*send_queue).push_back(bytes); + // receive msg + let this = &mut *self; + if !this.streams.streams.is_empty() { + loop { + let mut num_closed = 0; + let mut num_pending = 0; + + for (id, incoming_stream) in &mut this.streams.streams { + match incoming_stream.as_mut().poll_next(cx) { + Poll::Ready(Some(value)) => { + if let Some(bytes) = this.transport.handle_response(id, value) { + this.send_queue.push_back(bytes); } } + Poll::Ready(None) => { + num_closed += 1; + } + Poll::Pending => { + num_pending += 1; + } } } - }; - self.resolve_fut = Some(Box::pin(fut)); - continue; - } - // receive msg - match Pin::new(&mut self.rx_msg).poll_next(cx) { - Poll::Ready(Some(msg)) => { - self.resolve_queue.push_back(msg); - } - Poll::Ready(None) => return Poll::Ready(None), - Poll::Pending => { - // all pending - return Poll::Pending; + if num_closed == this.streams.streams.len() { + // all closed + return Poll::Ready(None); + } else if num_pending == this.streams.streams.len() { + return Poll::Pending; + } } + } else { + return Poll::Pending; } } } diff --git a/src/subscription/connection_builder.rs b/src/subscription/connection_builder.rs deleted file mode 100644 index 4e8a0fb1..00000000 --- a/src/subscription/connection_builder.rs +++ /dev/null @@ -1,43 +0,0 @@ -use crate::context::Data; -use crate::subscription::create_connection; -use crate::{ObjectType, Schema, SubscriptionStream, SubscriptionTransport, SubscriptionType}; -use bytes::Bytes; -use futures::channel::mpsc; -use std::any::Any; - -/// SubscriptionConnection builder -pub struct SubscriptionConnectionBuilder { - pub(crate) schema: Schema, - pub(crate) transport: T, - pub(crate) ctx_data: Option, -} - -impl - SubscriptionConnectionBuilder -where - Query: ObjectType + Send + Sync + 'static, - Mutation: ObjectType + Send + Sync + 'static, - Subscription: SubscriptionType + Send + Sync + 'static, -{ - /// Add a context data that can be accessed in the `Context`, you access it with `Context::data`. - pub fn data(mut self, data: D) -> Self { - if let Some(ctx_data) = &mut self.ctx_data { - ctx_data.insert(data); - } else { - let mut ctx_data = Data::default(); - ctx_data.insert(data); - self.ctx_data = Some(ctx_data); - } - self - } - - /// Create subscription connection, returns `Sink` and `Stream`. - pub async fn build( - self, - ) -> ( - mpsc::Sender, - SubscriptionStream, - ) { - create_connection(self.schema, self.transport, self.ctx_data).await - } -} diff --git a/src/subscription/mod.rs b/src/subscription/mod.rs index c79b2b53..9bfaa027 100644 --- a/src/subscription/mod.rs +++ b/src/subscription/mod.rs @@ -1,13 +1,9 @@ mod connection; -mod connection_builder; -mod subscription_stub; mod subscription_type; mod ws_transport; pub use connection::{ - create_connection, SubscriptionStream, SubscriptionStubs, SubscriptionTransport, + create_connection, SubscriptionStream, SubscriptionStreams, SubscriptionTransport, }; -pub use connection_builder::SubscriptionConnectionBuilder; -pub use subscription_stub::SubscriptionStub; -pub use subscription_type::SubscriptionType; +pub use subscription_type::{create_subscription_stream, SubscriptionType}; pub use ws_transport::WebSocketTransport; diff --git a/src/subscription/subscription_stub.rs b/src/subscription/subscription_stub.rs deleted file mode 100644 index f0a0396e..00000000 --- a/src/subscription/subscription_stub.rs +++ /dev/null @@ -1,52 +0,0 @@ -use crate::context::Data; -use crate::{ContextBase, ObjectType, Result, Schema, SubscriptionType, Variables}; -use graphql_parser::query::{Field, FragmentDefinition, VariableDefinition}; -use std::any::{Any, TypeId}; -use std::collections::HashMap; -use std::sync::atomic::AtomicUsize; -use std::sync::Arc; - -/// Subscription stub -/// -/// When a new push message is generated, a JSON object that needs to be pushed can be obtained by -/// `Subscribe::resolve`, and if None is returned, the Subscribe is not subscribed to a message of this type. -pub struct SubscriptionStub { - pub(crate) schema: Schema, - pub(crate) types: HashMap, - pub(crate) variables: Variables, - pub(crate) variable_definitions: Vec, - pub(crate) fragments: HashMap, - pub(crate) ctx_data: Option>, -} - -impl SubscriptionStub -where - Query: ObjectType + Send + Sync + 'static, - Mutation: ObjectType + Send + Sync + 'static, - Subscription: SubscriptionType + Send + Sync + 'static, -{ - #[doc(hidden)] - pub async fn resolve( - &self, - msg: &(dyn Any + Send + Sync), - ) -> Result> { - let resolve_id = AtomicUsize::default(); - let ctx = ContextBase::<()> { - path_node: None, - extensions: &[], - item: (), - resolve_id: &resolve_id, - variables: &self.variables, - variable_definitions: &self.variable_definitions, - registry: &self.schema.0.registry, - data: &self.schema.0.data, - ctx_data: self.ctx_data.as_deref(), - fragments: &self.fragments, - }; - self.schema - .0 - .subscription - .resolve(&ctx, &self.types, msg) - .await - } -} diff --git a/src/subscription/subscription_type.rs b/src/subscription/subscription_type.rs index 8168136c..aa17bd31 100644 --- a/src/subscription/subscription_type.rs +++ b/src/subscription/subscription_type.rs @@ -1,7 +1,9 @@ -use crate::{ContextBase, Result, Type}; -use graphql_parser::query::Field; -use std::any::{Any, TypeId}; -use std::collections::HashMap; +use crate::context::Environment; +use crate::{Context, ContextSelectionSet, ObjectType, Result, Schema, Type}; +use futures::Stream; +use graphql_parser::query::{Selection, TypeCondition}; +use std::pin::Pin; +use std::sync::Arc; /// Represents a GraphQL subscription object #[async_trait::async_trait] @@ -13,13 +15,79 @@ pub trait SubscriptionType: Type { } #[doc(hidden)] - fn create_type(field: &Field, types: &mut HashMap) -> Result<()>; - - /// Resolve a subscription message, If no message of this type is subscribed, None is returned. - async fn resolve( + fn create_field_stream( &self, - ctx: &ContextBase<'_, ()>, - types: &HashMap, - msg: &(dyn Any + Send + Sync), - ) -> Result>; + ctx: &Context<'_>, + schema: &Schema, + environment: Arc, + ) -> Result>>> + where + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Self: Send + Sync + 'static + Sized; +} + +pub fn create_subscription_stream( + schema: &Schema, + environment: Arc, + ctx: &ContextSelectionSet<'_>, + streams: &mut Vec>>>, +) -> Result<()> +where + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Subscription: SubscriptionType + Send + Sync + 'static + Sized, +{ + for selection in &ctx.items { + match selection { + Selection::Field(field) => { + if ctx.is_skip(&field.directives)? { + continue; + } + streams.push(schema.0.subscription.create_field_stream( + &ctx.with_field(field), + schema, + environment.clone(), + )?) + } + Selection::FragmentSpread(fragment_spread) => { + if ctx.is_skip(&fragment_spread.directives)? { + continue; + } + + if let Some(fragment) = ctx.fragments.get(fragment_spread.fragment_name.as_str()) { + create_subscription_stream( + schema, + environment.clone(), + &ctx.with_selection_set(&fragment.selection_set), + streams, + )?; + } + } + Selection::InlineFragment(inline_fragment) => { + if ctx.is_skip(&inline_fragment.directives)? { + continue; + } + + if let Some(TypeCondition::On(name)) = &inline_fragment.type_condition { + if name.as_str() == Subscription::type_name() { + create_subscription_stream( + schema, + environment.clone(), + &ctx.with_selection_set(&inline_fragment.selection_set), + streams, + )?; + } + } else { + create_subscription_stream( + schema, + environment.clone(), + &ctx.with_selection_set(&inline_fragment.selection_set), + streams, + )?; + } + } + } + } + Ok(()) } diff --git a/src/subscription/ws_transport.rs b/src/subscription/ws_transport.rs index b1a16b50..549be3b8 100644 --- a/src/subscription/ws_transport.rs +++ b/src/subscription/ws_transport.rs @@ -1,6 +1,6 @@ use crate::http::{GQLError, GQLRequest, GQLResponse}; use crate::{ - ObjectType, QueryResponse, Result, Schema, SubscriptionStubs, SubscriptionTransport, + ObjectType, QueryResponse, Schema, SubscriptionStreams, SubscriptionTransport, SubscriptionType, Variables, }; use bytes::Bytes; @@ -27,7 +27,7 @@ impl SubscriptionTransport for WebSocketTransport { fn handle_request( &mut self, schema: &Schema, - stubs: &mut SubscriptionStubs, + streams: &mut SubscriptionStreams, data: Bytes, ) -> std::result::Result, Self::Error> where @@ -54,16 +54,15 @@ impl SubscriptionTransport for WebSocketTransport { .map(|value| Variables::parse_from_json(value).ok()) .flatten() .unwrap_or_default(); - - match schema.create_subscription_stub( + match schema.create_subscription_stream( &request.query, request.operation_name.as_deref(), variables, ) { - Ok(stub) => { - let stub_id = stubs.add(stub); - self.id_to_sid.insert(id.clone(), stub_id); - self.sid_to_id.insert(stub_id, id); + Ok(stream) => { + let stream_id = streams.add(stream); + self.id_to_sid.insert(id.clone(), stream_id); + self.sid_to_id.insert(stream_id, id); Ok(None) } Err(err) => Ok(Some( @@ -89,7 +88,7 @@ impl SubscriptionTransport for WebSocketTransport { if let Some(id) = msg.id { if let Some(id) = self.id_to_sid.remove(&id) { self.sid_to_id.remove(&id); - stubs.remove(id); + streams.remove(id); } } Ok(None) @@ -101,15 +100,15 @@ impl SubscriptionTransport for WebSocketTransport { } } - fn handle_response(&mut self, id: usize, result: Result) -> Option { + fn handle_response(&mut self, id: usize, value: serde_json::Value) -> Option { if let Some(id) = self.sid_to_id.get(&id) { Some( serde_json::to_vec(&OperationMessage { ty: "data".to_string(), id: Some(id.clone()), payload: Some( - serde_json::to_value(GQLResponse(result.map(|data| QueryResponse { - data, + serde_json::to_value(GQLResponse(Ok(QueryResponse { + data: value, extensions: None, }))) .unwrap(), diff --git a/src/types/empty_subscription.rs b/src/types/empty_subscription.rs index 77864c4f..f7f81261 100644 --- a/src/types/empty_subscription.rs +++ b/src/types/empty_subscription.rs @@ -1,14 +1,13 @@ +use crate::context::Environment; use crate::{ - registry, ContextBase, ContextSelectionSet, Error, OutputValueType, QueryError, Result, - SubscriptionType, Type, + registry, Context, ContextSelectionSet, Error, ObjectType, OutputValueType, QueryError, Result, + Schema, SubscriptionType, Type, }; -use graphql_parser::query::Field; +use futures::Stream; use graphql_parser::Pos; -use serde_json::Value; -use std::any::{Any, TypeId}; use std::borrow::Cow; -use std::collections::hash_map::RandomState; -use std::collections::HashMap; +use std::pin::Pin; +use std::sync::Arc; /// Empty subscription /// @@ -36,17 +35,22 @@ impl SubscriptionType for EmptySubscription { true } - fn create_type(_field: &Field, _types: &mut HashMap) -> Result<()> { - unreachable!() - } - - async fn resolve( + fn create_field_stream( &self, - _ctx: &ContextBase<'_, ()>, - _types: &HashMap, - _msg: &(dyn Any + Send + Sync), - ) -> Result> { - unreachable!() + _ctx: &Context<'_>, + _schema: &Schema, + _environment: Arc, + ) -> Result>>> + where + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Self: Send + Sync + 'static + Sized, + { + Err(Error::Query { + pos: Pos::default(), + path: None, + err: QueryError::NotConfiguredSubscriptions, + }) } } diff --git a/src/validators/mod.rs b/src/validators/mod.rs index ed09bd71..c8250f7c 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -1,3 +1,5 @@ +//! Input value validators + mod int_validators; mod list_validators; mod string_validators;