Rework subscriptions

The main change in this commit is changing the return type of
SubscriptionType::create_stream from
Future<Result<Stream<Result<Response>>>> to just Stream<Result<Json>>. I
also allowed the returned stream to borrow from self and context.
This commit is contained in:
Koxiaet 2020-09-13 18:52:36 +01:00
parent 91ec3486ce
commit d404e756bc
12 changed files with 265 additions and 239 deletions

View File

@ -82,10 +82,18 @@ pub fn generate(object_args: &args::Object, input: &DeriveInput) -> Result<Token
}
#[allow(clippy::all, clippy::pedantic)]
#[#crate_name::async_trait::async_trait]
impl #crate_name::SubscriptionType for #ident {
async fn create_field_stream(&self, ctx: &#crate_name::Context<'_>, schema_env: #crate_name::SchemaEnv, query_env: #crate_name::QueryEnv) -> #crate_name::Result<::std::pin::Pin<Box<dyn #crate_name::futures::Stream<Item = #crate_name::Response> + Send>>> {
#create_merged_obj.create_field_stream(ctx, schema_env, query_env).await
fn create_field_stream<'a>(
&'a self,
ctx: &'a #crate_name::Context<'a>
) -> ::std::pin::Pin<::std::boxed::Box<dyn #crate_name::futures::Stream<Item = #crate_name::Result<#crate_name::serde_json::Value>> + Send + 'a>> {
::std::boxed::Box::pin(#crate_name::async_stream::stream! {
let obj = #create_merged_obj;
let mut stream = obj.create_field_stream(ctx);
while let Some(item) = #crate_name::futures::stream::StreamExt::next(&mut stream).await {
yield item;
}
})
}
}
};

View File

