Rework subscriptions

The main change in this commit is changing the return type of
SubscriptionType::create_stream from
Future<Result<Stream<Result<Response>>>> to just Stream<Result<Json>>. I
also allowed the returned stream to borrow from self and context.
This commit is contained in:
Koxiaet 2020-09-13 18:52:36 +01:00
parent 91ec3486ce
commit d404e756bc
12 changed files with 265 additions and 239 deletions

View File

@ -82,10 +82,18 @@ pub fn generate(object_args: &args::Object, input: &DeriveInput) -> Result<Token
} }
#[allow(clippy::all, clippy::pedantic)] #[allow(clippy::all, clippy::pedantic)]
#[#crate_name::async_trait::async_trait]
impl #crate_name::SubscriptionType for #ident { impl #crate_name::SubscriptionType for #ident {
async fn create_field_stream(&self, ctx: &#crate_name::Context<'_>, schema_env: #crate_name::SchemaEnv, query_env: #crate_name::QueryEnv) -> #crate_name::Result<::std::pin::Pin<Box<dyn #crate_name::futures::Stream<Item = #crate_name::Response> + Send>>> { fn create_field_stream<'a>(
#create_merged_obj.create_field_stream(ctx, schema_env, query_env).await &'a self,
ctx: &'a #crate_name::Context<'a>
) -> ::std::pin::Pin<::std::boxed::Box<dyn #crate_name::futures::Stream<Item = #crate_name::Result<#crate_name::serde_json::Value>> + Send + 'a>> {
::std::boxed::Box::pin(#crate_name::async_stream::stream! {
let obj = #create_merged_obj;
let mut stream = obj.create_field_stream(ctx);
while let Some(item) = #crate_name::futures::stream::StreamExt::next(&mut stream).await {
yield item;
}
})
} }
} }
}; };

View File

