From d300e27fa2397d75d32437baf83e7a9d03c95e2f Mon Sep 17 00:00:00 2001 From: Sunli Date: Tue, 19 May 2020 16:24:15 +0800 Subject: [PATCH] Add Deferred type and @defer directive. #51 --- async-graphql-derive/src/subscription.rs | 32 ++--- async-graphql-parser/src/query.rs | 34 +++--- src/base.rs | 1 + src/context.rs | 125 +++++++++++++------- src/http/mod.rs | 1 + src/lib.rs | 8 +- src/mutation_resolver.rs | 2 + src/query.rs | 137 +++++++++++++++++---- src/resolver.rs | 2 + src/schema.rs | 144 +++++++++++++++-------- src/subscription/subscription_type.rs | 19 ++- src/subscription/ws_transport.rs | 1 + src/types/deferred.rs | 61 ++++++++++ src/types/empty_subscription.rs | 15 +-- src/types/mod.rs | 2 + src/types/query_root.rs | 9 +- src/validation/test_harness.rs | 4 +- tests/defer.rs | 26 ++++ 18 files changed, 444 insertions(+), 179 deletions(-) create mode 100644 src/types/deferred.rs create mode 100644 tests/defer.rs diff --git a/async-graphql-derive/src/subscription.rs b/async-graphql-derive/src/subscription.rs index 9beeef12..dd50dd38 100644 --- a/async-graphql-derive/src/subscription.rs +++ b/async-graphql-derive/src/subscription.rs @@ -242,35 +242,29 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< #(#get_params)* #guard let field_name = std::sync::Arc::new(ctx.result_name().to_string()); + let field_selection_set = std::sync::Arc::new(ctx.selection_set.clone()); - // I think the code here is safe because the lifetime of selection_set is always less than the environment. - let field_selection_set = unsafe { - (&ctx.selection_set - as *const #crate_name::Positioned<#crate_name::parser::query::SelectionSet>) - .as_ref() - .unwrap() - }; - - let schema = schema.clone(); let pos = ctx.position(); - let environment = environment.clone(); + let schema_env = schema_env.clone(); + let query_env = query_env.clone(); let stream = #create_field_stream.then({ let field_name = field_name.clone(); move |msg| { - let environment = environment.clone(); + let schema_env = schema_env.clone(); + let query_env = query_env.clone(); let field_selection_set = field_selection_set.clone(); - let schema = schema.clone(); let field_name = field_name.clone(); async move { let resolve_id = std::sync::atomic::AtomicUsize::default(); - let ctx_selection_set = environment.create_context( - &schema, + let ctx_selection_set = query_env.create_context( + &schema_env, Some(#crate_name::QueryPathNode { parent: None, segment: #crate_name::QueryPathSegment::Name(&field_name), }), - field_selection_set, + &*field_selection_set, &resolve_id, + None, ); #crate_name::OutputValueType::resolve(&msg, &ctx_selection_set, pos).await } @@ -331,16 +325,14 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< impl #crate_name::SubscriptionType for #self_ty #where_clause { #[allow(unused_variables)] #[allow(bare_trait_objects)] - async fn create_field_stream( + async fn create_field_stream( &self, idx: usize, ctx: &#crate_name::Context<'_>, - schema: &#crate_name::Schema, - environment: std::sync::Arc<#crate_name::Environment>, + schema_env: #crate_name::SchemaEnv, + query_env: #crate_name::QueryEnv, ) -> #crate_name::Result> + Send>>> where - Query: #crate_name::ObjectType + Send + Sync + 'static, - Mutation: #crate_name::ObjectType + Send + Sync + 'static, Self: Send + Sync + 'static + Sized, { use #crate_name::futures::StreamExt; diff --git a/async-graphql-parser/src/query.rs b/async-graphql-parser/src/query.rs index 730918a9..2657e3d5 100644 --- a/async-graphql-parser/src/query.rs +++ b/async-graphql-parser/src/query.rs @@ -3,7 +3,7 @@ use crate::value::Value; use std::collections::HashMap; use std::fmt; -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] pub enum Type { Named(String), List(Box), @@ -20,7 +20,7 @@ impl fmt::Display for Type { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Directive { pub name: Positioned, pub arguments: Vec<(Positioned, Positioned)>, @@ -44,14 +44,14 @@ pub enum OperationType { Subscription, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct CurrentOperation { pub ty: OperationType, pub variable_definitions: Vec>, pub selection_set: Positioned, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Document { pub(crate) definitions: Vec>, pub(crate) fragments: FragmentsMap, @@ -140,18 +140,18 @@ impl Document { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum Definition { Operation(Positioned), Fragment(Positioned), } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum TypeCondition { On(Positioned), } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct FragmentDefinition { pub name: Positioned, pub type_condition: Positioned, @@ -159,7 +159,7 @@ pub struct FragmentDefinition { pub selection_set: Positioned, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum OperationDefinition { SelectionSet(Positioned), Query(Positioned), @@ -167,7 +167,7 @@ pub enum OperationDefinition { Subscription(Positioned), } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Query { pub name: Option>, pub variable_definitions: Vec>, @@ -175,7 +175,7 @@ pub struct Query { pub selection_set: Positioned, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Mutation { pub name: Option>, pub variable_definitions: Vec>, @@ -183,7 +183,7 @@ pub struct Mutation { pub selection_set: Positioned, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Subscription { pub name: Option>, pub variable_definitions: Vec>, @@ -191,26 +191,26 @@ pub struct Subscription { pub selection_set: Positioned, } -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct SelectionSet { pub items: Vec>, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct VariableDefinition { pub name: Positioned, pub var_type: Positioned, pub default_value: Option>, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum Selection { Field(Positioned), FragmentSpread(Positioned), InlineFragment(Positioned), } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Field { pub alias: Option>, pub name: Positioned, @@ -228,13 +228,13 @@ impl Field { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct FragmentSpread { pub fragment_name: Positioned, pub directives: Vec>, } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct InlineFragment { pub type_condition: Option>, pub directives: Vec>, diff --git a/src/base.rs b/src/base.rs index d416b4a0..164f9ac7 100644 --- a/src/base.rs +++ b/src/base.rs @@ -90,6 +90,7 @@ pub trait ObjectType: OutputValueType { { if name == Self::type_name().as_ref() || ctx + .schema_env .registry .implements .get(Self::type_name().as_ref()) diff --git a/src/context.rs b/src/context.rs index 1ed4f8c3..bfc2c096 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,14 +1,19 @@ use crate::extensions::BoxExtension; use crate::parser::query::{Directive, Field, SelectionSet}; -use crate::registry::Registry; -use crate::{InputValueType, Lookahead, Pos, Positioned, QueryError, Result, Schema, Type, Value}; +use crate::schema::SchemaEnv; +use crate::{ + InputValueType, Lookahead, Pos, Positioned, QueryError, QueryResponse, Result, Type, Value, +}; use async_graphql_parser::query::Document; use async_graphql_parser::UploadValue; use fnv::FnvHashMap; +use futures::Future; +use parking_lot::Mutex; use std::any::{Any, TypeId}; use std::collections::BTreeMap; use std::fs::File; use std::ops::{Deref, DerefMut}; +use std::pin::Pin; use std::sync::atomic::AtomicUsize; use std::sync::Arc; @@ -232,6 +237,27 @@ impl std::fmt::Display for ResolveId { } } +#[doc(hidden)] +pub type BoxDeferFuture = + Pin> + Send + 'static>>; + +#[doc(hidden)] +#[derive(Default)] +pub struct DeferList(pub Mutex>); + +impl DeferList { + pub(crate) fn into_inner(self) -> Vec { + self.0.into_inner() + } + + pub(crate) fn append(&self, fut: F) + where + F: Future> + Send + 'static, + { + self.0.lock().push(Box::pin(fut)); + } +} + /// Query context #[derive(Clone)] pub struct ContextBase<'a, T> { @@ -241,11 +267,9 @@ pub struct ContextBase<'a, T> { pub(crate) inc_resolve_id: &'a AtomicUsize, pub(crate) extensions: &'a [BoxExtension], pub(crate) item: T, - pub(crate) variables: &'a Variables, - pub(crate) registry: &'a Registry, - pub(crate) data: &'a Data, - pub(crate) ctx_data: Option<&'a Data>, - pub(crate) document: &'a Document, + pub(crate) schema_env: &'a SchemaEnv, + pub(crate) query_env: &'a QueryEnv, + pub(crate) defer_list: Option<&'a DeferList>, } impl<'a, T> Deref for ContextBase<'a, T> { @@ -257,20 +281,42 @@ impl<'a, T> Deref for ContextBase<'a, T> { } #[doc(hidden)] -pub struct Environment { +pub struct QueryEnvInner { pub variables: Variables, - pub document: Box, + pub document: Document, pub ctx_data: Arc, } -impl Environment { +#[doc(hidden)] +#[derive(Clone)] +pub struct QueryEnv(Arc); + +impl Deref for QueryEnv { + type Target = QueryEnvInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl QueryEnv { #[doc(hidden)] - pub fn create_context<'a, T, Query, Mutation, Subscription>( + pub fn new(variables: Variables, document: Document, ctx_data: Arc) -> QueryEnv { + QueryEnv(Arc::new(QueryEnvInner { + variables, + document, + ctx_data, + })) + } + + #[doc(hidden)] + pub fn create_context<'a, T>( &'a self, - schema: &'a Schema, + schema_env: &'a SchemaEnv, path_node: Option>, item: T, inc_resolve_id: &'a AtomicUsize, + defer_list: Option<&'a DeferList>, ) -> ContextBase<'a, T> { ContextBase { path_node, @@ -278,11 +324,9 @@ impl Environment { inc_resolve_id, extensions: &[], item, - variables: &self.variables, - registry: &schema.0.registry, - data: &schema.0.data, - ctx_data: Some(&self.ctx_data), - document: &self.document, + schema_env, + query_env: self, + defer_list, } } } @@ -319,11 +363,9 @@ impl<'a, T> ContextBase<'a, T> { item: field, resolve_id: self.get_child_resolve_id(), inc_resolve_id: self.inc_resolve_id, - variables: self.variables, - registry: self.registry, - data: self.data, - ctx_data: self.ctx_data, - document: self.document, + schema_env: self.schema_env, + query_env: self.query_env, + defer_list: self.defer_list, } } @@ -338,11 +380,9 @@ impl<'a, T> ContextBase<'a, T> { item: selection_set, resolve_id: self.resolve_id, inc_resolve_id: &self.inc_resolve_id, - variables: self.variables, - registry: self.registry, - data: self.data, - ctx_data: self.ctx_data, - document: self.document, + schema_env: self.schema_env, + query_env: self.query_env, + defer_list: self.defer_list, } } @@ -354,21 +394,24 @@ impl<'a, T> ContextBase<'a, T> { /// Gets the global data defined in the `Context` or `Schema`, returns `None` if the specified type data does not exist. pub fn data_opt(&self) -> Option<&D> { - self.ctx_data - .and_then(|ctx_data| ctx_data.0.get(&TypeId::of::())) - .or_else(|| self.data.0.get(&TypeId::of::())) + self.query_env + .ctx_data + .0 + .get(&TypeId::of::()) + .or_else(|| self.schema_env.data.0.get(&TypeId::of::())) .and_then(|d| d.downcast_ref::()) } fn var_value(&self, name: &str, pos: Pos) -> Result { let def = self + .query_env .document .current_operation() .variable_definitions .iter() .find(|def| def.name.node == name); if let Some(def) = def { - if let Some(var_value) = self.variables.get(def.name.as_str()) { + if let Some(var_value) = self.query_env.variables.get(def.name.as_str()) { return Ok(var_value.clone()); } else if let Some(default) = &def.default_value { return Ok(default.clone_inner()); @@ -444,16 +487,16 @@ impl<'a, T> ContextBase<'a, T> { } .into_error(directive.position())); } - } else { - return Err(QueryError::UnknownDirective { - name: directive.name.to_string(), - } - .into_error(directive.position())); } } Ok(false) } + + #[doc(hidden)] + pub fn is_defer(&self, directives: &[Positioned]) -> bool { + directives.iter().any(|d| d.name.node == "defer") + } } impl<'a> ContextBase<'a, &'a Positioned> { @@ -468,11 +511,9 @@ impl<'a> ContextBase<'a, &'a Positioned> { item: self.item, resolve_id: self.get_child_resolve_id(), inc_resolve_id: self.inc_resolve_id, - variables: self.variables, - registry: self.registry, - data: self.data, - ctx_data: self.ctx_data, - document: self.document, + schema_env: self.schema_env, + query_env: self.query_env, + defer_list: self.defer_list, } } } @@ -559,7 +600,7 @@ impl<'a> ContextBase<'a, &'a Positioned> { /// ``` pub fn look_ahead(&self) -> Lookahead { Lookahead { - document: self.document, + document: &self.query_env.document, field: Some(&self.item.node), } } diff --git a/src/http/mod.rs b/src/http/mod.rs index 68842e6d..a5d7bb2b 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -209,6 +209,7 @@ mod tests { #[test] fn test_response_data() { let resp = GQLResponse(Ok(QueryResponse { + path: None, data: json!({"ok": true}), extensions: None, cache_control: Default::default(), diff --git a/src/lib.rs b/src/lib.rs index 52e3032c..ead117fa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -117,7 +117,7 @@ pub mod http; pub use base::{ScalarType, Type}; pub use context::{ - Context, ContextBase, Data, Environment, QueryPathNode, QueryPathSegment, Variables, + Context, ContextBase, Data, QueryEnv, QueryPathNode, QueryPathSegment, Variables, }; pub use error::{ Error, ErrorExtensions, FieldError, FieldResult, InputValueError, InputValueResult, @@ -128,14 +128,14 @@ pub use parser::{Pos, Positioned, Value}; pub use query::{IntoQueryBuilder, IntoQueryBuilderOpts, QueryBuilder, QueryResponse}; pub use registry::CacheControl; pub use scalars::{Any, Json, ID}; -pub use schema::{Schema, SchemaBuilder}; +pub use schema::{Schema, SchemaBuilder, SchemaEnv}; pub use subscription::{ SimpleBroker, SubscriptionStream, SubscriptionStreams, SubscriptionTransport, WebSocketTransport, }; pub use types::{ - Connection, Cursor, DataSource, EmptyEdgeFields, EmptyMutation, EmptySubscription, PageInfo, - QueryOperation, Upload, + Connection, Cursor, DataSource, Deferred, EmptyEdgeFields, EmptyMutation, EmptySubscription, + PageInfo, QueryOperation, Upload, }; pub use validation::ValidationMode; diff --git a/src/mutation_resolver.rs b/src/mutation_resolver.rs index f67d1ada..c8361cb1 100644 --- a/src/mutation_resolver.rs +++ b/src/mutation_resolver.rs @@ -56,6 +56,7 @@ fn do_resolve<'a, T: ObjectType + Send + Sync>( path_node: ctx_field.path_node.as_ref().unwrap(), parent_type: &T::type_name(), return_type: match ctx_field + .schema_env .registry .types .get(T::type_name().as_ref()) @@ -98,6 +99,7 @@ fn do_resolve<'a, T: ObjectType + Send + Sync>( } if let Some(fragment) = ctx + .query_env .document .fragments() .get(fragment_spread.fragment_name.as_str()) diff --git a/src/query.rs b/src/query.rs index 5147c21c..d912ac66 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,17 +1,20 @@ -use crate::context::{Data, ResolveId}; +use crate::context::{Data, DeferList, ResolveId}; use crate::error::ParseRequestError; use crate::mutation_resolver::do_mutation_resolve; use crate::parser::parse_query; use crate::registry::CacheControl; use crate::validation::{check_rules, CheckResult}; use crate::{ - do_resolve, ContextBase, Error, ObjectType, Pos, QueryError, Result, Schema, Variables, + do_resolve, ContextBase, Error, ObjectType, Pos, QueryEnv, QueryError, Result, Schema, + SubscriptionType, Variables, }; use async_graphql_parser::query::OperationType; +use futures::{Stream, StreamExt}; use itertools::Itertools; use std::any::Any; use std::fs::File; use std::sync::atomic::AtomicUsize; +use std::sync::Arc; /// IntoQueryBuilder options #[derive(Default, Clone)] @@ -39,6 +42,9 @@ pub trait IntoQueryBuilder: Sized { /// Query response #[derive(Debug)] pub struct QueryResponse { + /// Path for subsequent response + pub path: Option, + /// Data of query result pub data: serde_json::Value, @@ -49,6 +55,38 @@ pub struct QueryResponse { pub cache_control: CacheControl, } +impl QueryResponse { + pub(crate) fn merge(&mut self, resp: QueryResponse) { + if let Some(serde_json::Value::Array(items)) = resp.path { + let mut p = &mut self.data; + for item in items { + match item { + serde_json::Value::String(name) => { + if let serde_json::Value::Object(obj) = p { + if let Some(next) = obj.get_mut(&name) { + p = next; + continue; + } + } + return; + } + serde_json::Value::Number(idx) => { + if let serde_json::Value::Array(array) = p { + if let Some(next) = array.get_mut(idx.as_i64().unwrap() as usize) { + p = next; + continue; + } + } + return; + } + _ => {} + } + } + *p = resp.data; + } + } +} + /// Query builder pub struct QueryBuilder { pub(crate) query_source: String, @@ -107,18 +145,52 @@ impl QueryBuilder { .set_upload(var_path, filename, content_type, content); } - /// Execute the query. - pub async fn execute( + /// Execute the query, returns a stream, the first result being the query result, + /// followed by the incremental result. Only when there are `@defer` and `@stream` directives + /// in the query will there be subsequent incremental results. + pub async fn execute_stream( self, schema: &Schema, - ) -> Result + ) -> impl Stream> where - Query: ObjectType + Send + Sync, - Mutation: ObjectType + Send + Sync, + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Subscription: SubscriptionType + Send + Sync + 'static, + { + let schema = schema.clone(); + let stream = async_stream::try_stream! { + let (first_resp, defer_list) = self.execute_first(&schema).await?; + yield first_resp; + + let mut current_defer_list = defer_list.into_inner(); + + loop { + let mut new_defer_list = Vec::new(); + for defer in current_defer_list { + let mut res = defer.await?; + new_defer_list.extend((res.1).into_inner()); + yield res.0; + } + if new_defer_list.is_empty() { + break; + } + current_defer_list = new_defer_list; + } + }; + Box::pin(stream) + } + + async fn execute_first<'a, Query, Mutation, Subscription>( + self, + schema: &Schema, + ) -> Result<(QueryResponse, DeferList)> + where + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Subscription: SubscriptionType + Send + Sync + 'static, { // create extension instances let extensions = schema - .0 .extensions .iter() .map(|factory| factory()) @@ -137,17 +209,17 @@ impl QueryBuilder { cache_control, complexity, depth, - } = check_rules(&schema.0.registry, &document, schema.0.validation_mode)?; + } = check_rules(&schema.env.registry, &document, schema.validation_mode)?; extensions.iter().for_each(|e| e.validation_end()); // check limit - if let Some(limit_complexity) = schema.0.complexity { + if let Some(limit_complexity) = schema.complexity { if complexity > limit_complexity { return Err(QueryError::TooComplex.into_error(Pos::default())); } } - if let Some(limit_depth) = schema.0.depth { + if let Some(limit_depth) = schema.depth { if depth > limit_depth { return Err(QueryError::TooDeep.into_error(Pos::default())); } @@ -173,24 +245,28 @@ impl QueryBuilder { }; } + let env = QueryEnv::new( + self.variables, + document, + Arc::new(self.ctx_data.unwrap_or_default()), + ); + let defer_list = DeferList::default(); let ctx = ContextBase { path_node: None, resolve_id: ResolveId::root(), inc_resolve_id: &inc_resolve_id, extensions: &extensions, - item: &document.current_operation().selection_set, - variables: &self.variables, - registry: &schema.0.registry, - data: &schema.0.data, - ctx_data: self.ctx_data.as_ref(), - document: &document, + item: &env.document.current_operation().selection_set, + schema_env: &schema.env, + query_env: &env, + defer_list: Some(&defer_list), }; extensions.iter().for_each(|e| e.execution_start()); - let data = match document.current_operation().ty { - OperationType::Query => do_resolve(&ctx, &schema.0.query).await?, - OperationType::Mutation => do_mutation_resolve(&ctx, &schema.0.mutation).await?, + let data = match &env.document.current_operation().ty { + OperationType::Query => do_resolve(&ctx, &schema.query).await?, + OperationType::Mutation => do_mutation_resolve(&ctx, &schema.mutation).await?, OperationType::Subscription => { return Err(Error::Query { pos: Pos::default(), @@ -203,6 +279,7 @@ impl QueryBuilder { extensions.iter().for_each(|e| e.execution_end()); let res = QueryResponse { + path: None, data, extensions: if !extensions.is_empty() { Some( @@ -222,6 +299,24 @@ impl QueryBuilder { }, cache_control, }; - Ok(res) + Ok((res, defer_list)) + } + + /// Execute the query, always return a complete result. + pub async fn execute( + self, + schema: &Schema, + ) -> Result + where + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Subscription: SubscriptionType + Send + Sync + 'static, + { + let mut stream = self.execute_stream(schema).await; + let mut resp = stream.next().await.unwrap()?; + while let Some(resp_part) = stream.next().await.transpose()? { + resp.merge(resp_part); + } + Ok(resp) } } diff --git a/src/resolver.rs b/src/resolver.rs index bc00437f..e802720d 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -65,6 +65,7 @@ pub fn collect_fields<'a, T: ObjectType + Send + Sync>( path_node: ctx_field.path_node.as_ref().unwrap(), parent_type: &T::type_name(), return_type: match ctx_field + .schema_env .registry .types .get(T::type_name().as_ref()) @@ -113,6 +114,7 @@ pub fn collect_fields<'a, T: ObjectType + Send + Sync>( } if let Some(fragment) = ctx + .query_env .document .fragments() .get(fragment_spread.fragment_name.as_str()) diff --git a/src/schema.rs b/src/schema.rs index 6e98b7aa..488e0d04 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -8,8 +8,8 @@ use crate::subscription::{create_connection, create_subscription_stream, Subscri use crate::types::QueryRoot; use crate::validation::{check_rules, ValidationMode}; use crate::{ - Environment, Error, ObjectType, Pos, QueryError, QueryResponse, Result, SubscriptionStream, - SubscriptionType, Type, Variables, + Error, ObjectType, Pos, QueryEnv, QueryError, QueryResponse, Result, SubscriptionStream, + SubscriptionType, Type, Variables, ID, }; use async_graphql_parser::query::OperationType; use bytes::Bytes; @@ -17,48 +17,47 @@ use futures::channel::mpsc; use futures::Stream; use indexmap::map::IndexMap; use std::any::Any; +use std::ops::Deref; use std::sync::atomic::AtomicUsize; use std::sync::Arc; -pub(crate) struct SchemaInner { - pub(crate) validation_mode: ValidationMode, - pub(crate) query: QueryRoot, - pub(crate) mutation: Mutation, - pub(crate) subscription: Subscription, - pub(crate) registry: Registry, - pub(crate) data: Data, - pub(crate) complexity: Option, - pub(crate) depth: Option, - pub(crate) extensions: Vec BoxExtension + Send + Sync>>, -} - /// Schema builder -pub struct SchemaBuilder(SchemaInner); +pub struct SchemaBuilder { + validation_mode: ValidationMode, + query: QueryRoot, + mutation: Mutation, + subscription: Subscription, + registry: Registry, + data: Data, + complexity: Option, + depth: Option, + extensions: Vec BoxExtension + Send + Sync>>, +} impl SchemaBuilder { /// You can use this function to register types that are not directly referenced. pub fn register_type(mut self) -> Self { - T::create_type_info(&mut self.0.registry); + T::create_type_info(&mut self.registry); self } /// Disable introspection query pub fn disable_introspection(mut self) -> Self { - self.0.query.disable_introspection = true; + self.query.disable_introspection = true; self } /// Set limit complexity, Default no limit. pub fn limit_complexity(mut self, complexity: usize) -> Self { - self.0.complexity = Some(complexity); + self.complexity = Some(complexity); self } /// Set limit complexity, Default no limit. pub fn limit_depth(mut self, depth: usize) -> Self { - self.0.depth = Some(depth); + self.depth = Some(depth); self } @@ -67,34 +66,73 @@ impl mut self, extension_factory: F, ) -> Self { - self.0 - .extensions + self.extensions .push(Box::new(move || Box::new(extension_factory()))); self } /// Add a global data that can be accessed in the `Schema`, you access it with `Context::data`. pub fn data(mut self, data: D) -> Self { - self.0.data.insert(data); + self.data.insert(data); self } /// Set the validation mode, default is `ValidationMode::Strict`. pub fn validation_mode(mut self, validation_mode: ValidationMode) -> Self { - self.0.validation_mode = validation_mode; + self.validation_mode = validation_mode; self } /// Build schema. pub fn finish(self) -> Schema { - Schema(Arc::new(self.0)) + Schema(Arc::new(SchemaInner { + validation_mode: self.validation_mode, + query: self.query, + mutation: self.mutation, + subscription: self.subscription, + complexity: self.complexity, + depth: self.depth, + extensions: self.extensions, + env: SchemaEnv(Arc::new(SchemaEnvInner { + registry: self.registry, + data: self.data, + })), + })) } } +#[doc(hidden)] +pub struct SchemaEnvInner { + pub registry: Registry, + pub data: Data, +} + +#[doc(hidden)] +#[derive(Clone)] +pub struct SchemaEnv(Arc); + +impl Deref for SchemaEnv { + type Target = SchemaEnvInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +#[doc(hidden)] +pub struct SchemaInner { + pub(crate) validation_mode: ValidationMode, + pub(crate) query: QueryRoot, + pub(crate) mutation: Mutation, + pub(crate) subscription: Subscription, + pub(crate) complexity: Option, + pub(crate) depth: Option, + pub(crate) extensions: Vec BoxExtension + Send + Sync>>, + pub(crate) env: SchemaEnv, +} + /// GraphQL schema -pub struct Schema( - pub(crate) Arc>, -); +pub struct Schema(Arc>); impl Clone for Schema { fn clone(&self) -> Self { @@ -102,6 +140,19 @@ impl Clone for Schema Deref for Schema +where + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Subscription: SubscriptionType + Send + Sync + 'static, +{ + type Target = SchemaInner; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + impl Schema where Query: ObjectType + Send + Sync + 'static, @@ -177,11 +228,19 @@ where } }); + registry.add_directive(MetaDirective { + name: "defer", + description: None, + locations: vec![__DirectiveLocation::FIELD], + args: Default::default(), + }); + // register scalars bool::create_type_info(&mut registry); i32::create_type_info(&mut registry); f32::create_type_info(&mut registry); String::create_type_info(&mut registry); + ID::create_type_info(&mut registry); QueryRoot::::create_type_info(&mut registry); if !Mutation::is_empty() { @@ -194,7 +253,7 @@ where // federation registry.create_federation_types(); - SchemaBuilder(SchemaInner { + SchemaBuilder { validation_mode: ValidationMode::Strict, query: QueryRoot { inner: query, @@ -207,7 +266,7 @@ where complexity: None, depth: None, extensions: Default::default(), - }) + } } /// Create a schema @@ -219,16 +278,6 @@ where Self::build(query, mutation, subscription).finish() } - #[doc(hidden)] - pub fn data(&self) -> &Data { - &self.0.data - } - - #[doc(hidden)] - pub fn registry(&self) -> &Registry { - &self.0.registry - } - /// Execute query without create the `QueryBuilder`. pub async fn execute(&self, query_source: &str) -> Result { QueryBuilder::new(query_source).execute(self).await @@ -243,7 +292,7 @@ where ctx_data: Option>, ) -> Result> + Send> { let mut document = parse_query(source).map_err(Into::::into)?; - check_rules(&self.0.registry, &document, self.0.validation_mode)?; + check_rules(&self.env.registry, &document, self.0.validation_mode)?; if !document.retain_operation(operation_name) { return if let Some(name) = operation_name { @@ -261,19 +310,16 @@ where } let resolve_id = AtomicUsize::default(); - let environment = Arc::new(Environment { - variables, - document: Box::new(document), - ctx_data: ctx_data.unwrap_or_default(), - }); - let ctx = environment.create_context( - self, + let env = QueryEnv::new(variables, document, ctx_data.unwrap_or_default()); + let ctx = env.create_context( + &self.env, None, - &environment.document.current_operation().selection_set, + &env.document.current_operation().selection_set, &resolve_id, + None, ); let mut streams = Vec::new(); - create_subscription_stream(self, environment.clone(), &ctx, &mut streams).await?; + create_subscription_stream(self, env.clone(), &ctx, &mut streams).await?; Ok(futures::stream::select_all(streams)) } diff --git a/src/subscription/subscription_type.rs b/src/subscription/subscription_type.rs index 490e132f..49879108 100644 --- a/src/subscription/subscription_type.rs +++ b/src/subscription/subscription_type.rs @@ -1,9 +1,8 @@ -use crate::context::Environment; +use crate::context::QueryEnv; use crate::parser::query::{Selection, TypeCondition}; -use crate::{Context, ContextSelectionSet, ObjectType, Result, Schema, Type}; +use crate::{Context, ContextSelectionSet, ObjectType, Result, Schema, SchemaEnv, Type}; use futures::{Future, Stream}; use std::pin::Pin; -use std::sync::Arc; /// Represents a GraphQL subscription object #[async_trait::async_trait] @@ -15,16 +14,14 @@ pub trait SubscriptionType: Type { } #[doc(hidden)] - async fn create_field_stream( + async fn create_field_stream( &self, idx: usize, ctx: &Context<'_>, - schema: &Schema, - environment: Arc, + schema_env: SchemaEnv, + query_env: QueryEnv, ) -> Result> + Send>>> where - Query: ObjectType + Send + Sync + 'static, - Mutation: ObjectType + Send + Sync + 'static, Self: Send + Sync + 'static + Sized; } @@ -32,7 +29,7 @@ type BoxCreateStreamFuture<'a> = Pin> + Send pub fn create_subscription_stream<'a, Query, Mutation, Subscription>( schema: &'a Schema, - environment: Arc, + environment: QueryEnv, ctx: &'a ContextSelectionSet<'_>, streams: &'a mut Vec> + Send>>>, ) -> BoxCreateStreamFuture<'a> @@ -50,12 +47,11 @@ where } streams.push( schema - .0 .subscription .create_field_stream( idx, &ctx.with_field(field), - schema, + schema.env.clone(), environment.clone(), ) .await?, @@ -67,6 +63,7 @@ where } if let Some(fragment) = ctx + .query_env .document .fragments() .get(fragment_spread.fragment_name.as_str()) diff --git a/src/subscription/ws_transport.rs b/src/subscription/ws_transport.rs index 804501a2..b9f86b26 100644 --- a/src/subscription/ws_transport.rs +++ b/src/subscription/ws_transport.rs @@ -141,6 +141,7 @@ impl SubscriptionTransport for WebSocketTransport { id: Some(id.clone()), payload: Some( serde_json::to_value(GQLResponse(Ok(QueryResponse { + path: None, data: value, extensions: None, cache_control: Default::default(), diff --git a/src/types/deferred.rs b/src/types/deferred.rs new file mode 100644 index 00000000..f52f3e7b --- /dev/null +++ b/src/types/deferred.rs @@ -0,0 +1,61 @@ +use crate::context::DeferList; +use crate::registry::Registry; +use crate::{ContextSelectionSet, OutputValueType, Pos, QueryResponse, Result, Type}; +use std::borrow::Cow; +use std::sync::atomic::AtomicUsize; + +pub struct Deferred(T); + +impl From for Deferred { + fn from(value: T) -> Self { + Self(value) + } +} + +impl Type for Deferred { + fn type_name() -> Cow<'static, str> { + T::type_name() + } + + fn create_type_info(registry: &mut Registry) -> String { + T::create_type_info(registry) + } +} + +#[async_trait::async_trait] +impl OutputValueType for Deferred { + async fn resolve(&self, ctx: &ContextSelectionSet<'_>, pos: Pos) -> Result { + if let Some(defer_list) = ctx.defer_list { + let obj = self.0.clone(); + let schema_env = ctx.schema_env.clone(); + let query_env = ctx.query_env.clone(); + let field_selection_set = ctx.item.clone(); + let path = ctx.path_node.as_ref().map(|path| path.to_json()); + defer_list.append(async move { + let inc_resolve_id = AtomicUsize::default(); + let defer_list = DeferList::default(); + let ctx = query_env.create_context( + &schema_env, + None, + &field_selection_set, + &inc_resolve_id, + Some(&defer_list), + ); + let data = obj.resolve(&ctx, pos).await?; + + Ok(( + QueryResponse { + path, + data, + extensions: None, + cache_control: Default::default(), + }, + defer_list, + )) + }); + Ok(serde_json::Value::Null) + } else { + OutputValueType::resolve(&self.0, ctx, pos).await + } + } +} diff --git a/src/types/empty_subscription.rs b/src/types/empty_subscription.rs index fa03f7c8..cbd47e85 100644 --- a/src/types/empty_subscription.rs +++ b/src/types/empty_subscription.rs @@ -1,12 +1,11 @@ -use crate::context::Environment; +use crate::context::QueryEnv; use crate::{ - registry, Context, ContextSelectionSet, Error, ObjectType, OutputValueType, Pos, QueryError, - Result, Schema, SubscriptionType, Type, + registry, Context, ContextSelectionSet, Error, OutputValueType, Pos, QueryError, Result, + SchemaEnv, SubscriptionType, Type, }; use futures::Stream; use std::borrow::Cow; use std::pin::Pin; -use std::sync::Arc; /// Empty subscription /// @@ -36,16 +35,14 @@ impl SubscriptionType for EmptySubscription { true } - async fn create_field_stream( + async fn create_field_stream( &self, _idx: usize, _ctx: &Context<'_>, - _schema: &Schema, - _environment: Arc, + _schema_env: SchemaEnv, + _query_env: QueryEnv, ) -> Result> + Send>>> where - Query: ObjectType + Send + Sync + 'static, - Mutation: ObjectType + Send + Sync + 'static, Self: Send + Sync + 'static + Sized, { Err(Error::Query { diff --git a/src/types/mod.rs b/src/types/mod.rs index 4d740656..61d75b73 100644 --- a/src/types/mod.rs +++ b/src/types/mod.rs @@ -1,4 +1,5 @@ mod connection; +mod deferred; mod empty_mutation; mod empty_subscription; mod r#enum; @@ -8,6 +9,7 @@ mod query_root; mod upload; pub use connection::{Connection, Cursor, DataSource, EmptyEdgeFields, PageInfo, QueryOperation}; +pub use deferred::Deferred; pub use empty_mutation::EmptyMutation; pub use empty_subscription::EmptySubscription; pub use query_root::QueryRoot; diff --git a/src/types/query_root.rs b/src/types/query_root.rs index 664c07c2..b4889796 100644 --- a/src/types/query_root.rs +++ b/src/types/query_root.rs @@ -95,7 +95,7 @@ impl ObjectType for QueryRoot { let ctx_obj = ctx.with_selection_set(&ctx.selection_set); return OutputValueType::resolve( &__Schema { - registry: &ctx.registry, + registry: &ctx.schema_env.registry, }, &ctx_obj, ctx.position(), @@ -105,10 +105,11 @@ impl ObjectType for QueryRoot { let type_name: String = ctx.param_value("name", || Value::Null)?; let ctx_obj = ctx.with_selection_set(&ctx.selection_set); return OutputValueType::resolve( - &ctx.registry + &ctx.schema_env + .registry .types .get(&type_name) - .map(|ty| __Type::new_simple(ctx.registry, ty)), + .map(|ty| __Type::new_simple(&ctx.schema_env.registry, ty)), &ctx_obj, ctx.position(), ) @@ -124,7 +125,7 @@ impl ObjectType for QueryRoot { let ctx_obj = ctx.with_selection_set(&ctx.selection_set); return OutputValueType::resolve( &Service { - sdl: Some(ctx.registry.create_federation_sdl()), + sdl: Some(ctx.schema_env.registry.create_federation_sdl()), }, &ctx_obj, ctx.position(), diff --git a/src/validation/test_harness.rs b/src/validation/test_harness.rs index a4d74f26..07fd5ef7 100644 --- a/src/validation/test_harness.rs +++ b/src/validation/test_harness.rs @@ -371,10 +371,10 @@ where F: Fn() -> V, { let schema = Schema::new(query, mutation, subscription); - let registry = &schema.0.registry; + let registry = &schema.env.registry; let doc = parse_query(query_source).expect("Parse error"); let mut ctx = VisitorContext::new( - unsafe { ::std::mem::transmute(&schema.0.registry) }, + unsafe { ::std::mem::transmute(&schema.env.registry) }, unsafe { ::std::mem::transmute(&doc) }, ); let mut visitor = factory(); diff --git a/tests/defer.rs b/tests/defer.rs new file mode 100644 index 00000000..7e9ba9f0 --- /dev/null +++ b/tests/defer.rs @@ -0,0 +1,26 @@ +use async_graphql::*; + +#[async_std::test] +pub async fn test_defer() { + struct Query { + value: i32, + } + + #[Object] + impl Query { + async fn value(&self) -> Deferred { + 10.into() + } + } + + let schema = Schema::new(Query { value: 10 }, EmptyMutation, EmptySubscription); + let query = r#"{ + value + }"#; + assert_eq!( + schema.execute(&query).await.unwrap().data, + serde_json::json!({ + "value": 10, + }) + ); +}