@ -233,8 +233,11 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
});
let create_field_stream = quote! {
#crate_name::futures::stream::StreamExt::fuse(self.#ident(ctx, #(#use_params),*).await.
map_err(|err| err.into_error_with_path(ctx.item.pos, ctx.path_node.as_ref()))?)
self.#ident(ctx, #(#use_params),*)
.await
.map_err(|err| {
err.into_error_with_path(ctx.item.pos, ctx.path_node.as_ref())
})?
};
let guard = field.guard.map(|guard| quote! {
@ -247,58 +250,65 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
));
}
create_stream.push(quote! {
if ctx.item.node.name.node == #field_name {
#(#get_params)*
#guard
let field_name = ::std::sync::Arc::new(ctx.item.node.response_key().node.clone());
let field = ::std::sync::Arc::new(ctx.item.clone());
let stream_fn = quote! {
#(#get_params)*
#guard
let field_name = ::std::sync::Arc::new(ctx.item.node.response_key().node.clone());
let field = ::std::sync::Arc::new(ctx.item.clone());
let pos = ctx.item.pos;
let schema_env = schema_env.clone();
let query_env = query_env.clone();
let stream = #crate_name::futures::StreamExt::then(#create_field_stream, {
let pos = ctx.item.pos;
let schema_env = ctx.schema_env.clone();
let query_env = ctx.query_env.clone();
let stream = #crate_name::futures::StreamExt::then(#create_field_stream, {
let field_name = field_name.clone();
move |msg| {
let schema_env = schema_env.clone();
let query_env = query_env.clone();
let field = field.clone();
let field_name = field_name.clone();
move |msg| {
let schema_env = schema_env.clone();
let query_env = query_env.clone();
let field = field.clone();
let field_name = field_name.clone();
async move {
let resolve_id = ::std::sync::atomic::AtomicUsize::default();
let ctx_selection_set = query_env.create_context(
&schema_env,
Some(#crate_name::QueryPathNode {
parent: None,
segment: #crate_name::QueryPathSegment::Name(&field_name),
}),
&field.node.selection_set,
&resolve_id,
);
#crate_name::OutputValueType::resolve(&msg, &ctx_selection_set, &*field).await
}
async move {
let resolve_id = ::std::sync::atomic::AtomicUsize::default();
let ctx_selection_set = query_env.create_context(
&schema_env,
Some(#crate_name::QueryPathNode {
parent: None,
segment: #crate_name::QueryPathSegment::Name(&field_name),
}),
&field.node.selection_set,
&resolve_id,
);
#crate_name::OutputValueType::resolve(&msg, &ctx_selection_set, &*field)
.await
.map(|value| {
#crate_name::serde_json::json!({
field_name.as_str(): value
})
})
}
});
let stream = #crate_name::futures::TryStreamExt::map_ok(stream, move |value| #crate_name::serde_json::json!({ field_name.as_str(): value }));
let stream = #crate_name::futures::StreamExt::scan(stream, true, |state, item| {
if !*state {
}
});
#crate_name::Result::Ok(#crate_name::futures::StreamExt::scan(
stream,
false,
|errored, item| {
if *errored {
return #crate_name::futures::future::ready(None);
}
let resp = match item {
Ok(value) => #crate_name::Response {
data: value,
extensions: None,
cache_control: Default::default(),
error: None,
},
Err(err) => err.into(),
};
if resp.is_err() {
*state = false;
if item.is_err() {
*errored = true;
}
#crate_name::futures::future::ready(Some(resp))
});
return Ok(Box::pin(stream));
#crate_name::futures::future::ready(Some(item))
},
))
};
create_stream.push(quote! {
if ctx.item.node.name.node == #field_name {
return ::std::boxed::Box::pin(
#crate_name::futures::TryStreamExt::try_flatten(
#crate_name::futures::stream::once((move || async move { #stream_fn })())
)
);
}
});
}
@ -333,7 +343,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
#(#schema_fields)*
fields
},
cache_control: Default::default(),
cache_control: ::std::default::Default::default(),
extends: false,
keys: None,
})
@ -341,20 +351,19 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
}
#[allow(clippy::all, clippy::pedantic)]
#[#crate_name::async_trait::async_trait]
#[allow(unused_braces, unused_variables)]
impl #crate_name::SubscriptionType for #self_ty #where_clause {
async fn create_field_stream(
&self,
ctx: &#crate_name::Context<'_>,
schema_env: #crate_name::SchemaEnv,
query_env: #crate_name::QueryEnv,
) -> #crate_name::Result<::std::pin::Pin<Box<dyn #crate_name::futures::Stream<Item = #crate_name::Response> + Send>>> {
fn create_field_stream<'a>(
&'a self,
ctx: &'a #crate_name::Context<'a>,
) -> ::std::pin::Pin<::std::boxed::Box<dyn #crate_name::futures::Stream<Item = #crate_name::Result<#crate_name::serde_json::Value>> + Send + 'a>> {
#(#create_stream)*
Err(#crate_name::QueryError::FieldNotFound {
let error = #crate_name::QueryError::FieldNotFound {
field_name: ctx.item.node.name.to_string(),
object: #gql_typename.to_string(),
}.into_error(ctx.item.pos))
}
.into_error(ctx.item.pos);
::std::boxed::Box::pin(#crate_name::futures::stream::once(async { Err(error) }))
}
}
};

View File

@ -232,8 +232,10 @@ pub struct ContextBase<'a, T> {
pub(crate) inc_resolve_id: &'a AtomicUsize,
#[doc(hidden)]
pub item: T,
pub(crate) schema_env: &'a SchemaEnv,
pub(crate) query_env: &'a QueryEnv,
#[doc(hidden)]
pub schema_env: &'a SchemaEnv,
#[doc(hidden)]
pub query_env: &'a QueryEnv,
}
#[doc(hidden)]

View File