@ -233,8 +233,11 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
}); });
let create_field_stream = quote! { let create_field_stream = quote! {
#crate_name::futures::stream::StreamExt::fuse(self.#ident(ctx, #(#use_params),*).await. self.#ident(ctx, #(#use_params),*)
map_err(|err| err.into_error_with_path(ctx.item.pos, ctx.path_node.as_ref()))?) .await
.map_err(|err| {
err.into_error_with_path(ctx.item.pos, ctx.path_node.as_ref())
})?
}; };
let guard = field.guard.map(|guard| quote! { let guard = field.guard.map(|guard| quote! {
@ -247,58 +250,65 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
)); ));
} }
create_stream.push(quote! { let stream_fn = quote! {
if ctx.item.node.name.node == #field_name { #(#get_params)*
#(#get_params)* #guard
#guard let field_name = ::std::sync::Arc::new(ctx.item.node.response_key().node.clone());
let field_name = ::std::sync::Arc::new(ctx.item.node.response_key().node.clone()); let field = ::std::sync::Arc::new(ctx.item.clone());
let field = ::std::sync::Arc::new(ctx.item.clone());
let pos = ctx.item.pos; let pos = ctx.item.pos;
let schema_env = schema_env.clone(); let schema_env = ctx.schema_env.clone();
let query_env = query_env.clone(); let query_env = ctx.query_env.clone();
let stream = #crate_name::futures::StreamExt::then(#create_field_stream, { let stream = #crate_name::futures::StreamExt::then(#create_field_stream, {
let field_name = field_name.clone();
move |msg| {
let schema_env = schema_env.clone();
let query_env = query_env.clone();
let field = field.clone();
let field_name = field_name.clone(); let field_name = field_name.clone();
move |msg| { async move {
let schema_env = schema_env.clone(); let resolve_id = ::std::sync::atomic::AtomicUsize::default();
let query_env = query_env.clone(); let ctx_selection_set = query_env.create_context(
let field = field.clone(); &schema_env,
let field_name = field_name.clone(); Some(#crate_name::QueryPathNode {
async move { parent: None,
let resolve_id = ::std::sync::atomic::AtomicUsize::default(); segment: #crate_name::QueryPathSegment::Name(&field_name),
let ctx_selection_set = query_env.create_context( }),
&schema_env, &field.node.selection_set,
Some(#crate_name::QueryPathNode { &resolve_id,
parent: None, );
segment: #crate_name::QueryPathSegment::Name(&field_name), #crate_name::OutputValueType::resolve(&msg, &ctx_selection_set, &*field)
}), .await
&field.node.selection_set, .map(|value| {
&resolve_id, #crate_name::serde_json::json!({
); field_name.as_str(): value
#crate_name::OutputValueType::resolve(&msg, &ctx_selection_set, &*field).await })
} })
} }
}); }
let stream = #crate_name::futures::TryStreamExt::map_ok(stream, move |value| #crate_name::serde_json::json!({ field_name.as_str(): value })); });
let stream = #crate_name::futures::StreamExt::scan(stream, true, |state, item| { #crate_name::Result::Ok(#crate_name::futures::StreamExt::scan(
if !*state { stream,
false,
|errored, item| {
if *errored {
return #crate_name::futures::future::ready(None); return #crate_name::futures::future::ready(None);
} }
let resp = match item { if item.is_err() {
Ok(value) => #crate_name::Response { *errored = true;
data: value,
extensions: None,
cache_control: Default::default(),
error: None,
},
Err(err) => err.into(),
};
if resp.is_err() {
*state = false;
} }
#crate_name::futures::future::ready(Some(resp)) #crate_name::futures::future::ready(Some(item))
}); },
return Ok(Box::pin(stream)); ))
};
create_stream.push(quote! {
if ctx.item.node.name.node == #field_name {
return ::std::boxed::Box::pin(
#crate_name::futures::TryStreamExt::try_flatten(
#crate_name::futures::stream::once((move || async move { #stream_fn })())
)
);
} }
}); });
} }
@ -333,7 +343,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
#(#schema_fields)* #(#schema_fields)*
fields fields
}, },
cache_control: Default::default(), cache_control: ::std::default::Default::default(),
extends: false, extends: false,
keys: None, keys: None,
}) })
@ -341,20 +351,19 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
} }
#[allow(clippy::all, clippy::pedantic)] #[allow(clippy::all, clippy::pedantic)]
#[#crate_name::async_trait::async_trait]
#[allow(unused_braces, unused_variables)] #[allow(unused_braces, unused_variables)]
impl #crate_name::SubscriptionType for #self_ty #where_clause { impl #crate_name::SubscriptionType for #self_ty #where_clause {
async fn create_field_stream( fn create_field_stream<'a>(
&self, &'a self,
ctx: &#crate_name::Context<'_>, ctx: &'a #crate_name::Context<'a>,
schema_env: #crate_name::SchemaEnv, ) -> ::std::pin::Pin<::std::boxed::Box<dyn #crate_name::futures::Stream<Item = #crate_name::Result<#crate_name::serde_json::Value>> + Send + 'a>> {
query_env: #crate_name::QueryEnv,
) -> #crate_name::Result<::std::pin::Pin<Box<dyn #crate_name::futures::Stream<Item = #crate_name::Response> + Send>>> {
#(#create_stream)* #(#create_stream)*
Err(#crate_name::QueryError::FieldNotFound { let error = #crate_name::QueryError::FieldNotFound {
field_name: ctx.item.node.name.to_string(), field_name: ctx.item.node.name.to_string(),
object: #gql_typename.to_string(), object: #gql_typename.to_string(),
}.into_error(ctx.item.pos)) }
.into_error(ctx.item.pos);
::std::boxed::Box::pin(#crate_name::futures::stream::once(async { Err(error) }))
} }
} }
}; };

View File

@ -232,8 +232,10 @@ pub struct ContextBase<'a, T> {
pub(crate) inc_resolve_id: &'a AtomicUsize, pub(crate) inc_resolve_id: &'a AtomicUsize,
#[doc(hidden)] #[doc(hidden)]
pub item: T, pub item: T,
pub(crate) schema_env: &'a SchemaEnv, #[doc(hidden)]
pub(crate) query_env: &'a QueryEnv, pub schema_env: &'a SchemaEnv,
#[doc(hidden)]
pub query_env: &'a QueryEnv,
} }
#[doc(hidden)] #[doc(hidden)]

