diff --git a/derive/src/merged_subscription.rs b/derive/src/merged_subscription.rs index 4728cfce..de93259c 100644 --- a/derive/src/merged_subscription.rs +++ b/derive/src/merged_subscription.rs @@ -82,10 +82,18 @@ pub fn generate(object_args: &args::Object, input: &DeriveInput) -> Result, schema_env: #crate_name::SchemaEnv, query_env: #crate_name::QueryEnv) -> #crate_name::Result<::std::pin::Pin + Send>>> { - #create_merged_obj.create_field_stream(ctx, schema_env, query_env).await + fn create_field_stream<'a>( + &'a self, + ctx: &'a #crate_name::Context<'a> + ) -> ::std::pin::Pin<::std::boxed::Box> + Send + 'a>> { + ::std::boxed::Box::pin(#crate_name::async_stream::stream! { + let obj = #create_merged_obj; + let mut stream = obj.create_field_stream(ctx); + while let Some(item) = #crate_name::futures::stream::StreamExt::next(&mut stream).await { + yield item; + } + }) } } }; diff --git a/derive/src/subscription.rs b/derive/src/subscription.rs index e91b321e..baf2002e 100644 --- a/derive/src/subscription.rs +++ b/derive/src/subscription.rs @@ -233,8 +233,11 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< }); let create_field_stream = quote! { - #crate_name::futures::stream::StreamExt::fuse(self.#ident(ctx, #(#use_params),*).await. - map_err(|err| err.into_error_with_path(ctx.item.pos, ctx.path_node.as_ref()))?) + self.#ident(ctx, #(#use_params),*) + .await + .map_err(|err| { + err.into_error_with_path(ctx.item.pos, ctx.path_node.as_ref()) + })? }; let guard = field.guard.map(|guard| quote! { @@ -247,58 +250,65 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< )); } - create_stream.push(quote! { - if ctx.item.node.name.node == #field_name { - #(#get_params)* - #guard - let field_name = ::std::sync::Arc::new(ctx.item.node.response_key().node.clone()); - let field = ::std::sync::Arc::new(ctx.item.clone()); + let stream_fn = quote! { + #(#get_params)* + #guard + let field_name = ::std::sync::Arc::new(ctx.item.node.response_key().node.clone()); + let field = ::std::sync::Arc::new(ctx.item.clone()); - let pos = ctx.item.pos; - let schema_env = schema_env.clone(); - let query_env = query_env.clone(); - let stream = #crate_name::futures::StreamExt::then(#create_field_stream, { + let pos = ctx.item.pos; + let schema_env = ctx.schema_env.clone(); + let query_env = ctx.query_env.clone(); + let stream = #crate_name::futures::StreamExt::then(#create_field_stream, { + let field_name = field_name.clone(); + move |msg| { + let schema_env = schema_env.clone(); + let query_env = query_env.clone(); + let field = field.clone(); let field_name = field_name.clone(); - move |msg| { - let schema_env = schema_env.clone(); - let query_env = query_env.clone(); - let field = field.clone(); - let field_name = field_name.clone(); - async move { - let resolve_id = ::std::sync::atomic::AtomicUsize::default(); - let ctx_selection_set = query_env.create_context( - &schema_env, - Some(#crate_name::QueryPathNode { - parent: None, - segment: #crate_name::QueryPathSegment::Name(&field_name), - }), - &field.node.selection_set, - &resolve_id, - ); - #crate_name::OutputValueType::resolve(&msg, &ctx_selection_set, &*field).await - } + async move { + let resolve_id = ::std::sync::atomic::AtomicUsize::default(); + let ctx_selection_set = query_env.create_context( + &schema_env, + Some(#crate_name::QueryPathNode { + parent: None, + segment: #crate_name::QueryPathSegment::Name(&field_name), + }), + &field.node.selection_set, + &resolve_id, + ); + #crate_name::OutputValueType::resolve(&msg, &ctx_selection_set, &*field) + .await + .map(|value| { + #crate_name::serde_json::json!({ + field_name.as_str(): value + }) + }) } - }); - let stream = #crate_name::futures::TryStreamExt::map_ok(stream, move |value| #crate_name::serde_json::json!({ field_name.as_str(): value })); - let stream = #crate_name::futures::StreamExt::scan(stream, true, |state, item| { - if !*state { + } + }); + #crate_name::Result::Ok(#crate_name::futures::StreamExt::scan( + stream, + false, + |errored, item| { + if *errored { return #crate_name::futures::future::ready(None); } - let resp = match item { - Ok(value) => #crate_name::Response { - data: value, - extensions: None, - cache_control: Default::default(), - error: None, - }, - Err(err) => err.into(), - }; - if resp.is_err() { - *state = false; + if item.is_err() { + *errored = true; } - #crate_name::futures::future::ready(Some(resp)) - }); - return Ok(Box::pin(stream)); + #crate_name::futures::future::ready(Some(item)) + }, + )) + }; + + create_stream.push(quote! { + if ctx.item.node.name.node == #field_name { + return ::std::boxed::Box::pin( + #crate_name::futures::TryStreamExt::try_flatten( + #crate_name::futures::stream::once((move || async move { #stream_fn })()) + ) + ); } }); } @@ -333,7 +343,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< #(#schema_fields)* fields }, - cache_control: Default::default(), + cache_control: ::std::default::Default::default(), extends: false, keys: None, }) @@ -341,20 +351,19 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< } #[allow(clippy::all, clippy::pedantic)] - #[#crate_name::async_trait::async_trait] #[allow(unused_braces, unused_variables)] impl #crate_name::SubscriptionType for #self_ty #where_clause { - async fn create_field_stream( - &self, - ctx: &#crate_name::Context<'_>, - schema_env: #crate_name::SchemaEnv, - query_env: #crate_name::QueryEnv, - ) -> #crate_name::Result<::std::pin::Pin + Send>>> { + fn create_field_stream<'a>( + &'a self, + ctx: &'a #crate_name::Context<'a>, + ) -> ::std::pin::Pin<::std::boxed::Box> + Send + 'a>> { #(#create_stream)* - Err(#crate_name::QueryError::FieldNotFound { + let error = #crate_name::QueryError::FieldNotFound { field_name: ctx.item.node.name.to_string(), object: #gql_typename.to_string(), - }.into_error(ctx.item.pos)) + } + .into_error(ctx.item.pos); + ::std::boxed::Box::pin(#crate_name::futures::stream::once(async { Err(error) })) } } }; diff --git a/src/context.rs b/src/context.rs index 4e6b31a8..7c041879 100644 --- a/src/context.rs +++ b/src/context.rs @@ -232,8 +232,10 @@ pub struct ContextBase<'a, T> { pub(crate) inc_resolve_id: &'a AtomicUsize, #[doc(hidden)] pub item: T, - pub(crate) schema_env: &'a SchemaEnv, - pub(crate) query_env: &'a QueryEnv, + #[doc(hidden)] + pub schema_env: &'a SchemaEnv, + #[doc(hidden)] + pub query_env: &'a QueryEnv, } #[doc(hidden)] diff --git a/src/http/multipart.rs b/src/http/multipart.rs index 3c0b8551..ddcc43ff 100644 --- a/src/http/multipart.rs +++ b/src/http/multipart.rs @@ -140,9 +140,9 @@ impl Stream for ReaderStream { let this = self.project(); Poll::Ready( - match futures::ready!(this.reader.poll_read(cx, &mut self.buf)?) { + match futures::ready!(this.reader.poll_read(cx, this.buf)?) { 0 => None, - size => Some(Ok(Bytes::copy_from_slice(&self.buf[..size]))), + size => Some(Ok(Bytes::copy_from_slice(&this.buf[..size]))), } ) } diff --git a/src/lib.rs b/src/lib.rs index df6e25a7..1752d2ef 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -123,6 +123,8 @@ pub use async_graphql_parser as parser; #[doc(hidden)] pub use async_trait; #[doc(hidden)] +pub use async_stream; +#[doc(hidden)] pub use futures; #[doc(hidden)] pub use indexmap; diff --git a/src/resolver_utils/object.rs b/src/resolver_utils/object.rs index dbfc0e90..8980c90b 100644 --- a/src/resolver_utils/object.rs +++ b/src/resolver_utils/object.rs @@ -12,7 +12,7 @@ use std::pin::Pin; /// `OutputValueType::resolve` implementation. #[async_trait::async_trait] pub trait ObjectType: OutputValueType { - /// This function returns true of type `EmptyMutation` only + /// This function returns true of type `EmptyMutation` only. #[doc(hidden)] fn is_empty() -> bool { false @@ -151,6 +151,7 @@ impl<'a> Fields<'a> { } self.0.push(Box::pin({ + // TODO: investigate removing this let ctx = ctx.clone(); async move { let ctx_field = ctx.with_field(field); diff --git a/src/response.rs b/src/response.rs index bf47e094..99c96886 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,7 +1,7 @@ use crate::{CacheControl, Error, Result}; /// Query response -#[derive(Debug)] +#[derive(Debug, Default)] pub struct Response { /// Data of query result pub data: serde_json::Value, @@ -17,6 +17,51 @@ pub struct Response { } impl Response { + /// Create a new successful response with the data. + #[must_use] + pub fn new(data: impl Into) -> Self { + Self { + data: data.into(), + ..Default::default() + } + } + + /// Create a response from the error. + #[must_use] + pub fn from_error(error: impl Into) -> Self { + Self { + error: Some(error.into()), + ..Default::default() + } + } + + /// Create a response from the result of the data and an error. + #[must_use] + pub fn from_result(result: Result) -> Self { + match result { + Ok(data) => Self::new(data), + Err(e) => Self::from_error(e), + } + } + + /// Set the extensions result of the response. + #[must_use] + pub fn extensions(self, extensions: Option) -> Self { + Self { + extensions, + ..self + } + } + + /// Set the cache control of the response. + #[must_use] + pub fn cache_control(self, cache_control: CacheControl) -> Self { + Self { + cache_control, + ..self + } + } + /// Returns `true` if the response is ok. #[inline] pub fn is_ok(&self) -> bool { @@ -29,7 +74,8 @@ impl Response { self.error.is_some() } - /// Convert response to `Result`. + /// Extract the error from the response. Only if the `error` field is `None` will this return + /// `Ok`. #[inline] pub fn into_result(self) -> Result { if self.is_err() { @@ -42,11 +88,6 @@ impl Response { impl From for Response { fn from(err: Error) -> Self { - Self { - data: serde_json::Value::Null, - extensions: None, - cache_control: CacheControl::default(), - error: Some(err), - } + Self::from_error(err) } } diff --git a/src/schema.rs b/src/schema.rs index 38bcec75..73035a99 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -5,7 +5,7 @@ use crate::parser::parse_query; use crate::parser::types::OperationType; use crate::registry::{MetaDirective, MetaInputValue, Registry}; use crate::resolver_utils::{resolve_object, resolve_object_serial, ObjectType}; -use crate::subscription::create_subscription_stream; +use crate::subscription::collect_subscription_streams; use crate::types::QueryRoot; use crate::validation::{check_rules, CheckResult, ValidationMode}; use crate::{ @@ -13,7 +13,7 @@ use crate::{ SubscriptionType, Type, Variables, ID, }; use async_graphql_parser::types::ExecutableDocumentData; -use futures::{Stream, StreamExt}; +use futures::stream::{self, Stream, StreamExt}; use indexmap::map::IndexMap; use itertools::Itertools; use std::any::Any; @@ -21,15 +21,6 @@ use std::ops::Deref; use std::sync::atomic::AtomicUsize; use std::sync::Arc; -macro_rules! try_query_result { - ($res:expr) => { - match $res { - Ok(resp) => resp, - Err(err) => return err.into(), - } - }; -} - /// Schema builder pub struct SchemaBuilder { validation_mode: ValidationMode, @@ -411,10 +402,11 @@ where }; env.extensions.lock().execution_start(); + let data = match &env.document.operation.node.ty { - OperationType::Query => try_query_result!(resolve_object(&ctx, &self.query).await), + OperationType::Query => resolve_object(&ctx, &self.query).await, OperationType::Mutation => { - try_query_result!(resolve_object_serial(&ctx, &self.mutation).await) + resolve_object_serial(&ctx, &self.mutation).await } OperationType::Subscription => { return Error::Query { @@ -428,24 +420,21 @@ where env.extensions.lock().execution_end(); let extensions = env.extensions.lock().result(); - Response { - data, - extensions, - cache_control: Default::default(), - error: None, - } + + Response::from_result(data) + .extensions(extensions) } /// Execute an GraphQL query. pub async fn execute(&self, request: impl Into) -> Response { let request = request.into(); - let (document, cache_control, extensions) = - try_query_result!(self.prepare_request(&request)); - let mut resp = self - .execute_once(document, extensions, request.variables, request.data) - .await; - resp.cache_control = cache_control; - resp + match self.prepare_request(&request) { + Ok((document, cache_control, extensions)) => self + .execute_once(document, extensions, request.variables, request.data) + .await + .cache_control(cache_control), + Err(e) => Response::from_error(e), + } } pub(crate) fn execute_stream_with_ctx_data( @@ -454,9 +443,10 @@ where ctx_data: Arc, ) -> impl Stream { let schema = self.clone(); + async_stream::stream! { let request = request.into(); - let (document, cache_control, extensions) = match schema.prepare_request(& request) { + let (document, cache_control, extensions) = match schema.prepare_request(&request) { Ok(res) => res, Err(err) => { yield Response::from(err); @@ -465,11 +455,10 @@ where }; if document.operation.node.ty != OperationType::Subscription { - let mut resp = schema + yield schema .execute_once(document, extensions, request.variables, request.data) - .await; - resp.cache_control = cache_control; - yield resp; + .await + .cache_control(cache_control); return; } @@ -488,17 +477,19 @@ where &resolve_id, ); - let mut streams = Vec::new(); + // TODO: Invoke extensions - if let Err(err) = create_subscription_stream(&schema, env.clone(), &ctx, &mut streams).await { - yield err.into(); + let mut streams = Vec::new(); + if let Err(e) = collect_subscription_streams(&ctx, &schema.subscription, &mut streams) { + yield Response::from(e); return; } - let mut stream = futures::stream::select_all(streams); - while let Some(resp) = stream.next().await { - let is_err = resp.is_err(); - yield resp; + let mut stream = stream::select_all(streams); + while let Some(data) = stream.next().await { + let is_err = data.is_err(); + let extensions = env.extensions.lock().result(); + yield Response::from_result(data).extensions(extensions); if is_err { break; } diff --git a/src/subscription.rs b/src/subscription.rs index 2a2eb821..11aaa782 100644 --- a/src/subscription.rs +++ b/src/subscription.rs @@ -1,112 +1,89 @@ -use crate::context::QueryEnv; use crate::parser::types::{Selection, TypeCondition}; -use crate::{Context, ContextSelectionSet, Response, Result, Schema, SchemaEnv, Type}; -use futures::{Future, Stream}; +use crate::{Context, ContextSelectionSet, Result, Type}; +use futures::{Stream, StreamExt}; use std::pin::Pin; /// Represents a GraphQL subscription object -#[async_trait::async_trait] pub trait SubscriptionType: Type { - /// This function returns true of type `EmptySubscription` only + /// This function returns true of type `EmptySubscription` only. #[doc(hidden)] fn is_empty() -> bool { false } #[doc(hidden)] - async fn create_field_stream( - &self, - ctx: &Context<'_>, - schema_env: SchemaEnv, - query_env: QueryEnv, - ) -> Result + Send>>>; + fn create_field_stream<'a>( + &'a self, + ctx: &'a Context<'a>, + ) -> Pin> + Send + 'a>>; } -type BoxCreateStreamFuture<'a> = Pin> + Send + 'a>>; - -pub(crate) fn create_subscription_stream<'a, Query, Mutation, Subscription>( - schema: &'a Schema, - environment: QueryEnv, - ctx: &'a ContextSelectionSet<'_>, - streams: &'a mut Vec + Send>>>, -) -> BoxCreateStreamFuture<'a> -where - Query: Send + Sync, - Mutation: Send + Sync, - Subscription: SubscriptionType + Send + Sync + 'static + Sized, -{ - Box::pin(async move { - for selection in &ctx.item.node.items { - if ctx.is_skip(selection.node.directives())? { - continue; - } - match &selection.node { - Selection::Field(field) => streams.push( - schema - .subscription - .create_field_stream( - &ctx.with_field(field), - schema.env.clone(), - environment.clone(), - ) - .await?, - ), - Selection::FragmentSpread(fragment_spread) => { - if let Some(fragment) = ctx - .query_env - .document - .fragments - .get(&fragment_spread.node.fragment_name.node) - { - create_subscription_stream( - schema, - environment.clone(), - &ctx.with_selection_set(&fragment.node.selection_set), - streams, - ) - .await?; +pub(crate) fn collect_subscription_streams<'a, T: SubscriptionType + Send + Sync + 'static>( + ctx: &ContextSelectionSet<'a>, + root: &'a T, + streams: &mut Vec> + Send + 'a>>>, +) -> Result<()> { + for selection in &ctx.item.node.items { + if ctx.is_skip(selection.node.directives())? { + continue; + } + match &selection.node { + Selection::Field(field) => streams.push(Box::pin({ + let ctx = ctx.clone(); + async_stream::stream! { + let ctx = ctx.with_field(field); + let mut stream = root.create_field_stream(&ctx); + while let Some(item) = stream.next().await { + yield item; } } - Selection::InlineFragment(inline_fragment) => { - if let Some(TypeCondition { on: name }) = inline_fragment - .node - .type_condition - .as_ref() - .map(|v| &v.node) - { - if name.node.as_str() == Subscription::type_name() { - create_subscription_stream( - schema, - environment.clone(), - &ctx.with_selection_set(&inline_fragment.node.selection_set), - streams, - ) - .await?; - } - } else { - create_subscription_stream( - schema, - environment.clone(), + })), + Selection::FragmentSpread(fragment_spread) => { + if let Some(fragment) = ctx + .query_env + .document + .fragments + .get(&fragment_spread.node.fragment_name.node) + { + collect_subscription_streams( + &ctx.with_selection_set(&fragment.node.selection_set), + root, + streams, + )?; + } + } + Selection::InlineFragment(inline_fragment) => { + if let Some(TypeCondition { on: name }) = inline_fragment + .node + .type_condition + .as_ref() + .map(|v| &v.node) + { + if name.node.as_str() == T::type_name() { + collect_subscription_streams( &ctx.with_selection_set(&inline_fragment.node.selection_set), + root, streams, - ) - .await?; + )?; } + } else { + collect_subscription_streams( + &ctx.with_selection_set(&inline_fragment.node.selection_set), + root, + streams, + )?; } } } - Ok(()) - }) + } + Ok(()) } -#[async_trait::async_trait] impl SubscriptionType for &T { - async fn create_field_stream( - &self, - ctx: &Context<'_>, - schema_env: SchemaEnv, - query_env: QueryEnv, - ) -> Result + Send>>> { - T::create_field_stream(*self, ctx, schema_env, query_env).await + fn create_field_stream<'a>( + &'a self, + ctx: &'a Context<'a>, + ) -> Pin> + Send + 'a>> { + T::create_field_stream(*self, ctx) } } diff --git a/src/types/empty_subscription.rs b/src/types/empty_subscription.rs index 53be40a2..961908da 100644 --- a/src/types/empty_subscription.rs +++ b/src/types/empty_subscription.rs @@ -1,8 +1,7 @@ -use crate::context::QueryEnv; use crate::{ - registry, Context, Error, Pos, QueryError, Response, Result, SchemaEnv, SubscriptionType, Type, + registry, Context, Error, Pos, QueryError, Result, SubscriptionType, Type, }; -use futures::Stream; +use futures::{stream, Stream}; use std::borrow::Cow; use std::pin::Pin; @@ -29,25 +28,24 @@ impl Type for EmptySubscription { } } -#[async_trait::async_trait] impl SubscriptionType for EmptySubscription { fn is_empty() -> bool { true } - async fn create_field_stream( - &self, - _ctx: &Context<'_>, - _schema_env: SchemaEnv, - _query_env: QueryEnv, - ) -> Result + Send>>> + fn create_field_stream<'a>( + &'a self, + _ctx: &'a Context<'a>, + ) -> Pin> + Send + 'a>> where Self: Send + Sync + 'static + Sized, { - Err(Error::Query { - pos: Pos::default(), - path: None, - err: QueryError::NotConfiguredSubscriptions, - }) + Box::pin(stream::once(async { + Err(Error::Query { + pos: Pos::default(), + path: None, + err: QueryError::NotConfiguredSubscriptions, + }) + })) } } diff --git a/src/types/merged_object.rs b/src/types/merged_object.rs index b4799e1c..b8f76d32 100644 --- a/src/types/merged_object.rs +++ b/src/types/merged_object.rs @@ -3,10 +3,10 @@ use crate::registry::{MetaType, Registry}; use crate::resolver_utils::{resolve_object, ObjectType}; use crate::{ CacheControl, Context, ContextSelectionSet, Error, GQLSimpleObject, GQLSubscription, - OutputValueType, Positioned, QueryEnv, QueryError, Response, Result, SchemaEnv, + OutputValueType, Positioned, QueryError, Result, SubscriptionType, Type, }; -use futures::Stream; +use futures::{stream, Stream, StreamExt, future::Either}; use indexmap::IndexMap; use std::borrow::Cow; use std::pin::Pin; @@ -101,34 +101,31 @@ where } } -#[async_trait::async_trait] impl SubscriptionType for MergedObject where A: SubscriptionType + Send + Sync, B: SubscriptionType + Send + Sync, { - async fn create_field_stream( - &self, - ctx: &Context<'_>, - schema_env: SchemaEnv, - query_env: QueryEnv, - ) -> Result + Send>>> { - match self - .0 - .create_field_stream(ctx, schema_env.clone(), query_env.clone()) - .await - { - Ok(value) => Ok(value), - Err(Error::Query { - err: QueryError::FieldNotFound { .. }, - .. - }) => { - self.1 - .create_field_stream(ctx, schema_env, query_env) - .await - } - Err(err) => Err(err), - } + fn create_field_stream<'a>( + &'a self, + ctx: &'a Context<'a>, + ) -> Pin> + Send + 'a>> { + let left_stream = self.0.create_field_stream(ctx); + let mut right_stream = Some(self.1.create_field_stream(ctx)); + Box::pin( + left_stream + .flat_map(move |res| { + match res { + Err(Error::Query { + err: QueryError::FieldNotFound { .. }, + .. + }) if right_stream.is_some() => { + Either::Right(right_stream.take().unwrap()) + } + other => Either::Left(stream::once(async { other })), + } + }) + ) } } diff --git a/tests/field_features.rs b/tests/field_features.rs index 89eaf48c..7be0aaeb 100644 --- a/tests/field_features.rs +++ b/tests/field_features.rs @@ -126,7 +126,7 @@ pub async fn test_field_features() { let mut stream = schema.execute_stream("subscription { values }").boxed(); assert_eq!( - stream.next().await.map(|resp| resp.data), + stream.next().await.map(|resp| resp.into_result().unwrap().data), Some(serde_json::json!({ "values": 10 }))