@ -140,9 +140,9 @@ impl<T: AsyncRead> Stream for ReaderStream<T> {
let this = self.project();
Poll::Ready(
match futures::ready!(this.reader.poll_read(cx, &mut self.buf)?) {
match futures::ready!(this.reader.poll_read(cx, this.buf)?) {
0 => None,
size => Some(Ok(Bytes::copy_from_slice(&self.buf[..size]))),
size => Some(Ok(Bytes::copy_from_slice(&this.buf[..size]))),
}
)
}

View File

@ -123,6 +123,8 @@ pub use async_graphql_parser as parser;
#[doc(hidden)]
pub use async_trait;
#[doc(hidden)]
pub use async_stream;
#[doc(hidden)]
pub use futures;
#[doc(hidden)]
pub use indexmap;

View File

@ -12,7 +12,7 @@ use std::pin::Pin;
/// `OutputValueType::resolve` implementation.
#[async_trait::async_trait]
pub trait ObjectType: OutputValueType {
/// This function returns true of type `EmptyMutation` only
/// This function returns true of type `EmptyMutation` only.
#[doc(hidden)]
fn is_empty() -> bool {
false
@ -151,6 +151,7 @@ impl<'a> Fields<'a> {
}
self.0.push(Box::pin({
// TODO: investigate removing this
let ctx = ctx.clone();
async move {
let ctx_field = ctx.with_field(field);

View File

@ -1,7 +1,7 @@
use crate::{CacheControl, Error, Result};
/// Query response
#[derive(Debug)]
#[derive(Debug, Default)]
pub struct Response {
/// Data of query result
pub data: serde_json::Value,
@ -17,6 +17,51 @@ pub struct Response {
}
impl Response {
/// Create a new successful response with the data.
#[must_use]
pub fn new(data: impl Into<serde_json::Value>) -> Self {
Self {
data: data.into(),
..Default::default()
}
}
/// Create a response from the error.
#[must_use]
pub fn from_error(error: impl Into<Error>) -> Self {
Self {
error: Some(error.into()),
..Default::default()
}
}
/// Create a response from the result of the data and an error.
#[must_use]
pub fn from_result(result: Result<serde_json::Value>) -> Self {
match result {
Ok(data) => Self::new(data),
Err(e) => Self::from_error(e),
}
}
/// Set the extensions result of the response.
#[must_use]
pub fn extensions(self, extensions: Option<serde_json::Value>) -> Self {
Self {
extensions,
..self
}
}
/// Set the cache control of the response.
#[must_use]
pub fn cache_control(self, cache_control: CacheControl) -> Self {
Self {
cache_control,
..self
}
}
/// Returns `true` if the response is ok.
#[inline]
pub fn is_ok(&self) -> bool {
@ -29,7 +74,8 @@ impl Response {
self.error.is_some()
}
/// Convert response to `Result<Response>`.
/// Extract the error from the response. Only if the `error` field is `None` will this return
/// `Ok`.
#[inline]
pub fn into_result(self) -> Result<Self> {
if self.is_err() {
@ -42,11 +88,6 @@ impl Response {
impl From<Error> for Response {
fn from(err: Error) -> Self {
Self {
data: serde_json::Value::Null,
extensions: None,
cache_control: CacheControl::default(),
error: Some(err),
}
Self::from_error(err)
}
}

View File

@ -5,7 +5,7 @@ use crate::parser::parse_query;
use crate::parser::types::OperationType;
use crate::registry::{MetaDirective, MetaInputValue, Registry};
use crate::resolver_utils::{resolve_object, resolve_object_serial, ObjectType};
use crate::subscription::create_subscription_stream;
use crate::subscription::collect_subscription_streams;
use crate::types::QueryRoot;
use crate::validation::{check_rules, CheckResult, ValidationMode};
use crate::{
@ -13,7 +13,7 @@ use crate::{
SubscriptionType, Type, Variables, ID,
};
use async_graphql_parser::types::ExecutableDocumentData;
use futures::{Stream, StreamExt};
use futures::stream::{self, Stream, StreamExt};
use indexmap::map::IndexMap;
use itertools::Itertools;
use std::any::Any;
@ -21,15 +21,6 @@ use std::ops::Deref;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
macro_rules! try_query_result {
($res:expr) => {
match $res {
Ok(resp) => resp,
Err(err) => return err.into(),
}
};
}
/// Schema builder
pub struct SchemaBuilder<Query, Mutation, Subscription> {
validation_mode: ValidationMode,
@ -411,10 +402,11 @@ where
};
env.extensions.lock().execution_start();
let data = match &env.document.operation.node.ty {
OperationType::Query => try_query_result!(resolve_object(&ctx, &self.query).await),
OperationType::Query => resolve_object(&ctx, &self.query).await,
OperationType::Mutation => {
try_query_result!(resolve_object_serial(&ctx, &self.mutation).await)
resolve_object_serial(&ctx, &self.mutation).await
}
OperationType::Subscription => {
return Error::Query {
@ -428,24 +420,21 @@ where
env.extensions.lock().execution_end();
let extensions = env.extensions.lock().result();
Response {
data,
extensions,
cache_control: Default::default(),
error: None,
}
Response::from_result(data)
.extensions(extensions)
}
/// Execute an GraphQL query.
pub async fn execute(&self, request: impl Into<Request>) -> Response {
let request = request.into();
let (document, cache_control, extensions) =
try_query_result!(self.prepare_request(&request));
let mut resp = self
.execute_once(document, extensions, request.variables, request.data)
.await;
resp.cache_control = cache_control;
resp
match self.prepare_request(&request) {
Ok((document, cache_control, extensions)) => self
.execute_once(document, extensions, request.variables, request.data)
.await
.cache_control(cache_control),
Err(e) => Response::from_error(e),
}
}
pub(crate) fn execute_stream_with_ctx_data(
@ -454,9 +443,10 @@ where
ctx_data: Arc<Data>,
) -> impl Stream<Item = Response> {
let schema = self.clone();
async_stream::stream! {
let request = request.into();
let (document, cache_control, extensions) = match schema.prepare_request(& request) {
let (document, cache_control, extensions) = match schema.prepare_request(&request) {
Ok(res) => res,
Err(err) => {
yield Response::from(err);
@ -465,11 +455,10 @@ where
};
if document.operation.node.ty != OperationType::Subscription {
let mut resp = schema
yield schema
.execute_once(document, extensions, request.variables, request.data)
.await;
resp.cache_control = cache_control;
yield resp;
.await
.cache_control(cache_control);
return;
}
@ -488,17 +477,19 @@ where
&resolve_id,
);
let mut streams = Vec::new();
// TODO: Invoke extensions
if let Err(err) = create_subscription_stream(&schema, env.clone(), &ctx, &mut streams).await {
yield err.into();
let mut streams = Vec::new();
if let Err(e) = collect_subscription_streams(&ctx, &schema.subscription, &mut streams) {
yield Response::from(e);
return;
}
let mut stream = futures::stream::select_all(streams);
while let Some(resp) = stream.next().await {
let is_err = resp.is_err();
yield resp;
let mut stream = stream::select_all(streams);
while let Some(data) = stream.next().await {
let is_err = data.is_err();
let extensions = env.extensions.lock().result();
yield Response::from_result(data).extensions(extensions);
if is_err {
break;
}

View File

@ -1,112 +1,89 @@
use crate::context::QueryEnv;
use crate::parser::types::{Selection, TypeCondition};
use crate::{Context, ContextSelectionSet, Response, Result, Schema, SchemaEnv, Type};
use futures::{Future, Stream};
use crate::{Context, ContextSelectionSet, Result, Type};
use futures::{Stream, StreamExt};
use std::pin::Pin;
/// Represents a GraphQL subscription object
#[async_trait::async_trait]
pub trait SubscriptionType: Type {
/// This function returns true of type `EmptySubscription` only
/// This function returns true of type `EmptySubscription` only.
#[doc(hidden)]
fn is_empty() -> bool {
false
}
#[doc(hidden)]
async fn create_field_stream(
&self,
ctx: &Context<'_>,
schema_env: SchemaEnv,
query_env: QueryEnv,
) -> Result<Pin<Box<dyn Stream<Item = Response> + Send>>>;
fn create_field_stream<'a>(
&'a self,
ctx: &'a Context<'a>,
) -> Pin<Box<dyn Stream<Item = Result<serde_json::Value>> + Send + 'a>>;
}
type BoxCreateStreamFuture<'a> = Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
pub(crate) fn create_subscription_stream<'a, Query, Mutation, Subscription>(
schema: &'a Schema<Query, Mutation, Subscription>,
environment: QueryEnv,
ctx: &'a ContextSelectionSet<'_>,
streams: &'a mut Vec<Pin<Box<dyn Stream<Item = Response> + Send>>>,
) -> BoxCreateStreamFuture<'a>
where
Query: Send + Sync,
Mutation: Send + Sync,
Subscription: SubscriptionType + Send + Sync + 'static + Sized,
{
Box::pin(async move {
for selection in &ctx.item.node.items {
if ctx.is_skip(selection.node.directives())? {
continue;
}
match &selection.node {
Selection::Field(field) => streams.push(
schema
.subscription
.create_field_stream(
&ctx.with_field(field),
schema.env.clone(),
environment.clone(),
)
.await?,
),
Selection::FragmentSpread(fragment_spread) => {
if let Some(fragment) = ctx
.query_env
.document
.fragments
.get(&fragment_spread.node.fragment_name.node)
{
create_subscription_stream(
schema,
environment.clone(),
&ctx.with_selection_set(&fragment.node.selection_set),
streams,
)
.await?;
pub(crate) fn collect_subscription_streams<'a, T: SubscriptionType + Send + Sync + 'static>(
ctx: &ContextSelectionSet<'a>,
root: &'a T,
streams: &mut Vec<Pin<Box<dyn Stream<Item = Result<serde_json::Value>> + Send + 'a>>>,
) -> Result<()> {
for selection in &ctx.item.node.items {
if ctx.is_skip(selection.node.directives())? {
continue;
}
match &selection.node {
Selection::Field(field) => streams.push(Box::pin({
let ctx = ctx.clone();
async_stream::stream! {
let ctx = ctx.with_field(field);
let mut stream = root.create_field_stream(&ctx);
while let Some(item) = stream.next().await {
yield item;
}
}
Selection::InlineFragment(inline_fragment) => {
if let Some(TypeCondition { on: name }) = inline_fragment
.node
.type_condition
.as_ref()
.map(|v| &v.node)
{
if name.node.as_str() == Subscription::type_name() {
create_subscription_stream(
schema,
environment.clone(),
&ctx.with_selection_set(&inline_fragment.node.selection_set),
streams,
)
.await?;
}
} else {
create_subscription_stream(
schema,
environment.clone(),
})),
Selection::FragmentSpread(fragment_spread) => {
if let Some(fragment) = ctx
.query_env
.document
.fragments
.get(&fragment_spread.node.fragment_name.node)
{
collect_subscription_streams(
&ctx.with_selection_set(&fragment.node.selection_set),
root,
streams,
)?;
}
}
Selection::InlineFragment(inline_fragment) => {
if let Some(TypeCondition { on: name }) = inline_fragment
.node
.type_condition
.as_ref()
.map(|v| &v.node)
{
if name.node.as_str() == T::type_name() {
collect_subscription_streams(
&ctx.with_selection_set(&inline_fragment.node.selection_set),
root,
streams,
)
.await?;
)?;
}
} else {
collect_subscription_streams(
&ctx.with_selection_set(&inline_fragment.node.selection_set),
root,
streams,
)?;
}
}
}
Ok(())
})
}
Ok(())
}
#[async_trait::async_trait]
impl<T: SubscriptionType + Send + Sync> SubscriptionType for &T {
async fn create_field_stream(
&self,
ctx: &Context<'_>,
schema_env: SchemaEnv,
query_env: QueryEnv,
) -> Result<Pin<Box<dyn Stream<Item = Response> + Send>>> {
T::create_field_stream(*self, ctx, schema_env, query_env).await
fn create_field_stream<'a>(
&'a self,
ctx: &'a Context<'a>,
) -> Pin<Box<dyn Stream<Item = Result<serde_json::Value>> + Send + 'a>> {
T::create_field_stream(*self, ctx)
}
}

View File

@ -1,8 +1,7 @@
use crate::context::QueryEnv;
use crate::{
registry, Context, Error, Pos, QueryError, Response, Result, SchemaEnv, SubscriptionType, Type,
registry, Context, Error, Pos, QueryError, Result, SubscriptionType, Type,
};
use futures::Stream;
use futures::{stream, Stream};
use std::borrow::Cow;
use std::pin::Pin;
@ -29,25 +28,24 @@ impl Type for EmptySubscription {
}
}
#[async_trait::async_trait]
impl SubscriptionType for EmptySubscription {
fn is_empty() -> bool {
true
}
async fn create_field_stream(
&self,
_ctx: &Context<'_>,
_schema_env: SchemaEnv,
_query_env: QueryEnv,
) -> Result<Pin<Box<dyn Stream<Item = Response> + Send>>>
fn create_field_stream<'a>(
&'a self,
_ctx: &'a Context<'a>,
) -> Pin<Box<dyn Stream<Item = Result<serde_json::Value>> + Send + 'a>>
where
Self: Send + Sync + 'static + Sized,
{
Err(Error::Query {
pos: Pos::default(),
path: None,
err: QueryError::NotConfiguredSubscriptions,
})
Box::pin(stream::once(async {
Err(Error::Query {
pos: Pos::default(),
path: None,
err: QueryError::NotConfiguredSubscriptions,
})
}))
}
}

View File

@ -3,10 +3,10 @@ use crate::registry::{MetaType, Registry};
use crate::resolver_utils::{resolve_object, ObjectType};
use crate::{
CacheControl, Context, ContextSelectionSet, Error, GQLSimpleObject, GQLSubscription,
OutputValueType, Positioned, QueryEnv, QueryError, Response, Result, SchemaEnv,
OutputValueType, Positioned, QueryError, Result,
SubscriptionType, Type,
};
use futures::Stream;
use futures::{stream, Stream, StreamExt, future::Either};
use indexmap::IndexMap;
use std::borrow::Cow;
use std::pin::Pin;
@ -101,34 +101,31 @@ where
}
}
#[async_trait::async_trait]
impl<A, B> SubscriptionType for MergedObject<A, B>
where
A: SubscriptionType + Send + Sync,
B: SubscriptionType + Send + Sync,
{
async fn create_field_stream(
&self,
ctx: &Context<'_>,
schema_env: SchemaEnv,
query_env: QueryEnv,
) -> Result<Pin<Box<dyn Stream<Item = Response> + Send>>> {
match self
.0
.create_field_stream(ctx, schema_env.clone(), query_env.clone())
.await
{
Ok(value) => Ok(value),
Err(Error::Query {
err: QueryError::FieldNotFound { .. },
..
}) => {
self.1
.create_field_stream(ctx, schema_env, query_env)
.await
}
Err(err) => Err(err),
}
fn create_field_stream<'a>(
&'a self,
ctx: &'a Context<'a>,
) -> Pin<Box<dyn Stream<Item = Result<serde_json::Value>> + Send + 'a>> {
let left_stream = self.0.create_field_stream(ctx);
let mut right_stream = Some(self.1.create_field_stream(ctx));
Box::pin(
left_stream
.flat_map(move |res| {
match res {
Err(Error::Query {
err: QueryError::FieldNotFound { .. },
..
}) if right_stream.is_some() => {
Either::Right(right_stream.take().unwrap())
}
other => Either::Left(stream::once(async { other })),
}
})
)
}
}

View File

@ -126,7 +126,7 @@ pub async fn test_field_features() {
let mut stream = schema.execute_stream("subscription { values }").boxed();
assert_eq!(
stream.next().await.map(|resp| resp.data),
stream.next().await.map(|resp| resp.into_result().unwrap().data),
Some(serde_json::json!({
"values": 10
}))