View File

@ -140,9 +140,9 @@ impl<T: AsyncRead> Stream for ReaderStream<T> {
let this = self.project(); let this = self.project();
Poll::Ready( Poll::Ready(
match futures::ready!(this.reader.poll_read(cx, &mut self.buf)?) { match futures::ready!(this.reader.poll_read(cx, this.buf)?) {
0 => None, 0 => None,
size => Some(Ok(Bytes::copy_from_slice(&self.buf[..size]))), size => Some(Ok(Bytes::copy_from_slice(&this.buf[..size]))),
} }
) )
} }

View File

@ -123,6 +123,8 @@ pub use async_graphql_parser as parser;
#[doc(hidden)] #[doc(hidden)]
pub use async_trait; pub use async_trait;
#[doc(hidden)] #[doc(hidden)]
pub use async_stream;
#[doc(hidden)]
pub use futures; pub use futures;
#[doc(hidden)] #[doc(hidden)]
pub use indexmap; pub use indexmap;

View File

@ -12,7 +12,7 @@ use std::pin::Pin;
/// `OutputValueType::resolve` implementation. /// `OutputValueType::resolve` implementation.
#[async_trait::async_trait] #[async_trait::async_trait]
pub trait ObjectType: OutputValueType { pub trait ObjectType: OutputValueType {
/// This function returns true of type `EmptyMutation` only /// This function returns true of type `EmptyMutation` only.
#[doc(hidden)] #[doc(hidden)]
fn is_empty() -> bool { fn is_empty() -> bool {
false false
@ -151,6 +151,7 @@ impl<'a> Fields<'a> {
} }
self.0.push(Box::pin({ self.0.push(Box::pin({
// TODO: investigate removing this
let ctx = ctx.clone(); let ctx = ctx.clone();
async move { async move {
let ctx_field = ctx.with_field(field); let ctx_field = ctx.with_field(field);

View File

@ -1,7 +1,7 @@
use crate::{CacheControl, Error, Result}; use crate::{CacheControl, Error, Result};
/// Query response /// Query response
#[derive(Debug)] #[derive(Debug, Default)]
pub struct Response { pub struct Response {
/// Data of query result /// Data of query result
pub data: serde_json::Value, pub data: serde_json::Value,
@ -17,6 +17,51 @@ pub struct Response {
} }
impl Response { impl Response {
/// Create a new successful response with the data.
#[must_use]
pub fn new(data: impl Into<serde_json::Value>) -> Self {
Self {
data: data.into(),
..Default::default()
}
}
/// Create a response from the error.
#[must_use]
pub fn from_error(error: impl Into<Error>) -> Self {
Self {
error: Some(error.into()),
..Default::default()
}
}
/// Create a response from the result of the data and an error.
#[must_use]
pub fn from_result(result: Result<serde_json::Value>) -> Self {
match result {
Ok(data) => Self::new(data),
Err(e) => Self::from_error(e),
}
}
/// Set the extensions result of the response.
#[must_use]
pub fn extensions(self, extensions: Option<serde_json::Value>) -> Self {
Self {
extensions,
..self
}
}
/// Set the cache control of the response.
#[must_use]
pub fn cache_control(self, cache_control: CacheControl) -> Self {
Self {
cache_control,
..self
}
}
/// Returns `true` if the response is ok. /// Returns `true` if the response is ok.
#[inline] #[inline]
pub fn is_ok(&self) -> bool { pub fn is_ok(&self) -> bool {
@ -29,7 +74,8 @@ impl Response {
self.error.is_some() self.error.is_some()
} }
/// Convert response to `Result<Response>`. /// Extract the error from the response. Only if the `error` field is `None` will this return
/// `Ok`.
#[inline] #[inline]
pub fn into_result(self) -> Result<Self> { pub fn into_result(self) -> Result<Self> {
if self.is_err() { if self.is_err() {
@ -42,11 +88,6 @@ impl Response {
impl From<Error> for Response { impl From<Error> for Response {
fn from(err: Error) -> Self { fn from(err: Error) -> Self {
Self { Self::from_error(err)
data: serde_json::Value::Null,
extensions: None,
cache_control: CacheControl::default(),
error: Some(err),
}
} }
} }

View File

@ -5,7 +5,7 @@ use crate::parser::parse_query;
use crate::parser::types::OperationType; use crate::parser::types::OperationType;
use crate::registry::{MetaDirective, MetaInputValue, Registry}; use crate::registry::{MetaDirective, MetaInputValue, Registry};
use crate::resolver_utils::{resolve_object, resolve_object_serial, ObjectType}; use crate::resolver_utils::{resolve_object, resolve_object_serial, ObjectType};
use crate::subscription::create_subscription_stream; use crate::subscription::collect_subscription_streams;
use crate::types::QueryRoot; use crate::types::QueryRoot;
use crate::validation::{check_rules, CheckResult, ValidationMode}; use crate::validation::{check_rules, CheckResult, ValidationMode};
use crate::{ use crate::{
@ -13,7 +13,7 @@ use crate::{
SubscriptionType, Type, Variables, ID, SubscriptionType, Type, Variables, ID,
}; };
use async_graphql_parser::types::ExecutableDocumentData; use async_graphql_parser::types::ExecutableDocumentData;
use futures::{Stream, StreamExt}; use futures::stream::{self, Stream, StreamExt};
use indexmap::map::IndexMap; use indexmap::map::IndexMap;
use itertools::Itertools; use itertools::Itertools;
use std::any::Any; use std::any::Any;
@ -21,15 +21,6 @@ use std::ops::Deref;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
use std::sync::Arc; use std::sync::Arc;
macro_rules! try_query_result {
($res:expr) => {
match $res {
Ok(resp) => resp,
Err(err) => return err.into(),
}
};
}
/// Schema builder /// Schema builder
pub struct SchemaBuilder<Query, Mutation, Subscription> { pub struct SchemaBuilder<Query, Mutation, Subscription> {
validation_mode: ValidationMode, validation_mode: ValidationMode,
@ -411,10 +402,11 @@ where
}; };
env.extensions.lock().execution_start(); env.extensions.lock().execution_start();
let data = match &env.document.operation.node.ty { let data = match &env.document.operation.node.ty {
OperationType::Query => try_query_result!(resolve_object(&ctx, &self.query).await), OperationType::Query => resolve_object(&ctx, &self.query).await,
OperationType::Mutation => { OperationType::Mutation => {
try_query_result!(resolve_object_serial(&ctx, &self.mutation).await) resolve_object_serial(&ctx, &self.mutation).await
} }
OperationType::Subscription => { OperationType::Subscription => {
return Error::Query { return Error::Query {
@ -428,24 +420,21 @@ where
env.extensions.lock().execution_end(); env.extensions.lock().execution_end();
let extensions = env.extensions.lock().result(); let extensions = env.extensions.lock().result();
Response {
data, Response::from_result(data)
extensions, .extensions(extensions)
cache_control: Default::default(),
error: None,
}
} }
/// Execute an GraphQL query. /// Execute an GraphQL query.
pub async fn execute(&self, request: impl Into<Request>) -> Response { pub async fn execute(&self, request: impl Into<Request>) -> Response {
let request = request.into(); let request = request.into();
let (document, cache_control, extensions) = match self.prepare_request(&request) {
try_query_result!(self.prepare_request(&request)); Ok((document, cache_control, extensions)) => self
let mut resp = self .execute_once(document, extensions, request.variables, request.data)
.execute_once(document, extensions, request.variables, request.data) .await
.await; .cache_control(cache_control),
resp.cache_control = cache_control; Err(e) => Response::from_error(e),
resp }
} }
pub(crate) fn execute_stream_with_ctx_data( pub(crate) fn execute_stream_with_ctx_data(
@ -454,9 +443,10 @@ where
ctx_data: Arc<Data>, ctx_data: Arc<Data>,
) -> impl Stream<Item = Response> { ) -> impl Stream<Item = Response> {
let schema = self.clone(); let schema = self.clone();
async_stream::stream! { async_stream::stream! {
let request = request.into(); let request = request.into();
let (document, cache_control, extensions) = match schema.prepare_request(& request) { let (document, cache_control, extensions) = match schema.prepare_request(&request) {
Ok(res) => res, Ok(res) => res,
Err(err) => { Err(err) => {
yield Response::from(err); yield Response::from(err);
@ -465,11 +455,10 @@ where
}; };
if document.operation.node.ty != OperationType::Subscription { if document.operation.node.ty != OperationType::Subscription {
let mut resp = schema yield schema
.execute_once(document, extensions, request.variables, request.data) .execute_once(document, extensions, request.variables, request.data)
.await; .await
resp.cache_control = cache_control; .cache_control(cache_control);
yield resp;
return; return;
} }
@ -488,17 +477,19 @@ where
&resolve_id, &resolve_id,
); );
let mut streams = Vec::new(); // TODO: Invoke extensions
if let Err(err) = create_subscription_stream(&schema, env.clone(), &ctx, &mut streams).await { let mut streams = Vec::new();
yield err.into(); if let Err(e) = collect_subscription_streams(&ctx, &schema.subscription, &mut streams) {
yield Response::from(e);
return; return;
} }
let mut stream = futures::stream::select_all(streams); let mut stream = stream::select_all(streams);
while let Some(resp) = stream.next().await { while let Some(data) = stream.next().await {
let is_err = resp.is_err(); let is_err = data.is_err();
yield resp; let extensions = env.extensions.lock().result();
yield Response::from_result(data).extensions(extensions);
if is_err { if is_err {
break; break;
} }

View File

@ -1,112 +1,89 @@
use crate::context::QueryEnv;
use crate::parser::types::{Selection, TypeCondition}; use crate::parser::types::{Selection, TypeCondition};
use crate::{Context, ContextSelectionSet, Response, Result, Schema, SchemaEnv, Type}; use crate::{Context, ContextSelectionSet, Result, Type};
use futures::{Future, Stream}; use futures::{Stream, StreamExt};
use std::pin::Pin; use std::pin::Pin;
/// Represents a GraphQL subscription object /// Represents a GraphQL subscription object
#[async_trait::async_trait]
pub trait SubscriptionType: Type { pub trait SubscriptionType: Type {
/// This function returns true of type `EmptySubscription` only /// This function returns true of type `EmptySubscription` only.
#[doc(hidden)] #[doc(hidden)]
fn is_empty() -> bool { fn is_empty() -> bool {
false false
} }
#[doc(hidden)] #[doc(hidden)]
async fn create_field_stream( fn create_field_stream<'a>(
&self, &'a self,
ctx: &Context<'_>, ctx: &'a Context<'a>,
schema_env: SchemaEnv, ) -> Pin<Box<dyn Stream<Item = Result<serde_json::Value>> + Send + 'a>>;
query_env: QueryEnv,
) -> Result<Pin<Box<dyn Stream<Item = Response> + Send>>>;
} }
type BoxCreateStreamFuture<'a> = Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>; pub(crate) fn collect_subscription_streams<'a, T: SubscriptionType + Send + Sync + 'static>(
ctx: &ContextSelectionSet<'a>,
pub(crate) fn create_subscription_stream<'a, Query, Mutation, Subscription>( root: &'a T,
schema: &'a Schema<Query, Mutation, Subscription>, streams: &mut Vec<Pin<Box<dyn Stream<Item = Result<serde_json::Value>> + Send + 'a>>>,
environment: QueryEnv, ) -> Result<()> {
ctx: &'a ContextSelectionSet<'_>, for selection in &ctx.item.node.items {
streams: &'a mut Vec<Pin<Box<dyn Stream<Item = Response> + Send>>>, if ctx.is_skip(selection.node.directives())? {
) -> BoxCreateStreamFuture<'a> continue;
where }
Query: Send + Sync, match &selection.node {
Mutation: Send + Sync, Selection::Field(field) => streams.push(Box::pin({
Subscription: SubscriptionType + Send + Sync + 'static + Sized, let ctx = ctx.clone();
{ async_stream::stream! {
Box::pin(async move { let ctx = ctx.with_field(field);
for selection in &ctx.item.node.items { let mut stream = root.create_field_stream(&ctx);
if ctx.is_skip(selection.node.directives())? { while let Some(item) = stream.next().await {
continue; yield item;
}
match &selection.node {
Selection::Field(field) => streams.push(
schema
.subscription
.create_field_stream(
&ctx.with_field(field),
schema.env.clone(),
environment.clone(),
)
.await?,
),
Selection::FragmentSpread(fragment_spread) => {
if let Some(fragment) = ctx
.query_env
.document
.fragments
.get(&fragment_spread.node.fragment_name.node)
{
create_subscription_stream(
schema,
environment.clone(),
&ctx.with_selection_set(&fragment.node.selection_set),
streams,
)
.await?;
} }
} }
Selection::InlineFragment(inline_fragment) => { })),
if let Some(TypeCondition { on: name }) = inline_fragment Selection::FragmentSpread(fragment_spread) => {
.node if let Some(fragment) = ctx
.type_condition .query_env
.as_ref() .document
.map(|v| &v.node) .fragments
{ .get(&fragment_spread.node.fragment_name.node)
if name.node.as_str() == Subscription::type_name() { {
create_subscription_stream( collect_subscription_streams(
schema, &ctx.with_selection_set(&fragment.node.selection_set),
environment.clone(), root,
&ctx.with_selection_set(&inline_fragment.node.selection_set), streams,
streams, )?;
) }
.await?; }
} Selection::InlineFragment(inline_fragment) => {
} else { if let Some(TypeCondition { on: name }) = inline_fragment
create_subscription_stream( .node
schema, .type_condition
environment.clone(), .as_ref()
.map(|v| &v.node)
{
if name.node.as_str() == T::type_name() {
collect_subscription_streams(
&ctx.with_selection_set(&inline_fragment.node.selection_set), &ctx.with_selection_set(&inline_fragment.node.selection_set),
root,
streams, streams,
) )?;
.await?;
} }
} else {
collect_subscription_streams(
&ctx.with_selection_set(&inline_fragment.node.selection_set),
root,
streams,
)?;
} }
} }
} }
Ok(()) }
}) Ok(())
} }
#[async_trait::async_trait]
impl<T: SubscriptionType + Send + Sync> SubscriptionType for &T { impl<T: SubscriptionType + Send + Sync> SubscriptionType for &T {
async fn create_field_stream( fn create_field_stream<'a>(
&self, &'a self,
ctx: &Context<'_>, ctx: &'a Context<'a>,
schema_env: SchemaEnv, ) -> Pin<Box<dyn Stream<Item = Result<serde_json::Value>> + Send + 'a>> {
query_env: QueryEnv, T::create_field_stream(*self, ctx)
) -> Result<Pin<Box<dyn Stream<Item = Response> + Send>>> {
T::create_field_stream(*self, ctx, schema_env, query_env).await
} }
} }

View File

@ -1,8 +1,7 @@
use crate::context::QueryEnv;
use crate::{ use crate::{
registry, Context, Error, Pos, QueryError, Response, Result, SchemaEnv, SubscriptionType, Type, registry, Context, Error, Pos, QueryError, Result, SubscriptionType, Type,
}; };
use futures::Stream; use futures::{stream, Stream};
use std::borrow::Cow; use std::borrow::Cow;
use std::pin::Pin; use std::pin::Pin;
@ -29,25 +28,24 @@ impl Type for EmptySubscription {
} }
} }
#[async_trait::async_trait]
impl SubscriptionType for EmptySubscription { impl SubscriptionType for EmptySubscription {
fn is_empty() -> bool { fn is_empty() -> bool {
true true
} }
async fn create_field_stream( fn create_field_stream<'a>(
&self, &'a self,
_ctx: &Context<'_>, _ctx: &'a Context<'a>,
_schema_env: SchemaEnv, ) -> Pin<Box<dyn Stream<Item = Result<serde_json::Value>> + Send + 'a>>
_query_env: QueryEnv,
) -> Result<Pin<Box<dyn Stream<Item = Response> + Send>>>
where where
Self: Send + Sync + 'static + Sized, Self: Send + Sync + 'static + Sized,
{ {
Err(Error::Query { Box::pin(stream::once(async {
pos: Pos::default(), Err(Error::Query {
path: None, pos: Pos::default(),
err: QueryError::NotConfiguredSubscriptions, path: None,
}) err: QueryError::NotConfiguredSubscriptions,
})
}))
} }
} }

View File

@ -3,10 +3,10 @@ use crate::registry::{MetaType, Registry};
use crate::resolver_utils::{resolve_object, ObjectType}; use crate::resolver_utils::{resolve_object, ObjectType};
use crate::{ use crate::{
CacheControl, Context, ContextSelectionSet, Error, GQLSimpleObject, GQLSubscription, CacheControl, Context, ContextSelectionSet, Error, GQLSimpleObject, GQLSubscription,
OutputValueType, Positioned, QueryEnv, QueryError, Response, Result, SchemaEnv, OutputValueType, Positioned, QueryError, Result,
SubscriptionType, Type, SubscriptionType, Type,
}; };
use futures::Stream; use futures::{stream, Stream, StreamExt, future::Either};
use indexmap::IndexMap; use indexmap::IndexMap;
use std::borrow::Cow; use std::borrow::Cow;
use std::pin::Pin; use std::pin::Pin;
@ -101,34 +101,31 @@ where
} }
} }
#[async_trait::async_trait]
impl<A, B> SubscriptionType for MergedObject<A, B> impl<A, B> SubscriptionType for MergedObject<A, B>
where where
A: SubscriptionType + Send + Sync, A: SubscriptionType + Send + Sync,
B: SubscriptionType + Send + Sync, B: SubscriptionType + Send + Sync,
{ {
async fn create_field_stream( fn create_field_stream<'a>(
&self, &'a self,
ctx: &Context<'_>, ctx: &'a Context<'a>,
schema_env: SchemaEnv, ) -> Pin<Box<dyn Stream<Item = Result<serde_json::Value>> + Send + 'a>> {
query_env: QueryEnv, let left_stream = self.0.create_field_stream(ctx);
) -> Result<Pin<Box<dyn Stream<Item = Response> + Send>>> { let mut right_stream = Some(self.1.create_field_stream(ctx));
match self Box::pin(
.0 left_stream
.create_field_stream(ctx, schema_env.clone(), query_env.clone()) .flat_map(move |res| {
.await match res {
{ Err(Error::Query {
Ok(value) => Ok(value), err: QueryError::FieldNotFound { .. },
Err(Error::Query { ..
err: QueryError::FieldNotFound { .. }, }) if right_stream.is_some() => {
.. Either::Right(right_stream.take().unwrap())
}) => { }
self.1 other => Either::Left(stream::once(async { other })),
.create_field_stream(ctx, schema_env, query_env) }
.await })
} )
Err(err) => Err(err),
}
} }
} }

View File

@ -126,7 +126,7 @@ pub async fn test_field_features() {
let mut stream = schema.execute_stream("subscription { values }").boxed(); let mut stream = schema.execute_stream("subscription { values }").boxed();
assert_eq!( assert_eq!(
stream.next().await.map(|resp| resp.data), stream.next().await.map(|resp| resp.into_result().unwrap().data),
Some(serde_json::json!({ Some(serde_json::json!({
"values": 10 "values": 10
})) }))