From 824356d118457d4e15cb1b1733dae3406beefdd1 Mon Sep 17 00:00:00 2001 From: Sunli Date: Sun, 4 Apr 2021 12:05:54 +0800 Subject: [PATCH] Rework Extension --- CHANGELOG.md | 9 +- Cargo.toml | 4 +- derive/src/merged_subscription.rs | 2 +- derive/src/subscription.rs | 63 +-- src/context.rs | 55 -- src/extensions/analyzer.rs | 65 ++- src/extensions/apollo_persisted_queries.rs | 18 +- src/extensions/apollo_tracing.rs | 147 +++--- src/extensions/logger.rs | 179 +++---- src/extensions/mod.rs | 566 ++++++++++++--------- src/extensions/opentelemetry.rs | 270 ++++------ src/extensions/tracing.rs | 253 ++++----- src/resolver_utils/container.rs | 22 +- src/resolver_utils/list.rs | 71 +-- src/response.rs | 14 +- src/schema.rs | 191 ++++--- src/subscription.rs | 22 +- src/types/empty_subscription.rs | 7 +- src/validation/mod.rs | 1 + tests/extension.rs | 197 +++---- tests/field_features.rs | 5 +- tests/generic_types.rs | 3 +- tests/guard.rs | 1 - tests/merged_object.rs | 6 +- tests/raw_ident.rs | 3 +- tests/rename.rs | 1 - tests/subscription.rs | 28 +- 27 files changed, 1058 insertions(+), 1145 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 747f718e..fa8fab37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,14 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), -nd this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +### Changed + +- Rework `Extension`, now fully supports asynchronous, better to use than before, and can achieve more features. + Because it contains a lot of changes _(if you don't have a custom extension, it will not cause the existing code to fail to compile)_, the major version will be updated to `3.0.0`. ## [2.7.4] 2021-04-02 diff --git a/Cargo.toml b/Cargo.toml index bf66bd23..acb042b1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ apollo_persisted_queries = ["lru", "sha2"] unblock = ["blocking"] string_number = ["num-traits"] dataloader = ["futures-timer", "futures-channel", "lru"] +tracing = ["tracinglib", "tracing-futures"] [dependencies] async-graphql-derive = { path = "derive", version = "=2.7.3" } @@ -47,7 +48,8 @@ bson = { version = "1.2.0", optional = true } chrono = { version = "0.4.19", optional = true } chrono-tz = { version = "0.5.3", optional = true } log = { version = "0.4.14", optional = true } -tracing = { version = "0.1.25", optional = true } +tracinglib = { version = "0.1.25", optional = true, package = "tracing" } +tracing-futures = { version = "0.2.5", optional = true, features = ["std-future", "futures-03"] } opentelemetry = { version = "0.13.0", optional = true } url = { version = "2.2.1", optional = true } uuid = { version = "0.8.2", optional = true, features = ["v4", "serde"] } diff --git a/derive/src/merged_subscription.rs b/derive/src/merged_subscription.rs index 6b912426..2b57dea8 100644 --- a/derive/src/merged_subscription.rs +++ b/derive/src/merged_subscription.rs @@ -82,7 +82,7 @@ pub fn generate(object_args: &args::MergedSubscription) -> GeneratorResult( &'__life self, ctx: &'__life #crate_name::Context<'__life> - ) -> ::std::option::Option<::std::pin::Pin<::std::boxed::Box> + ::std::marker::Send + '__life>>> { + ) -> ::std::option::Option<::std::pin::Pin<::std::boxed::Box + ::std::marker::Send + '__life>>> { ::std::option::Option::None #create_field_stream } } diff --git a/derive/src/subscription.rs b/derive/src/subscription.rs index ec97f08a..2896d6d7 100644 --- a/derive/src/subscription.rs +++ b/derive/src/subscription.rs @@ -282,7 +282,7 @@ pub fn generate( quote! { Some(#crate_name::registry::ComplexityType::Fn(|__ctx, __variables_definition, __field, child_complexity| { #(#parse_args)* - Ok(#expr) + ::std::result::Result::Ok(#expr) })) } } @@ -331,7 +331,7 @@ pub fn generate( let stream_fn = quote! { #(#get_params)* #guard - let field_name = ::std::sync::Arc::new(::std::clone::Clone::clone(&ctx.item.node.response_key().node)); + let field_name = ::std::clone::Clone::clone(&ctx.item.node.response_key().node); let field = ::std::sync::Arc::new(::std::clone::Clone::clone(&ctx.item)); let pos = ctx.item.pos; @@ -345,11 +345,6 @@ pub fn generate( let field = ::std::clone::Clone::clone(&field); let field_name = ::std::clone::Clone::clone(&field_name); async move { - let resolve_id = #crate_name::ResolveId { - parent: ::std::option::Option::Some(0), - current: 1, - }; - let inc_resolve_id = ::std::sync::atomic::AtomicUsize::new(1); let ctx_selection_set = query_env.create_context( &schema_env, ::std::option::Option::Some(#crate_name::QueryPathNode { @@ -357,26 +352,31 @@ pub fn generate( segment: #crate_name::QueryPathSegment::Name(&field_name), }), &field.node.selection_set, - resolve_id, - &inc_resolve_id, ); - query_env.extensions.execution_start(); - - #[allow(bare_trait_objects)] - let ri = #crate_name::extensions::ResolveInfo { - resolve_id, - path_node: ctx_selection_set.path_node.as_ref().unwrap(), - parent_type: #gql_typename, - return_type: &<<#stream_ty as #crate_name::futures_util::stream::Stream>::Item as #crate_name::Type>::qualified_type_name(), + let mut execute_fut = async { + #[allow(bare_trait_objects)] + let ri = #crate_name::extensions::ResolveInfo { + path_node: ctx_selection_set.path_node.as_ref().unwrap(), + parent_type: #gql_typename, + return_type: &<<#stream_ty as #crate_name::futures_util::stream::Stream>::Item as #crate_name::Type>::qualified_type_name(), + }; + let resolve_fut = async { + #crate_name::OutputType::resolve(&msg, &ctx_selection_set, &*field) + .await + .map(::std::option::Option::Some) + }; + #crate_name::futures_util::pin_mut!(resolve_fut); + query_env.extensions.resolve(ri, &mut resolve_fut).await + .map(|value| { + let mut map = ::std::collections::BTreeMap::new(); + map.insert(::std::clone::Clone::clone(&field_name), value.unwrap_or_default()); + #crate_name::Response::new(#crate_name::Value::Object(map)) + }) + .unwrap_or_else(|err| #crate_name::Response::from_errors(::std::vec![err])) }; - - query_env.extensions.resolve_start(&ri); - let res = #crate_name::OutputType::resolve(&msg, &ctx_selection_set, &*field).await; - query_env.extensions.resolve_end(&ri); - query_env.extensions.execution_end(); - - res + #crate_name::futures_util::pin_mut!(execute_fut); + ::std::result::Result::Ok(query_env.extensions.execute(&mut execute_fut).await) } } }); @@ -398,11 +398,14 @@ pub fn generate( create_stream.push(quote! { #(#cfg_attrs)* if ctx.item.node.name.node == #field_name { - return ::std::option::Option::Some(::std::boxed::Box::pin( - #crate_name::futures_util::stream::TryStreamExt::try_flatten( - #crate_name::futures_util::stream::once((move || async move { #stream_fn })()) - ) - )); + let stream = #crate_name::futures_util::stream::TryStreamExt::try_flatten( + #crate_name::futures_util::stream::once((move || async move { #stream_fn })()) + ); + let stream = #crate_name::futures_util::StreamExt::map(stream, |res| match res { + ::std::result::Result::Ok(resp) => resp, + ::std::result::Result::Err(err) => #crate_name::Response::from_errors(::std::vec![err]), + }); + return ::std::option::Option::Some(::std::boxed::Box::pin(stream)); } }); @@ -451,7 +454,7 @@ pub fn generate( fn create_field_stream<'__life>( &'__life self, ctx: &'__life #crate_name::Context<'_>, - ) -> ::std::option::Option<::std::pin::Pin<::std::boxed::Box> + ::std::marker::Send + '__life>>> { + ) -> ::std::option::Option<::std::pin::Pin<::std::boxed::Box + ::std::marker::Send + '__life>>> { #(#create_stream)* ::std::option::Option::None } diff --git a/src/context.rs b/src/context.rs index 75d1cef5..894edced 100644 --- a/src/context.rs +++ b/src/context.rs @@ -4,7 +4,6 @@ use std::any::{Any, TypeId}; use std::collections::HashMap; use std::fmt::{self, Debug, Display, Formatter}; use std::ops::Deref; -use std::sync::atomic::AtomicUsize; use std::sync::Arc; use async_graphql_value::{Value as InputValue, Variables}; @@ -189,36 +188,6 @@ impl<'a> Iterator for Parents<'a> { impl<'a> std::iter::FusedIterator for Parents<'a> {} -/// The unique id of the current resolution. -#[derive(Debug, Clone, Copy)] -pub struct ResolveId { - /// The unique ID of the parent resolution. - pub parent: Option, - - /// The current unique id. - pub current: usize, -} - -impl ResolveId { - #[doc(hidden)] - pub fn root() -> ResolveId { - ResolveId { - parent: None, - current: 0, - } - } -} - -impl Display for ResolveId { - fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - if let Some(parent) = self.parent { - write!(f, "{}:{}", parent, self.current) - } else { - write!(f, "{}", self.current) - } - } -} - /// Query context. /// /// **This type is not stable and should not be used directly.** @@ -226,8 +195,6 @@ impl Display for ResolveId { pub struct ContextBase<'a, T> { /// The current path node being resolved. pub path_node: Option>, - pub(crate) resolve_id: ResolveId, - pub(crate) inc_resolve_id: &'a AtomicUsize, #[doc(hidden)] pub item: T, #[doc(hidden)] @@ -273,13 +240,9 @@ impl QueryEnv { schema_env: &'a SchemaEnv, path_node: Option>, item: T, - resolve_id: ResolveId, - inc_resolve_id: &'a AtomicUsize, ) -> ContextBase<'a, T> { ContextBase { path_node, - resolve_id, - inc_resolve_id, item, schema_env, query_env: self, @@ -288,18 +251,6 @@ impl QueryEnv { } impl<'a, T> ContextBase<'a, T> { - #[doc(hidden)] - pub fn get_child_resolve_id(&self) -> ResolveId { - let id = self - .inc_resolve_id - .fetch_add(1, std::sync::atomic::Ordering::Relaxed) - + 1; - ResolveId { - parent: Some(self.resolve_id.current), - current: id, - } - } - #[doc(hidden)] pub fn with_field( &'a self, @@ -311,8 +262,6 @@ impl<'a, T> ContextBase<'a, T> { segment: QueryPathSegment::Name(&field.node.response_key().node), }), item: field, - resolve_id: self.get_child_resolve_id(), - inc_resolve_id: self.inc_resolve_id, schema_env: self.schema_env, query_env: self.query_env, } @@ -326,8 +275,6 @@ impl<'a, T> ContextBase<'a, T> { ContextBase { path_node: self.path_node, item: selection_set, - resolve_id: self.resolve_id, - inc_resolve_id: &self.inc_resolve_id, schema_env: self.schema_env, query_env: self.query_env, } @@ -560,8 +507,6 @@ impl<'a> ContextBase<'a, &'a Positioned> { segment: QueryPathSegment::Index(idx), }), item: self.item, - resolve_id: self.get_child_resolve_id(), - inc_resolve_id: self.inc_resolve_id, schema_env: self.schema_env, query_env: self.query_env, } diff --git a/src/extensions/analyzer.rs b/src/extensions/analyzer.rs index 6e2962bc..409a5e47 100644 --- a/src/extensions/analyzer.rs +++ b/src/extensions/analyzer.rs @@ -1,5 +1,9 @@ -use crate::extensions::{Extension, ExtensionContext, ExtensionFactory}; -use crate::{value, ValidationResult, Value}; +use std::sync::Arc; + +use futures_util::lock::Mutex; + +use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, NextExtension}; +use crate::{value, Response, ServerError, ValidationResult}; /// Analyzer extension /// @@ -7,32 +11,41 @@ use crate::{value, ValidationResult, Value}; pub struct Analyzer; impl ExtensionFactory for Analyzer { - fn create(&self) -> Box { - Box::new(AnalyzerExtension::default()) + fn create(&self) -> Arc { + Arc::new(AnalyzerExtension::default()) } } #[derive(Default)] struct AnalyzerExtension { - complexity: usize, - depth: usize, + validation_result: Mutex>, } +#[async_trait::async_trait] impl Extension for AnalyzerExtension { - fn name(&self) -> Option<&'static str> { - Some("analyzer") + async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + let mut resp = next.request(ctx).await; + let validation_result = self.validation_result.lock().await.take(); + if let Some(validation_result) = validation_result { + resp = resp.extension( + "analyzer", + value! ({ + "complexity": validation_result.complexity, + "depth": validation_result.depth, + }), + ); + } + resp } - fn validation_end(&mut self, _ctx: &ExtensionContext<'_>, result: &ValidationResult) { - self.complexity = result.complexity; - self.depth = result.depth; - } - - fn result(&mut self, _ctx: &ExtensionContext<'_>) -> Option { - Some(value! ({ - "complexity": self.complexity, - "depth": self.depth, - })) + async fn validation( + &self, + ctx: &ExtensionContext<'_>, + next: NextExtension<'_>, + ) -> Result> { + let res = next.validation(ctx).await?; + *self.validation_result.lock().await = Some(res); + Ok(res) } } @@ -78,7 +91,7 @@ mod tests { .extension(extensions::Analyzer) .finish(); - let extensions = schema + let res = schema .execute( r#"{ value obj { @@ -93,15 +106,13 @@ mod tests { .into_result() .unwrap() .extensions - .unwrap(); + .remove("analyzer"); assert_eq!( - extensions, - value!({ - "analyzer": { - "complexity": 5 + 10, - "depth": 3, - } - }) + res, + Some(value!({ + "complexity": 5 + 10, + "depth": 3, + })) ); } } diff --git a/src/extensions/apollo_persisted_queries.rs b/src/extensions/apollo_persisted_queries.rs index 00dc3043..98d60c94 100644 --- a/src/extensions/apollo_persisted_queries.rs +++ b/src/extensions/apollo_persisted_queries.rs @@ -6,7 +6,7 @@ use futures_util::lock::Mutex; use serde::Deserialize; use sha2::{Digest, Sha256}; -use crate::extensions::{Extension, ExtensionContext, ExtensionFactory}; +use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, NextExtension}; use crate::{from_value, Request, ServerError, ServerResult}; #[derive(Deserialize)] @@ -64,8 +64,8 @@ impl ApolloPersistedQueries { } impl ExtensionFactory for ApolloPersistedQueries { - fn create(&self) -> Box { - Box::new(ApolloPersistedQueriesExtension { + fn create(&self) -> Arc { + Arc::new(ApolloPersistedQueriesExtension { storage: self.0.clone(), }) } @@ -78,18 +78,19 @@ struct ApolloPersistedQueriesExtension { #[async_trait::async_trait] impl Extension for ApolloPersistedQueriesExtension { async fn prepare_request( - &mut self, - _ctx: &ExtensionContext<'_>, + &self, + ctx: &ExtensionContext<'_>, mut request: Request, + next: NextExtension<'_>, ) -> ServerResult { - if let Some(value) = request.extensions.remove("persistedQuery") { + let res = if let Some(value) = request.extensions.remove("persistedQuery") { let persisted_query: PersistedQuery = from_value(value).map_err(|_| { ServerError::new("Invalid \"PersistedQuery\" extension configuration.") })?; if persisted_query.version != 1 { return Err(ServerError::new( format!("Only the \"PersistedQuery\" extension of version \"1\" is supported, and the current version is \"{}\".", persisted_query.version), - )); + )); } if request.query.is_empty() { @@ -110,7 +111,8 @@ impl Extension for ApolloPersistedQueriesExtension { } } else { Ok(request) - } + }; + next.prepare_request(ctx, res?).await } } diff --git a/src/extensions/apollo_tracing.rs b/src/extensions/apollo_tracing.rs index d54990b6..cb60fd97 100644 --- a/src/extensions/apollo_tracing.rs +++ b/src/extensions/apollo_tracing.rs @@ -1,36 +1,26 @@ -use std::collections::BTreeMap; -use std::ops::Deref; +use std::sync::Arc; use chrono::{DateTime, Utc}; +use futures_util::lock::Mutex; use serde::ser::SerializeMap; use serde::{Serialize, Serializer}; -use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; -use crate::{value, Value}; +use crate::extensions::{ + Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo, +}; +use crate::{value, Response, ServerResult, Value}; -struct PendingResolve { +struct ResolveState { path: Vec, field_name: String, parent_type: String, return_type: String, start_time: DateTime, -} - -struct ResolveStat { - pending_resolve: PendingResolve, end_time: DateTime, start_offset: i64, } -impl Deref for ResolveStat { - type Target = PendingResolve; - - fn deref(&self) -> &Self::Target { - &self.pending_resolve - } -} - -impl Serialize for ResolveStat { +impl Serialize for ResolveState { fn serialize(&self, serializer: S) -> Result { let mut map = serializer.serialize_map(None)?; map.serialize_entry("path", &self.path)?; @@ -57,76 +47,79 @@ impl Serialize for ResolveStat { pub struct ApolloTracing; impl ExtensionFactory for ApolloTracing { - fn create(&self) -> Box { - Box::new(ApolloTracingExtension { - start_time: Utc::now(), - end_time: Utc::now(), - pending_resolves: Default::default(), - resolves: Default::default(), + fn create(&self) -> Arc { + Arc::new(ApolloTracingExtension { + inner: Mutex::new(Inner { + start_time: Utc::now(), + end_time: Utc::now(), + resolves: Default::default(), + }), }) } } -struct ApolloTracingExtension { +struct Inner { start_time: DateTime, end_time: DateTime, - pending_resolves: BTreeMap, - resolves: Vec, + resolves: Vec, } +struct ApolloTracingExtension { + inner: Mutex, +} + +#[async_trait::async_trait] impl Extension for ApolloTracingExtension { - fn name(&self) -> Option<&'static str> { - Some("tracing") - } + async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + self.inner.lock().await.start_time = Utc::now(); + let resp = next.execute(ctx).await; - fn execution_start(&mut self, _ctx: &ExtensionContext<'_>) { - self.start_time = Utc::now(); - self.pending_resolves.clear(); - self.resolves.clear(); - } - - fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) { - self.end_time = Utc::now(); - } - - fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { - self.pending_resolves.insert( - info.resolve_id.current, - PendingResolve { - path: info.path_node.to_string_vec(), - field_name: info.path_node.field_name().to_string(), - parent_type: info.parent_type.to_string(), - return_type: info.return_type.to_string(), - start_time: Utc::now(), - }, - ); - } - - fn resolve_end(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { - if let Some(pending_resolve) = self.pending_resolves.remove(&info.resolve_id.current) { - let start_offset = (pending_resolve.start_time - self.start_time) - .num_nanoseconds() - .unwrap(); - self.resolves.push(ResolveStat { - pending_resolve, - start_offset, - end_time: Utc::now(), - }); - } - } - - fn result(&mut self, _ctx: &ExtensionContext<'_>) -> Option { - self.resolves + let mut inner = self.inner.lock().await; + inner.end_time = Utc::now(); + inner + .resolves .sort_by(|a, b| a.start_offset.cmp(&b.start_offset)); + resp.extension( + "tracing", + value!({ + "version": 1, + "startTime": inner.start_time.to_rfc3339(), + "endTime": inner.end_time.to_rfc3339(), + "duration": (inner.end_time - inner.start_time).num_nanoseconds(), + "execution": { + "resolvers": inner.resolves + } + }), + ) + } - Some(value!({ - "version": 1, - "startTime": self.start_time.to_rfc3339(), - "endTime": self.end_time.to_rfc3339(), - "duration": (self.end_time - self.start_time).num_nanoseconds(), - "execution": { - "resolvers": self.resolves - } - })) + async fn resolve( + &self, + ctx: &ExtensionContext<'_>, + info: ResolveInfo<'_>, + next: NextExtension<'_>, + ) -> ServerResult> { + let path = info.path_node.to_string_vec(); + let field_name = info.path_node.field_name().to_string(); + let parent_type = info.parent_type.to_string(); + let return_type = info.return_type.to_string(); + let start_time = Utc::now(); + let start_offset = (start_time - self.inner.lock().await.start_time) + .num_nanoseconds() + .unwrap(); + + let res = next.resolve(ctx, info).await; + let end_time = Utc::now(); + + self.inner.lock().await.resolves.push(ResolveState { + path, + field_name, + parent_type, + return_type, + start_time, + end_time, + start_offset, + }); + res } } diff --git a/src/extensions/logger.rs b/src/extensions/logger.rs index 7bd3e357..53c12a0d 100644 --- a/src/extensions/logger.rs +++ b/src/extensions/logger.rs @@ -1,117 +1,124 @@ use std::fmt::{self, Display, Formatter}; +use std::sync::Arc; -use log::{error, info, trace}; +use futures_util::lock::Mutex; -use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; +use crate::extensions::{ + Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo, +}; use crate::parser::types::{ExecutableDocument, OperationType, Selection}; -use crate::{PathSegment, ServerError, Variables}; +use crate::{PathSegment, ServerError, ServerResult, Value, Variables}; /// Logger extension #[cfg_attr(docsrs, doc(cfg(feature = "log")))] pub struct Logger; impl ExtensionFactory for Logger { - fn create(&self) -> Box { - Box::new(LoggerExtension { - enabled: true, - query: String::new(), - variables: Default::default(), + fn create(&self) -> Arc { + Arc::new(LoggerExtension { + inner: Mutex::new(Inner { + enabled: true, + query: String::new(), + variables: Default::default(), + }), }) } } -struct LoggerExtension { +struct Inner { enabled: bool, query: String, variables: Variables, } -impl Extension for LoggerExtension { - fn parse_start( - &mut self, - _ctx: &ExtensionContext<'_>, - query_source: &str, - variables: &Variables, - ) { - self.query = query_source.replace(char::is_whitespace, ""); - self.variables = variables.clone(); - } +struct LoggerExtension { + inner: Mutex, +} - fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, document: &ExecutableDocument) { +#[async_trait::async_trait] +impl Extension for LoggerExtension { + async fn parse_query( + &self, + ctx: &ExtensionContext<'_>, + query: &str, + variables: &Variables, + next: NextExtension<'_>, + ) -> ServerResult { + let mut inner = self.inner.lock().await; + inner.query = query.replace(char::is_whitespace, ""); + inner.variables = variables.clone(); + let document = next.parse_query(ctx, query, variables).await?; let is_schema = document .operations .iter() .filter(|(_, operation)| operation.node.ty == OperationType::Query) .any(|(_, operation)| operation.node.selection_set.node.items.iter().any(|selection| matches!(&selection.node, Selection::Field(field) if field.node.name.node == "__schema"))); - - if is_schema { - self.enabled = false; - return; - } - - info!(target: "async-graphql", "[Query] query: \"{}\", variables: {}", &self.query, self.variables); + inner.enabled = !is_schema; + Ok(document) } - fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { - if !self.enabled { - return; - } - trace!(target: "async-graphql", "[ResolveStart] path: \"{}\"", info.path_node); - } - - fn resolve_end(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { - if !self.enabled { - return; - } - trace!(target: "async-graphql", "[ResolveEnd] path: \"{}\"", info.path_node); - } - - fn error(&mut self, _ctx: &ExtensionContext<'_>, err: &ServerError) { - struct DisplayError<'a> { - log: &'a LoggerExtension, - e: &'a ServerError, - } - impl<'a> Display for DisplayError<'a> { - fn fmt(&self, f: &mut Formatter) -> fmt::Result { - write!(f, "[Error] ")?; - - if !self.e.path.is_empty() { - write!(f, "path: ")?; - for (i, segment) in self.e.path.iter().enumerate() { - if i != 0 { - write!(f, ".")?; - } - - match segment { - PathSegment::Field(field) => write!(f, "{}", field), - PathSegment::Index(i) => write!(f, "{}", i), - }?; - } - write!(f, ", ")?; - } - if !self.e.locations.is_empty() { - write!(f, "pos: [")?; - for (i, location) in self.e.locations.iter().enumerate() { - if i != 0 { - write!(f, ", ")?; - } - write!(f, "{}:{}", location.line, location.column)?; - } - write!(f, "], ")?; - } - write!(f, r#"query: "{}", "#, self.log.query)?; - write!(f, "variables: {}", self.log.variables)?; - write!(f, "{}", self.e.message) + async fn resolve( + &self, + ctx: &ExtensionContext<'_>, + info: ResolveInfo<'_>, + next: NextExtension<'_>, + ) -> ServerResult> { + let enabled = self.inner.lock().await.enabled; + if enabled { + let path = info.path_node.to_string(); + log::trace!(target: "async-graphql", "[ResolveStart] path: \"{}\"", path); + let res = next.resolve(ctx, info).await; + if let Err(err) = &res { + let inner = self.inner.lock().await; + log::error!( + target: "async-graphql", + "{}", + DisplayError { query:&inner.query,variables:&inner.variables, e: &err } + ); } + log::trace!(target: "async-graphql", "[ResolveEnd] path: \"{}\"", path); + res + } else { + next.resolve(ctx, info).await } - - error!( - target: "async-graphql", - "{}", - DisplayError { - log: self, - e: err, - } - ); + } +} + +struct DisplayError<'a> { + query: &'a str, + variables: &'a Variables, + e: &'a ServerError, +} +impl<'a> Display for DisplayError<'a> { + fn fmt(&self, f: &mut Formatter) -> fmt::Result { + write!(f, "[Error] ")?; + + if !self.e.path.is_empty() { + write!(f, "path: ")?; + for (i, segment) in self.e.path.iter().enumerate() { + if i != 0 { + write!(f, ".")?; + } + + match segment { + PathSegment::Field(field) => write!(f, "{}", field), + PathSegment::Index(i) => write!(f, "{}", i), + }?; + } + write!(f, ", ")?; + } + if !self.e.locations.is_empty() { + write!(f, "pos: [")?; + for (i, location) in self.e.locations.iter().enumerate() { + if i != 0 { + write!(f, ", ")?; + } + write!(f, "{}:{}", location.line, location.column)?; + } + write!(f, "], ")?; + } + write!(f, r#"query: "{}", "#, self.query)?; + write!(f, "variables: {}", self.variables)?; + write!(f, "{}", self.e.message) } } diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index f6799d33..e9ac0cf4 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -12,28 +12,28 @@ mod opentelemetry; #[cfg(feature = "tracing")] mod tracing; -use std::any::{Any, TypeId}; -use std::collections::BTreeMap; -use std::sync::Arc; - -use crate::context::{QueryPathNode, ResolveId}; -use crate::parser::types::ExecutableDocument; -use crate::{ - Data, Request, Result, SchemaEnv, ServerError, ServerResult, ValidationResult, Variables, -}; -use crate::{Error, Name, Value}; - pub use self::analyzer::Analyzer; #[cfg(feature = "apollo_tracing")] pub use self::apollo_tracing::ApolloTracing; #[cfg(feature = "log")] pub use self::logger::Logger; #[cfg(feature = "opentelemetry")] -pub use self::opentelemetry::{OpenTelemetry, OpenTelemetryConfig}; +pub use self::opentelemetry::OpenTelemetry; #[cfg(feature = "tracing")] -pub use self::tracing::{Tracing, TracingConfig}; +pub use self::tracing::Tracing; -pub(crate) type BoxExtension = Box; +use std::any::{Any, TypeId}; +use std::future::Future; +use std::sync::Arc; + +use futures_util::stream::BoxStream; +use futures_util::stream::StreamExt; + +use crate::parser::types::ExecutableDocument; +use crate::{ + Data, Error, QueryPathNode, Request, Response, Result, SchemaEnv, ServerError, ServerResult, + ValidationResult, Value, Variables, +}; /// Context for extension pub struct ExtensionContext<'a> { @@ -86,10 +86,6 @@ impl<'a> ExtensionContext<'a> { /// Parameters for `Extension::resolve_field_start` pub struct ResolveInfo<'a> { - /// Because resolver is concurrent, `Extension::resolve_field_start` and `Extension::resolve_field_end` are - /// not strictly ordered, so each pair is identified by an id. - pub resolve_id: ResolveId, - /// Current path node, You can go through the entire path. pub path_node: &'a QueryPathNode<'a>, @@ -100,120 +96,250 @@ pub struct ResolveInfo<'a> { pub return_type: &'a str, } -/// Represents a GraphQL extension -/// -/// # Call order for query and mutation -/// -/// - start -/// - prepare_request -/// - parse_start -/// - parse_end -/// - validation_start -/// - validation_end -/// - execution_start -/// - resolve_start -/// - resolve_end -/// - result -/// - execution_end -/// - end -/// -/// # Call order for subscription -/// -/// - start -/// - prepare_request -/// - parse_start -/// - parse_end -/// - validation_start -/// - validation_end -/// - execution_start -/// - resolve_start -/// - resolve_end -/// - execution_end -/// - result -/// ``` -#[async_trait::async_trait] -#[allow(unused_variables)] -pub trait Extension: Sync + Send + 'static { - /// If this extension needs to output data to query results, you need to specify a name. - fn name(&self) -> Option<&'static str> { - None +type RequestFut<'a> = &'a mut (dyn Future + Send + Unpin); + +type ParseFut<'a> = &'a mut (dyn Future> + Send + Unpin); + +type ValidationFut<'a> = + &'a mut (dyn Future>> + Send + Unpin); + +type ExecuteFut<'a> = &'a mut (dyn Future + Send + Unpin); + +type ResolveFut<'a> = &'a mut (dyn Future>> + Send + Unpin); + +/// The remainder of a extension chain. +pub struct NextExtension<'a> { + chain: &'a [Arc], + request_fut: Option>, + parse_query_fut: Option>, + validation_fut: Option>, + execute_fut: Option>, + resolve_fut: Option>, +} + +impl<'a> NextExtension<'a> { + #[inline] + pub(crate) fn new(chain: &'a [Arc]) -> Self { + Self { + chain, + request_fut: None, + parse_query_fut: None, + validation_fut: None, + execute_fut: None, + resolve_fut: None, + } } - /// Called at the beginning of query. - fn start(&mut self, ctx: &ExtensionContext<'_>) {} + #[inline] + pub(crate) fn with_chain(self, chain: &'a [Arc]) -> Self { + Self { chain, ..self } + } - /// Called at the beginning of query. - fn end(&mut self, ctx: &ExtensionContext<'_>) {} + #[inline] + pub(crate) fn with_request(self, fut: RequestFut<'a>) -> Self { + Self { + request_fut: Some(fut), + ..self + } + } - /// Called at prepare request. - async fn prepare_request( - &mut self, + #[inline] + pub(crate) fn with_parse_query(self, fut: ParseFut<'a>) -> Self { + Self { + parse_query_fut: Some(fut), + ..self + } + } + + #[inline] + pub(crate) fn with_validation(self, fut: ValidationFut<'a>) -> Self { + Self { + validation_fut: Some(fut), + ..self + } + } + + #[inline] + pub(crate) fn with_execute(self, fut: ExecuteFut<'a>) -> Self { + Self { + execute_fut: Some(fut), + ..self + } + } + + #[inline] + pub(crate) fn with_resolve(self, fut: ResolveFut<'a>) -> Self { + Self { + resolve_fut: Some(fut), + ..self + } + } + + /// Call the [Extension::request] function of next extension. + pub async fn request(mut self, ctx: &ExtensionContext<'_>) -> Response { + if let Some((first, next)) = self.chain.split_first() { + first.request(ctx, self.with_chain(next)).await + } else { + self.request_fut + .take() + .expect("You definitely called the wrong function.") + .await + } + } + + /// Call the [Extension::subscribe] function of next extension. + pub fn subscribe<'s>( + self, + ctx: &ExtensionContext<'_>, + stream: BoxStream<'s, Response>, + ) -> BoxStream<'s, Response> { + if let Some((first, next)) = self.chain.split_first() { + first.subscribe(ctx, stream, self.with_chain(next)) + } else { + stream + } + } + + /// Call the [Extension::prepare_request] function of next extension. + pub async fn prepare_request( + self, ctx: &ExtensionContext<'_>, request: Request, ) -> ServerResult { - Ok(request) + if let Some((first, next)) = self.chain.split_first() { + first + .prepare_request(ctx, request, self.with_chain(next)) + .await + } else { + Ok(request) + } } - /// Called at the beginning of parse query source. - fn parse_start( - &mut self, + /// Call the [Extension::parse_query] function of next extension. + pub async fn parse_query( + mut self, ctx: &ExtensionContext<'_>, - query_source: &str, + query: &str, variables: &Variables, - ) { - } - - /// Called at the end of parse query source. - fn parse_end(&mut self, ctx: &ExtensionContext<'_>, document: &ExecutableDocument) {} - - /// Called at the beginning of the validation. - fn validation_start(&mut self, ctx: &ExtensionContext<'_>) {} - - /// Called at the end of the validation. - fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) {} - - /// Called at the beginning of execute a query. - fn execution_start(&mut self, ctx: &ExtensionContext<'_>) {} - - /// Called at the end of execute a query. - fn execution_end(&mut self, ctx: &ExtensionContext<'_>) {} - - /// Called at the beginning of resolve a field. - fn resolve_start(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {} - - /// Called at the end of resolve a field. - fn resolve_end(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {} - - /// Called when an error occurs. - fn error(&mut self, ctx: &ExtensionContext<'_>, err: &ServerError) {} - - /// Get the results. - fn result(&mut self, ctx: &ExtensionContext<'_>) -> Option { - None - } -} - -pub(crate) trait ErrorLogger { - fn log_error(self, extensions: &Extensions) -> Self; -} - -impl ErrorLogger for ServerResult { - fn log_error(self, extensions: &Extensions) -> Self { - if let Err(err) = &self { - extensions.error(err); + ) -> ServerResult { + if let Some((first, next)) = self.chain.split_first() { + first + .parse_query(ctx, query, variables, self.with_chain(next)) + .await + } else { + self.parse_query_fut + .take() + .expect("You definitely called the wrong function.") + .await + } + } + + /// Call the [Extension::validation] function of next extension. + pub async fn validation( + mut self, + ctx: &ExtensionContext<'_>, + ) -> Result> { + if let Some((first, next)) = self.chain.split_first() { + first.validation(ctx, self.with_chain(next)).await + } else { + self.validation_fut + .take() + .expect("You definitely called the wrong function.") + .await + } + } + + /// Call the [Extension::execute] function of next extension. + pub async fn execute(mut self, ctx: &ExtensionContext<'_>) -> Response { + if let Some((first, next)) = self.chain.split_first() { + first.execute(ctx, self.with_chain(next)).await + } else { + self.execute_fut + .take() + .expect("You definitely called the wrong function.") + .await + } + } + + /// Call the [Extension::resolve] function of next extension. + pub async fn resolve( + mut self, + ctx: &ExtensionContext<'_>, + info: ResolveInfo<'_>, + ) -> ServerResult> { + if let Some((first, next)) = self.chain.split_first() { + first.resolve(ctx, info, self.with_chain(next)).await + } else { + self.resolve_fut + .take() + .expect("You definitely called the wrong function.") + .await } - self } } -impl ErrorLogger for Result> { - fn log_error(self, extensions: &Extensions) -> Self { - if let Err(errors) = &self { - for error in errors { - extensions.error(error); - } - } - self +/// Represents a GraphQL extension +#[async_trait::async_trait] +#[allow(unused_variables)] +pub trait Extension: Sync + Send + 'static { + /// Called at start query/mutation request. + async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + next.request(ctx).await + } + + /// Called at subscribe request. + fn subscribe<'s>( + &self, + ctx: &ExtensionContext<'_>, + stream: BoxStream<'s, Response>, + next: NextExtension<'_>, + ) -> BoxStream<'s, Response> { + next.subscribe(ctx, stream) + } + + /// Called at prepare request. + async fn prepare_request( + &self, + ctx: &ExtensionContext<'_>, + request: Request, + next: NextExtension<'_>, + ) -> ServerResult { + next.prepare_request(ctx, request).await + } + + /// Called at parse query. + async fn parse_query( + &self, + ctx: &ExtensionContext<'_>, + query: &str, + variables: &Variables, + next: NextExtension<'_>, + ) -> ServerResult { + next.parse_query(ctx, query, variables).await + } + + /// Called at validation query. + async fn validation( + &self, + ctx: &ExtensionContext<'_>, + next: NextExtension<'_>, + ) -> Result> { + next.validation(ctx).await + } + + /// Called at execute query. + async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + next.execute(ctx).await + } + + /// Called at resolve field. + async fn resolve( + &self, + ctx: &ExtensionContext<'_>, + info: ResolveInfo<'_>, + next: NextExtension<'_>, + ) -> ServerResult> { + next.resolve(ctx, info).await } } @@ -222,12 +348,13 @@ impl ErrorLogger for Result> { /// Used to create an extension instance. pub trait ExtensionFactory: Send + Sync + 'static { /// Create an extended instance. - fn create(&self) -> Box; + fn create(&self) -> Arc; } +#[derive(Clone)] #[doc(hidden)] pub struct Extensions { - extensions: Option>>, + extensions: Vec>, schema_env: SchemaEnv, session_data: Arc, query_data: Option>, @@ -235,17 +362,13 @@ pub struct Extensions { #[doc(hidden)] impl Extensions { - pub fn new( - extensions: Vec, + pub(crate) fn new( + extensions: impl IntoIterator>, schema_env: SchemaEnv, session_data: Arc, ) -> Self { Extensions { - extensions: if extensions.is_empty() { - None - } else { - Some(spin::Mutex::new(extensions)) - }, + extensions: extensions.into_iter().collect(), schema_env, session_data, query_data: None, @@ -255,18 +378,14 @@ impl Extensions { pub fn attach_query_data(&mut self, data: Arc) { self.query_data = Some(data); } -} -impl Drop for Extensions { - fn drop(&mut self) { - self.end(); - } -} - -#[doc(hidden)] -impl Extensions { #[inline] - fn context(&self) -> ExtensionContext<'_> { + pub(crate) fn is_empty(&self) -> bool { + self.extensions.is_empty() + } + + #[inline] + fn create_context(&self) -> ExtensionContext { ExtensionContext { schema_data: &self.schema_env.data, session_data: &self.session_data, @@ -274,124 +393,79 @@ impl Extensions { } } - pub fn is_empty(&self) -> bool { - self.extensions.is_none() - } - - pub fn start(&self) { - if let Some(e) = &self.extensions { - e.lock().iter_mut().for_each(|e| e.start(&self.context())); + pub async fn request(&self, request_fut: RequestFut<'_>) -> Response { + if !self.extensions.is_empty() { + let next = NextExtension::new(&self.extensions).with_request(request_fut); + next.request(&self.create_context()).await + } else { + request_fut.await } } - pub fn end(&self) { - if let Some(e) = &self.extensions { - e.lock().iter_mut().for_each(|e| e.end(&self.context())); + pub fn subscribe<'s>(&self, stream: BoxStream<'s, Response>) -> BoxStream<'s, Response> { + if !self.extensions.is_empty() { + let next = NextExtension::new(&self.extensions); + next.subscribe(&self.create_context(), stream) + } else { + stream.boxed() } } pub async fn prepare_request(&self, request: Request) -> ServerResult { - let mut request = request; - if let Some(e) = &self.extensions { - for e in e.lock().iter_mut() { - request = e.prepare_request(&self.context(), request).await?; - } - } - Ok(request) - } - - pub fn parse_start(&self, query_source: &str, variables: &Variables) { - if let Some(e) = &self.extensions { - e.lock() - .iter_mut() - .for_each(|e| e.parse_start(&self.context(), query_source, variables)); - } - } - - pub fn parse_end(&self, document: &ExecutableDocument) { - if let Some(e) = &self.extensions { - e.lock() - .iter_mut() - .for_each(|e| e.parse_end(&self.context(), document)); - } - } - - pub fn validation_start(&self) { - if let Some(e) = &self.extensions { - e.lock() - .iter_mut() - .for_each(|e| e.validation_start(&self.context())); - } - } - - pub fn validation_end(&self, result: &ValidationResult) { - if let Some(e) = &self.extensions { - e.lock() - .iter_mut() - .for_each(|e| e.validation_end(&self.context(), result)); - } - } - - pub fn execution_start(&self) { - if let Some(e) = &self.extensions { - e.lock() - .iter_mut() - .for_each(|e| e.execution_start(&self.context())); - } - } - - pub fn execution_end(&self) { - if let Some(e) = &self.extensions { - e.lock() - .iter_mut() - .for_each(|e| e.execution_end(&self.context())); - } - } - - pub fn resolve_start(&self, info: &ResolveInfo<'_>) { - if let Some(e) = &self.extensions { - e.lock() - .iter_mut() - .for_each(|e| e.resolve_start(&self.context(), info)); - } - } - - pub fn resolve_end(&self, resolve_id: &ResolveInfo<'_>) { - if let Some(e) = &self.extensions { - e.lock() - .iter_mut() - .for_each(|e| e.resolve_end(&self.context(), resolve_id)); - } - } - - pub fn error(&self, err: &ServerError) { - if let Some(e) = &self.extensions { - e.lock() - .iter_mut() - .for_each(|e| e.error(&self.context(), err)); - } - } - - pub fn result(&self) -> Option { - if let Some(e) = &self.extensions { - let value = e - .lock() - .iter_mut() - .filter_map(|e| { - if let Some(name) = e.name() { - e.result(&self.context()).map(|res| (Name::new(name), res)) - } else { - None - } - }) - .collect::>(); - if value.is_empty() { - None - } else { - Some(Value::Object(value)) - } + if !self.extensions.is_empty() { + let next = NextExtension::new(&self.extensions); + next.prepare_request(&self.create_context(), request).await } else { - None + Ok(request) + } + } + + pub async fn parse_query( + &self, + query: &str, + variables: &Variables, + parse_query_fut: ParseFut<'_>, + ) -> ServerResult { + if !self.extensions.is_empty() { + let next = NextExtension::new(&self.extensions).with_parse_query(parse_query_fut); + next.parse_query(&self.create_context(), query, variables) + .await + } else { + parse_query_fut.await + } + } + + pub async fn validation( + &self, + validation_fut: ValidationFut<'_>, + ) -> Result> { + if !self.extensions.is_empty() { + let next = NextExtension::new(&self.extensions).with_validation(validation_fut); + next.validation(&self.create_context()).await + } else { + validation_fut.await + } + } + + pub async fn execute(&self, execute_fut: ExecuteFut<'_>) -> Response { + if !self.extensions.is_empty() { + let next = NextExtension::new(&self.extensions).with_execute(execute_fut); + next.execute(&self.create_context()).await + } else { + execute_fut.await + } + } + + pub async fn resolve( + &self, + info: ResolveInfo<'_>, + resolve_fut: ResolveFut<'_>, + ) -> ServerResult> { + if !self.extensions.is_empty() { + let next = NextExtension::new(&self.extensions).with_resolve(resolve_fut); + next.resolve(&self.create_context(), info).await + } else { + resolve_fut.await } } } diff --git a/src/extensions/opentelemetry.rs b/src/extensions/opentelemetry.rs index b16c111d..57be9917 100644 --- a/src/extensions/opentelemetry.rs +++ b/src/extensions/opentelemetry.rs @@ -1,49 +1,25 @@ -use std::collections::HashMap; use std::sync::Arc; use async_graphql_parser::types::ExecutableDocument; use async_graphql_value::Variables; -use opentelemetry::trace::{SpanKind, TraceContextExt, Tracer}; +use futures_util::stream::BoxStream; +use futures_util::TryFutureExt; +use opentelemetry::trace::{FutureExt, SpanKind, TraceContextExt, Tracer}; use opentelemetry::{Context as OpenTelemetryContext, Key}; -use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; -use crate::{ServerError, ValidationResult}; - -const REQUEST_CTX: usize = 0; -const PARSE_CTX: usize = 1; -const VALIDATION_CTX: usize = 2; -const EXECUTE_CTX: usize = 3; - -#[inline] -fn resolve_ctx_id(resolver_id: usize) -> usize { - resolver_id + 10 -} +use crate::extensions::{ + Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo, +}; +use crate::{Response, ServerError, ServerResult, ValidationResult, Value}; const KEY_SOURCE: Key = Key::from_static_str("graphql.source"); const KEY_VARIABLES: Key = Key::from_static_str("graphql.variables"); const KEY_PARENT_TYPE: Key = Key::from_static_str("graphql.parentType"); const KEY_RETURN_TYPE: Key = Key::from_static_str("graphql.returnType"); -const KEY_RESOLVE_ID: Key = Key::from_static_str("graphql.resolveId"); const KEY_ERROR: Key = Key::from_static_str("graphql.error"); const KEY_COMPLEXITY: Key = Key::from_static_str("graphql.complexity"); const KEY_DEPTH: Key = Key::from_static_str("graphql.depth"); -/// OpenTelemetry extension configuration for each request. -#[derive(Default)] -#[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))] -pub struct OpenTelemetryConfig { - /// Use a context as the parent node of the entire query. - parent: spin::Mutex>, -} - -impl OpenTelemetryConfig { - /// Use a context as the parent of the entire query. - pub fn parent_context(mut self, cx: OpenTelemetryContext) -> Self { - *self.parent.get_mut() = Some(cx); - self - } -} - /// OpenTelemetry extension #[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))] pub struct OpenTelemetry { @@ -63,159 +39,127 @@ impl OpenTelemetry { } impl ExtensionFactory for OpenTelemetry { - fn create(&self) -> Box { - Box::new(OpenTelemetryExtension { + fn create(&self) -> Arc { + Arc::new(OpenTelemetryExtension { tracer: self.tracer.clone(), - contexts: Default::default(), }) } } struct OpenTelemetryExtension { tracer: Arc, - contexts: HashMap, -} - -impl OpenTelemetryExtension { - fn enter_context(&mut self, id: usize, cx: OpenTelemetryContext) { - let _ = cx.clone().attach(); - self.contexts.insert(id, cx); - } - - fn exit_context(&mut self, id: usize) -> Option { - if let Some(cx) = self.contexts.remove(&id) { - let _ = cx.clone().attach(); - Some(cx) - } else { - None - } - } } +#[async_trait::async_trait] impl Extension for OpenTelemetryExtension { - fn start(&mut self, ctx: &ExtensionContext<'_>) { - let request_cx = ctx - .data_opt::() - .and_then(|cfg| cfg.parent.lock().take()) - .unwrap_or_else(|| { - OpenTelemetryContext::current_with_span( + async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + next.request(ctx) + .with_context(OpenTelemetryContext::current_with_span( + self.tracer + .span_builder("request") + .with_kind(SpanKind::Server) + .start(&*self.tracer), + )) + .await + } + + fn subscribe<'s>( + &self, + ctx: &ExtensionContext<'_>, + stream: BoxStream<'s, Response>, + next: NextExtension<'_>, + ) -> BoxStream<'s, Response> { + Box::pin( + next.subscribe(ctx, stream) + .with_context(OpenTelemetryContext::current_with_span( self.tracer - .span_builder("request") + .span_builder("subscribe") .with_kind(SpanKind::Server) .start(&*self.tracer), - ) - }); - self.enter_context(REQUEST_CTX, request_cx); + )), + ) } - fn end(&mut self, _ctx: &ExtensionContext<'_>) { - self.exit_context(REQUEST_CTX); - } - - fn parse_start( - &mut self, - _ctx: &ExtensionContext<'_>, - query_source: &str, + async fn parse_query( + &self, + ctx: &ExtensionContext<'_>, + query: &str, variables: &Variables, - ) { - if let Some(parent_cx) = self.contexts.get(&REQUEST_CTX).cloned() { - let attributes = vec![ - KEY_SOURCE.string(query_source.to_string()), - KEY_VARIABLES.string(serde_json::to_string(variables).unwrap()), - ]; - let parse_span = self - .tracer - .span_builder("parse") - .with_kind(SpanKind::Server) - .with_attributes(attributes) - .with_parent_context(parent_cx) - .start(&*self.tracer); - let parse_cx = OpenTelemetryContext::current_with_span(parse_span); - self.enter_context(PARSE_CTX, parse_cx); - } + next: NextExtension<'_>, + ) -> ServerResult { + let attributes = vec![ + KEY_SOURCE.string(query.to_string()), + KEY_VARIABLES.string(serde_json::to_string(variables).unwrap()), + ]; + let span = self + .tracer + .span_builder("parse") + .with_kind(SpanKind::Server) + .with_attributes(attributes) + .start(&*self.tracer); + next.parse_query(ctx, query, variables) + .with_context(OpenTelemetryContext::current_with_span(span)) + .await } - fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, _document: &ExecutableDocument) { - self.exit_context(PARSE_CTX); + async fn validation( + &self, + ctx: &ExtensionContext<'_>, + next: NextExtension<'_>, + ) -> Result> { + let span = self + .tracer + .span_builder("validation") + .with_kind(SpanKind::Server) + .start(&*self.tracer); + next.validation(ctx) + .with_context(OpenTelemetryContext::current_with_span(span)) + .map_ok(|res| { + let current_cx = OpenTelemetryContext::current(); + let span = current_cx.span(); + span.set_attribute(KEY_COMPLEXITY.i64(res.complexity as i64)); + span.set_attribute(KEY_DEPTH.i64(res.depth as i64)); + res + }) + .await } - fn validation_start(&mut self, _ctx: &ExtensionContext<'_>) { - if let Some(parent_cx) = self.contexts.get(&REQUEST_CTX).cloned() { - let span = self - .tracer - .span_builder("validation") - .with_kind(SpanKind::Server) - .with_parent_context(parent_cx) - .start(&*self.tracer); - let validation_cx = OpenTelemetryContext::current_with_span(span); - self.enter_context(VALIDATION_CTX, validation_cx); - } + async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + let span = self + .tracer + .span_builder("execute") + .with_kind(SpanKind::Server) + .start(&*self.tracer); + next.execute(ctx) + .with_context(OpenTelemetryContext::current_with_span(span)) + .await } - fn validation_end(&mut self, _ctx: &ExtensionContext<'_>, result: &ValidationResult) { - if let Some(validation_cx) = self.exit_context(VALIDATION_CTX) { - let span = validation_cx.span(); - span.set_attribute(KEY_COMPLEXITY.i64(result.complexity as i64)); - span.set_attribute(KEY_DEPTH.i64(result.depth as i64)); - } - } - - fn execution_start(&mut self, _ctx: &ExtensionContext<'_>) { - let span = match self.contexts.get(&REQUEST_CTX).cloned() { - Some(parent_cx) => self - .tracer - .span_builder("execute") - .with_kind(SpanKind::Server) - .with_parent_context(parent_cx) - .start(&*self.tracer), - None => self - .tracer - .span_builder("execute") - .with_kind(SpanKind::Server) - .start(&*self.tracer), - }; - let execute_cx = OpenTelemetryContext::current_with_span(span); - self.enter_context(EXECUTE_CTX, execute_cx); - } - - fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) { - self.exit_context(EXECUTE_CTX); - } - - fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { - let parent_cx = match info.resolve_id.parent { - Some(parent_id) if parent_id > 0 => self.contexts.get(&resolve_ctx_id(parent_id)), - _ => self.contexts.get(&EXECUTE_CTX), - } - .cloned(); - - if let Some(parent_cx) = parent_cx { - let attributes = vec![ - KEY_RESOLVE_ID.i64(info.resolve_id.current as i64), - KEY_PARENT_TYPE.string(info.parent_type.to_string()), - KEY_RETURN_TYPE.string(info.return_type.to_string()), - ]; - let span = self - .tracer - .span_builder(&info.path_node.to_string()) - .with_kind(SpanKind::Server) - .with_parent_context(parent_cx) - .with_attributes(attributes) - .start(&*self.tracer); - let resolve_cx = OpenTelemetryContext::current_with_span(span); - self.enter_context(resolve_ctx_id(info.resolve_id.current), resolve_cx); - } - } - - fn resolve_end(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { - self.exit_context(resolve_ctx_id(info.resolve_id.current)); - } - - fn error(&mut self, _ctx: &ExtensionContext<'_>, err: &ServerError) { - if let Some(parent_cx) = self.contexts.get(&EXECUTE_CTX).cloned() { - parent_cx - .span() - .add_event("error".to_string(), vec![KEY_ERROR.string(err.to_string())]); - } + async fn resolve( + &self, + ctx: &ExtensionContext<'_>, + info: ResolveInfo<'_>, + next: NextExtension<'_>, + ) -> ServerResult> { + let attributes = vec![ + KEY_PARENT_TYPE.string(info.parent_type.to_string()), + KEY_RETURN_TYPE.string(info.return_type.to_string()), + ]; + let span = self + .tracer + .span_builder(&info.path_node.to_string()) + .with_kind(SpanKind::Server) + .with_attributes(attributes) + .start(&*self.tracer); + next.resolve(ctx, info) + .with_context(OpenTelemetryContext::current_with_span(span)) + .map_err(|err| { + let current_cx = OpenTelemetryContext::current(); + current_cx + .span() + .add_event("error".to_string(), vec![KEY_ERROR.string(err.to_string())]); + err + }) + .await } } diff --git a/src/extensions/tracing.rs b/src/extensions/tracing.rs index 0b278dbf..25614704 100644 --- a/src/extensions/tracing.rs +++ b/src/extensions/tracing.rs @@ -1,36 +1,15 @@ -use std::collections::HashMap; +use std::sync::Arc; -use tracing::{span, Level, Span}; +use futures_util::stream::BoxStream; +use tracing_futures::Instrument; +use tracinglib::{span, Level}; -use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; +use crate::extensions::{ + Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo, +}; +use crate::futures_util::TryFutureExt; use crate::parser::types::ExecutableDocument; -use crate::{ServerError, ValidationResult, Variables}; - -/// Tracing extension configuration for each request. -#[derive(Default)] -#[cfg_attr(docsrs, doc(cfg(feature = "tracing")))] -pub struct TracingConfig { - /// Use a span as the parent node of the entire query. - parent: spin::Mutex>, -} - -impl TracingConfig { - /// Use a span as the parent of the entire query. - pub fn parent_span(mut self, span: Span) -> Self { - *self.parent.get_mut() = Some(span); - self - } -} - -const REQUEST_CTX: usize = 0; -const PARSE_CTX: usize = 1; -const VALIDATION_CTX: usize = 2; -const EXECUTE_CTX: usize = 3; - -#[inline] -fn resolve_span_id(resolver_id: usize) -> usize { - resolver_id + 10 -} +use crate::{Response, ServerError, ServerResult, ValidationResult, Value, Variables}; /// Tracing extension /// @@ -42,7 +21,7 @@ fn resolve_span_id(resolver_id: usize) -> usize { /// /// ```no_run /// use async_graphql::*; -/// use async_graphql::extensions::{Tracing, TracingConfig}; +/// use async_graphql::extensions::Tracing; /// use tracing::{span, Level, Instrument}; /// /// #[derive(SimpleObject)] @@ -50,164 +29,112 @@ fn resolve_span_id(resolver_id: usize) -> usize { /// value: i32, /// } /// -/// let schema = Schema::build(Query { value: 100 }, EmptyMutation, EmptySubscription). -/// extension(Tracing::default()) +/// let schema = Schema::build(Query { value: 100 }, EmptyMutation, EmptySubscription) +/// .extension(Tracing) /// .finish(); /// /// tokio::runtime::Runtime::new().unwrap().block_on(async { /// schema.execute(Request::new("{ value }")).await; /// }); -/// -/// // tracing in parent span -/// tokio::runtime::Runtime::new().unwrap().block_on(async { -/// let root_span = span!( -/// parent: None, -/// Level::INFO, -/// "span root" -/// ); -/// schema.execute(Request::new("{ value }")).instrument(root_span).await; -/// }); -/// -/// // replace root span -/// tokio::runtime::Runtime::new().unwrap().block_on(async { -/// let root_span = span!( -/// parent: None, -/// Level::INFO, -/// "span root" -/// ); -/// schema.execute(Request::new("{ value }").data(TracingConfig::default().parent_span(root_span))).await; -/// }); /// ``` -#[derive(Default)] #[cfg_attr(docsrs, doc(cfg(feature = "tracing")))] pub struct Tracing; impl ExtensionFactory for Tracing { - fn create(&self) -> Box { - Box::new(TracingExtension::default()) + fn create(&self) -> Arc { + Arc::new(TracingExtension::default()) } } #[derive(Default)] -struct TracingExtension { - spans: HashMap, -} - -impl TracingExtension { - fn enter_span(&mut self, id: usize, span: Span) -> &Span { - let _ = span.enter(); - self.spans.insert(id, span); - self.spans.get(&id).unwrap() - } - - fn exit_span(&mut self, id: usize) { - if let Some(span) = self.spans.remove(&id) { - let _ = span.enter(); - } - } -} +struct TracingExtension; +#[async_trait::async_trait] impl Extension for TracingExtension { - fn start(&mut self, ctx: &ExtensionContext<'_>) { - let request_span = ctx - .data_opt::() - .and_then(|cfg| cfg.parent.lock().take()) - .unwrap_or_else(|| { - span!( - target: "async_graphql::graphql", - Level::INFO, - "request", - ) - }); - self.enter_span(REQUEST_CTX, request_span); + async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + next.request(ctx) + .instrument(span!( + target: "async_graphql::graphql", + Level::INFO, + "request", + )) + .await } - fn end(&mut self, _ctx: &ExtensionContext<'_>) { - self.exit_span(REQUEST_CTX); + fn subscribe<'s>( + &self, + ctx: &ExtensionContext<'_>, + stream: BoxStream<'s, Response>, + next: NextExtension<'_>, + ) -> BoxStream<'s, Response> { + Box::pin(next.subscribe(ctx, stream).instrument(span!( + target: "async_graphql::graphql", + Level::INFO, + "subscribe", + ))) } - fn parse_start( - &mut self, - _ctx: &ExtensionContext<'_>, - query_source: &str, + async fn parse_query( + &self, + ctx: &ExtensionContext<'_>, + query: &str, variables: &Variables, - ) { - if let Some(parent) = self.spans.get(&REQUEST_CTX) { - let variables = serde_json::to_string(&variables).unwrap(); - let parse_span = span!( - target: "async_graphql::graphql", - parent: parent, - Level::INFO, - "parse", - source = query_source, - variables = %variables, - ); - self.enter_span(PARSE_CTX, parse_span); - } + next: NextExtension<'_>, + ) -> ServerResult { + let span = span!( + target: "async_graphql::graphql", + Level::INFO, + "parse", + source = query, + variables = %serde_json::to_string(&variables).unwrap(), + ); + next.parse_query(ctx, query, variables) + .instrument(span) + .await } - fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, _document: &ExecutableDocument) { - self.exit_span(PARSE_CTX); + async fn validation( + &self, + ctx: &ExtensionContext<'_>, + next: NextExtension<'_>, + ) -> Result> { + let span = span!( + target: "async_graphql::graphql", + Level::INFO, + "validation" + ); + next.validation(ctx).instrument(span).await } - fn validation_start(&mut self, _ctx: &ExtensionContext<'_>) { - if let Some(parent) = self.spans.get(&REQUEST_CTX) { - let span = span!( - target: "async_graphql::graphql", - parent: parent, - Level::INFO, - "validation" - ); - self.enter_span(VALIDATION_CTX, span); - } + async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + let span = span!( + target: "async_graphql::graphql", + Level::INFO, + "execute" + ); + next.execute(ctx).instrument(span).await } - fn validation_end(&mut self, _ctx: &ExtensionContext<'_>, _result: &ValidationResult) { - self.exit_span(VALIDATION_CTX); - } - - fn execution_start(&mut self, _ctx: &ExtensionContext<'_>) { - if let Some(parent) = self.spans.get(&REQUEST_CTX) { - let span = span!( - target: "async_graphql::graphql", - parent: parent, - Level::INFO, - "execute" - ); - self.enter_span(EXECUTE_CTX, span); - }; - } - - fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) { - self.exit_span(EXECUTE_CTX); - } - - fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { - let parent = match info.resolve_id.parent { - Some(parent_id) if parent_id > 0 => self.spans.get(&resolve_span_id(parent_id)), - _ => self.spans.get(&EXECUTE_CTX), - }; - - if let Some(parent) = parent { - let span = span!( - target: "async_graphql::graphql", - parent: parent, - Level::INFO, - "field", - id = %info.resolve_id.current, - path = %info.path_node, - parent_type = %info.parent_type, - return_type = %info.return_type, - ); - self.enter_span(resolve_span_id(info.resolve_id.current), span); - } - } - - fn resolve_end(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { - self.exit_span(resolve_span_id(info.resolve_id.current)); - } - - fn error(&mut self, _ctx: &ExtensionContext<'_>, err: &ServerError) { - tracing::error!(target: "async_graphql::graphql", error = %err.message); + async fn resolve( + &self, + ctx: &ExtensionContext<'_>, + info: ResolveInfo<'_>, + next: NextExtension<'_>, + ) -> ServerResult> { + let span = span!( + target: "async_graphql::graphql", + Level::INFO, + "field", + path = %info.path_node, + parent_type = %info.parent_type, + return_type = %info.return_type, + ); + next.resolve(ctx, info) + .instrument(span) + .map_err(|err| { + tracinglib::error!(target: "async_graphql::graphql", error = %err.message); + err + }) + .await } } diff --git a/src/resolver_utils/container.rs b/src/resolver_utils/container.rs index e670f6b9..f7c4ed1a 100644 --- a/src/resolver_utils/container.rs +++ b/src/resolver_utils/container.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use std::future::Future; use std::pin::Pin; -use crate::extensions::{ErrorLogger, ResolveInfo}; +use crate::extensions::ResolveInfo; use crate::parser::types::Selection; use crate::registry::MetaType; use crate::{ @@ -174,18 +174,18 @@ impl<'a> Fields<'a> { async move { let ctx_field = ctx.with_field(field); let field_name = ctx_field.item.node.response_key().node.clone(); + let extensions = &ctx.query_env.extensions; - let res = if ctx_field.query_env.extensions.is_empty() { + if extensions.is_empty() { match root.resolve_field(&ctx_field).await { Ok(value) => Ok((field_name, value.unwrap_or_default())), Err(e) => { Err(e.path(PathSegment::Field(field_name.to_string()))) } - }? + } } else { let type_name = T::type_name(); let resolve_info = ResolveInfo { - resolve_id: ctx_field.resolve_id, path_node: ctx_field.path_node.as_ref().unwrap(), parent_type: &type_name, return_type: match ctx_field @@ -210,20 +210,16 @@ impl<'a> Fields<'a> { }, }; - ctx_field.query_env.extensions.resolve_start(&resolve_info); - - let res = match root.resolve_field(&ctx_field).await { + let resolve_fut = async { root.resolve_field(&ctx_field).await }; + futures_util::pin_mut!(resolve_fut); + let res = extensions.resolve(resolve_info, &mut resolve_fut).await; + match res { Ok(value) => Ok((field_name, value.unwrap_or_default())), Err(e) => { Err(e.path(PathSegment::Field(field_name.to_string()))) } } - .log_error(&ctx_field.query_env.extensions)?; - ctx_field.query_env.extensions.resolve_end(&resolve_info); - res - }; - - Ok(res) + } } })); } diff --git a/src/resolver_utils/list.rs b/src/resolver_utils/list.rs index ced30ea9..92f6d8f4 100644 --- a/src/resolver_utils/list.rs +++ b/src/resolver_utils/list.rs @@ -1,4 +1,4 @@ -use crate::extensions::{ErrorLogger, ResolveInfo}; +use crate::extensions::ResolveInfo; use crate::parser::types::Field; use crate::{ContextSelectionSet, OutputType, PathSegment, Positioned, ServerResult, Type, Value}; @@ -9,38 +9,49 @@ pub async fn resolve_list<'a, T: OutputType + 'a>( iter: impl IntoIterator, len: Option, ) -> ServerResult { - let mut futures = len.map(Vec::with_capacity).unwrap_or_default(); + let extensions = &ctx.query_env.extensions; - for (idx, item) in iter.into_iter().enumerate() { - let ctx_idx = ctx.with_index(idx); - futures.push(async move { - if ctx_idx.query_env.extensions.is_empty() { + if extensions.is_empty() { + let mut futures = len.map(Vec::with_capacity).unwrap_or_default(); + for (idx, item) in iter.into_iter().enumerate() { + futures.push({ + let ctx = ctx.clone(); + async move { + let ctx_idx = ctx.with_index(idx); + let resolve_info = ResolveInfo { + path_node: ctx_idx.path_node.as_ref().unwrap(), + parent_type: &Vec::::type_name(), + return_type: &T::qualified_type_name(), + }; + let resolve_fut = async { + OutputType::resolve(&item, &ctx_idx, field) + .await + .map(Option::Some) + .map_err(|e| e.path(PathSegment::Index(idx))) + }; + futures_util::pin_mut!(resolve_fut); + extensions + .resolve(resolve_info, &mut resolve_fut) + .await + .map(|value| value.expect("You definitely encountered a bug!")) + } + }); + } + Ok(Value::List( + futures_util::future::try_join_all(futures).await?, + )) + } else { + let mut futures = len.map(Vec::with_capacity).unwrap_or_default(); + for (idx, item) in iter.into_iter().enumerate() { + let ctx_idx = ctx.with_index(idx); + futures.push(async move { OutputType::resolve(&item, &ctx_idx, field) .await .map_err(|e| e.path(PathSegment::Index(idx))) - .log_error(&ctx_idx.query_env.extensions) - } else { - let resolve_info = ResolveInfo { - resolve_id: ctx_idx.resolve_id, - path_node: ctx_idx.path_node.as_ref().unwrap(), - parent_type: &Vec::::type_name(), - return_type: &T::qualified_type_name(), - }; - - ctx_idx.query_env.extensions.resolve_start(&resolve_info); - - let res = OutputType::resolve(&item, &ctx_idx, field) - .await - .map_err(|e| e.path(PathSegment::Index(idx))) - .log_error(&ctx_idx.query_env.extensions)?; - - ctx_idx.query_env.extensions.resolve_end(&resolve_info); - Ok(res) - } - }); + }); + } + Ok(Value::List( + futures_util::future::try_join_all(futures).await?, + )) } - - Ok(Value::List( - futures_util::future::try_join_all(futures).await?, - )) } diff --git a/src/response.rs b/src/response.rs index a16014ae..acf23180 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,5 +1,6 @@ -use http::header::HeaderMap; +use std::collections::BTreeMap; +use http::header::HeaderMap; use serde::{Deserialize, Serialize}; use crate::{CacheControl, Result, ServerError, Value}; @@ -12,8 +13,8 @@ pub struct Response { pub data: Value, /// Extensions result - #[serde(skip_serializing_if = "Option::is_none", default)] - pub extensions: Option, + #[serde(skip_serializing_if = "BTreeMap::is_empty", default)] + pub extensions: BTreeMap, /// Cache control value #[serde(skip)] @@ -47,10 +48,11 @@ impl Response { } } - /// Set the extensions result of the response. + /// Set the extension result of the response. #[must_use] - pub fn extensions(self, extensions: Option) -> Self { - Self { extensions, ..self } + pub fn extension(mut self, name: impl Into, value: Value) -> Self { + self.extensions.insert(name.into(), value); + self } /// Set the http headers of the response. diff --git a/src/schema.rs b/src/schema.rs index d24f381d..21f5da91 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,14 +1,12 @@ use std::any::Any; -use std::collections::BTreeMap; use std::ops::Deref; -use std::sync::atomic::AtomicUsize; use std::sync::Arc; use futures_util::stream::{self, Stream, StreamExt}; use indexmap::map::IndexMap; -use crate::context::{Data, QueryEnvInner, ResolveId}; -use crate::extensions::{ErrorLogger, ExtensionFactory, Extensions}; +use crate::context::{Data, QueryEnvInner}; +use crate::extensions::{ExtensionFactory, Extensions}; use crate::model::__DirectiveLocation; use crate::parser::parse_query; use crate::parser::types::{DocumentOperations, OperationType}; @@ -19,7 +17,7 @@ use crate::types::QueryRoot; use crate::validation::{check_rules, ValidationMode}; use crate::{ BatchRequest, BatchResponse, CacheControl, ContextBase, ObjectType, QueryEnv, Request, - Response, ServerError, SubscriptionType, Type, Value, ID, + Response, ServerError, SubscriptionType, Type, ID, }; /// Schema builder @@ -354,17 +352,11 @@ where } fn create_extensions(&self, session_data: Arc) -> Extensions { - let extensions = Extensions::new( - self.0 - .extensions - .iter() - .map(|factory| factory.create()) - .collect::>(), + Extensions::new( + self.extensions.iter().map(|f| f.create()), self.env.clone(), session_data, - ); - extensions.start(); - extensions + ) } async fn prepare_request( @@ -376,36 +368,41 @@ where let mut request = request; let query_data = Arc::new(std::mem::take(&mut request.data)); extensions.attach_query_data(query_data.clone()); - let request = extensions.prepare_request(request).await?; - extensions.parse_start(&request.query, &request.variables); - let document = parse_query(&request.query) - .map_err(Into::::into) - .log_error(&extensions)?; - extensions.parse_end(&document); + let request = extensions.prepare_request(request).await?; + let document = { + let query = &request.query; + let fut_parse = async { parse_query(&query).map_err(Into::::into) }; + futures_util::pin_mut!(fut_parse); + extensions + .parse_query(&query, &request.variables, &mut fut_parse) + .await? + }; // check rules - extensions.validation_start(); - let validation_result = check_rules( - &self.env.registry, - &document, - Some(&request.variables), - self.validation_mode, - ) - .log_error(&extensions)?; - extensions.validation_end(&validation_result); + let validation_result = { + let validation_fut = async { + check_rules( + &self.env.registry, + &document, + Some(&request.variables), + self.validation_mode, + ) + }; + futures_util::pin_mut!(validation_fut); + extensions.validation(&mut validation_fut).await? + }; // check limit if let Some(limit_complexity) = self.complexity { if validation_result.complexity > limit_complexity { - return Err(vec![ServerError::new("Query is too complex.")]).log_error(&extensions); + return Err(vec![ServerError::new("Query is too complex.")]); } } if let Some(limit_depth) = self.depth { if validation_result.depth > limit_depth { - return Err(vec![ServerError::new("Query is nested too deep.")]) - .log_error(&extensions); + return Err(vec![ServerError::new("Query is nested too deep.")]); } } @@ -432,10 +429,7 @@ where }; let operation = match operation { Ok(operation) => operation, - Err(e) => { - extensions.error(&e); - return Err(vec![e]); - } + Err(e) => return Err(vec![e]), }; let env = QueryEnvInner { @@ -454,49 +448,58 @@ where async fn execute_once(&self, env: QueryEnv) -> Response { // execute - let inc_resolve_id = AtomicUsize::default(); let ctx = ContextBase { path_node: None, - resolve_id: ResolveId::root(), - inc_resolve_id: &inc_resolve_id, item: &env.operation.node.selection_set, schema_env: &self.env, query_env: &env, }; - env.extensions.execution_start(); - let data = match &env.operation.node.ty { + let res = match &env.operation.node.ty { OperationType::Query => resolve_container(&ctx, &self.query).await, OperationType::Mutation => resolve_container_serial(&ctx, &self.mutation).await, OperationType::Subscription => { return Response::from_errors(vec![ServerError::new( "Subscriptions are not supported on this transport.", - )]) + )]); } }; - env.extensions.execution_end(); - let extensions = env.extensions.result(); - - match data { - Ok(data) => Response::new(data), - Err(e) => Response::from_errors(vec![e]), + match res { + Ok(data) => { + let resp = Response::new(data); + resp.http_headers(std::mem::take(&mut *env.http_headers.lock())) + } + Err(err) => Response::from_errors(vec![err]), } - .extensions(extensions) - .http_headers(std::mem::take(&mut *env.http_headers.lock())) } /// Execute a GraphQL query. pub async fn execute(&self, request: impl Into) -> Response { let request = request.into(); let extensions = self.create_extensions(Default::default()); - match self - .prepare_request(extensions, request, Default::default()) - .await - { - Ok((env, cache_control)) => self.execute_once(env).await.cache_control(cache_control), - Err(errors) => Response::from_errors(errors), - } + let request_fut = { + let extensions = extensions.clone(); + async move { + match self + .prepare_request(extensions, request, Default::default()) + .await + { + Ok((env, cache_control)) => { + let fut = async { + self.execute_once(env.clone()) + .await + .cache_control(cache_control) + }; + futures_util::pin_mut!(fut); + env.extensions.execute(&mut fut).await + } + Err(errors) => Response::from_errors(errors), + } + } + }; + futures_util::pin_mut!(request_fut); + extensions.request(&mut request_fut).await } /// Execute a GraphQL batch query. @@ -518,68 +521,52 @@ where &self, request: impl Into + Send, session_data: Arc, - ) -> impl Stream + Send { + ) -> impl Stream + Send + Unpin { let schema = self.clone(); let request = request.into(); let extensions = self.create_extensions(session_data.clone()); - async_stream::stream! { - let (env, cache_control) = match schema.prepare_request(extensions, request, session_data).await { - Ok(res) => res, - Err(errors) => { - yield Response::from_errors(errors); + let stream = futures_util::stream::StreamExt::boxed({ + let extensions = extensions.clone(); + async_stream::stream! { + let (env, cache_control) = match schema.prepare_request(extensions, request, session_data).await { + Ok(res) => res, + Err(errors) => { + yield Response::from_errors(errors); + return; + } + }; + + if env.operation.node.ty != OperationType::Subscription { + yield schema.execute_once(env).await.cache_control(cache_control); return; } - }; - if env.operation.node.ty != OperationType::Subscription { - yield schema - .execute_once(env) - .await - .cache_control(cache_control); - return; - } + let ctx = env.create_context( + &schema.env, + None, + &env.operation.node.selection_set, + ); - let resolve_id = AtomicUsize::default(); - let ctx = env.create_context( - &schema.env, - None, - &env.operation.node.selection_set, - ResolveId::root(), - &resolve_id, - ); + let mut streams = Vec::new(); + if let Err(err) = collect_subscription_streams(&ctx, &schema.subscription, &mut streams) { + yield Response::from_errors(vec![err]); + } - let mut streams = Vec::new(); - if let Err(e) = collect_subscription_streams(&ctx, &schema.subscription, &mut streams) { - env.extensions.execution_end(); - yield Response::from_errors(vec![e]); - return; - } - - let mut stream = stream::select_all(streams); - while let Some(data) = stream.next().await { - let is_err = data.is_err(); - let extensions = env.extensions.result(); - yield match data { - Ok((name, value)) => { - let mut map = BTreeMap::new(); - map.insert(name, value); - Response::new(Value::Object(map)) - }, - Err(e) => Response::from_errors(vec![e]), - }.extensions(extensions); - if is_err { - break; + let mut stream = stream::select_all(streams); + while let Some(resp) = stream.next().await { + yield resp; } } - } + }); + extensions.subscribe(stream) } /// Execute a GraphQL subscription. pub fn execute_stream( &self, request: impl Into, - ) -> impl Stream + Send { + ) -> impl Stream + Send + Unpin { self.execute_stream_with_session_data(request.into(), Default::default()) } } diff --git a/src/subscription.rs b/src/subscription.rs index 4cafd4da..a25d0ad6 100644 --- a/src/subscription.rs +++ b/src/subscription.rs @@ -3,9 +3,7 @@ use std::pin::Pin; use futures_util::stream::{Stream, StreamExt}; use crate::parser::types::{Selection, TypeCondition}; -use crate::{ - Context, ContextSelectionSet, Name, PathSegment, ServerError, ServerResult, Type, Value, -}; +use crate::{Context, ContextSelectionSet, PathSegment, Response, ServerError, ServerResult, Type}; /// A GraphQL subscription object pub trait SubscriptionType: Type + Send + Sync { @@ -19,10 +17,10 @@ pub trait SubscriptionType: Type + Send + Sync { fn create_field_stream<'a>( &'a self, ctx: &'a Context<'_>, - ) -> Option> + Send + 'a>>>; + ) -> Option + Send + 'a>>>; } -type BoxFieldStream<'a> = Pin> + 'a + Send>>; +type BoxFieldStream<'a> = Pin + 'a + Send>>; pub(crate) fn collect_subscription_streams<'a, T: SubscriptionType + 'static>( ctx: &ContextSelectionSet<'a>, @@ -41,16 +39,14 @@ pub(crate) fn collect_subscription_streams<'a, T: SubscriptionType + 'static>( let field_name = ctx.item.node.response_key().node.clone(); let stream = root.create_field_stream(&ctx); if let Some(mut stream) = stream { - while let Some(item) = stream.next().await { - yield match item { - Ok(value) => Ok((field_name.to_owned(), value)), - Err(e) => Err(e.path(PathSegment::Field(field_name.to_string()))), - }; + while let Some(resp) = stream.next().await { + yield resp; } } else { - yield Err(ServerError::new(format!(r#"Cannot query field "{}" on type "{}"."#, field_name, T::type_name())) + let err = ServerError::new(format!(r#"Cannot query field "{}" on type "{}"."#, field_name, T::type_name())) .at(ctx.item.pos) - .path(PathSegment::Field(field_name.to_string()))); + .path(PathSegment::Field(field_name.to_string())); + yield Response::from_errors(vec![err]); } } })), @@ -98,7 +94,7 @@ impl SubscriptionType for &T { fn create_field_stream<'a>( &'a self, ctx: &'a Context<'_>, - ) -> Option> + Send + 'a>>> { + ) -> Option + Send + 'a>>> { T::create_field_stream(*self, ctx) } } diff --git a/src/types/empty_subscription.rs b/src/types/empty_subscription.rs index 66e31e2f..aa4f5a3c 100644 --- a/src/types/empty_subscription.rs +++ b/src/types/empty_subscription.rs @@ -3,7 +3,7 @@ use std::pin::Pin; use futures_util::stream::{self, Stream}; -use crate::{registry, Context, ServerError, ServerResult, SubscriptionType, Type, Value}; +use crate::{registry, Context, Response, ServerError, SubscriptionType, Type}; /// Empty subscription /// @@ -37,12 +37,13 @@ impl SubscriptionType for EmptySubscription { fn create_field_stream<'a>( &'a self, ctx: &'a Context<'_>, - ) -> Option> + Send + 'a>>> + ) -> Option + Send + 'a>>> where Self: Send + Sync + 'static + Sized, { Some(Box::pin(stream::once(async move { - Err(ServerError::new("Schema is not configured for mutations.").at(ctx.item.pos)) + let err = ServerError::new("Schema is not configured for mutations.").at(ctx.item.pos); + Response::from_errors(vec![err]) }))) } } diff --git a/src/validation/mod.rs b/src/validation/mod.rs index 5b4e561c..81658447 100644 --- a/src/validation/mod.rs +++ b/src/validation/mod.rs @@ -16,6 +16,7 @@ pub use visitor::VisitorContext; use visitor::{visit, VisitorNil}; /// Validation results. +#[derive(Debug, Copy, Clone)] pub struct ValidationResult { /// Cache control pub cache_control: CacheControl, diff --git a/tests/extension.rs b/tests/extension.rs index 6f9c263a..93007045 100644 --- a/tests/extension.rs +++ b/tests/extension.rs @@ -1,12 +1,15 @@ use std::sync::Arc; -use async_graphql::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; +use async_graphql::extensions::{ + Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo, +}; +use async_graphql::futures_util::stream::BoxStream; use async_graphql::parser::types::ExecutableDocument; use async_graphql::*; use async_graphql_value::ConstValue; +use futures_util::lock::Mutex; use futures_util::stream::Stream; use futures_util::StreamExt; -use spin::Mutex; #[tokio::test] pub async fn test_extension_ctx() { @@ -18,7 +21,7 @@ pub async fn test_extension_ctx() { #[Object] impl Query { async fn value(&self, ctx: &Context<'_>) -> i32 { - *ctx.data_unchecked::().0.lock() + *ctx.data_unchecked::().0.lock().await } } @@ -27,7 +30,7 @@ pub async fn test_extension_ctx() { #[Subscription] impl Subscription { async fn value(&self, ctx: &Context<'_>) -> impl Stream { - let data = *ctx.data_unchecked::().0.lock(); + let data = *ctx.data_unchecked::().0.lock().await; futures_util::stream::once(async move { data }) } } @@ -36,23 +39,25 @@ pub async fn test_extension_ctx() { #[async_trait::async_trait] impl Extension for MyExtensionImpl { - fn parse_start( - &mut self, + async fn parse_query( + &self, ctx: &ExtensionContext<'_>, - _query_source: &str, - _variables: &Variables, - ) { + query: &str, + variables: &Variables, + next: NextExtension<'_>, + ) -> ServerResult { if let Ok(data) = ctx.data::() { - *data.0.lock() = 100; + *data.0.lock().await = 100; } + next.parse_query(ctx, query, variables).await } } struct MyExtension; impl ExtensionFactory for MyExtension { - fn create(&self) -> Box { - Box::new(MyExtensionImpl) + fn create(&self) -> Arc { + Arc::new(MyExtensionImpl) } } @@ -104,12 +109,10 @@ pub async fn test_extension_ctx() { let mut data = Data::default(); data.insert(MyData::default()); - let mut stream = schema - .execute_stream_with_session_data( - Request::new("subscription { value }"), - Arc::new(data), - ) - .boxed(); + let mut stream = schema.execute_stream_with_session_data( + Request::new("subscription { value }"), + Arc::new(data), + ); assert_eq!( stream.next().await.unwrap().into_result().unwrap().data, value! ({ @@ -128,67 +131,83 @@ pub async fn test_extension_call_order() { #[async_trait::async_trait] #[allow(unused_variables)] impl Extension for MyExtensionImpl { - fn name(&self) -> Option<&'static str> { - Some("test") + async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + self.calls.lock().await.push("request_start"); + let res = next.request(ctx).await; + self.calls.lock().await.push("request_end"); + res } - fn start(&mut self, ctx: &ExtensionContext<'_>) { - self.calls.lock().push("start"); - } - - fn end(&mut self, ctx: &ExtensionContext<'_>) { - self.calls.lock().push("end"); + fn subscribe<'s>( + &self, + ctx: &ExtensionContext<'_>, + mut stream: BoxStream<'s, Response>, + next: NextExtension<'_>, + ) -> BoxStream<'s, Response> { + let calls = self.calls.clone(); + let stream = async_stream::stream! { + calls.lock().await.push("subscribe_start"); + while let Some(item) = stream.next().await { + yield item; + } + calls.lock().await.push("subscribe_end"); + }; + Box::pin(stream) } async fn prepare_request( - &mut self, + &self, ctx: &ExtensionContext<'_>, request: Request, + next: NextExtension<'_>, ) -> ServerResult { - self.calls.lock().push("prepare_request"); - Ok(request) + self.calls.lock().await.push("prepare_request_start"); + let res = next.prepare_request(ctx, request).await; + self.calls.lock().await.push("prepare_request_end"); + res } - fn parse_start( - &mut self, + async fn parse_query( + &self, ctx: &ExtensionContext<'_>, - query_source: &str, + query: &str, variables: &Variables, - ) { - self.calls.lock().push("parse_start"); + next: NextExtension<'_>, + ) -> ServerResult { + self.calls.lock().await.push("parse_query_start"); + let res = next.parse_query(ctx, query, variables).await; + self.calls.lock().await.push("parse_query_end"); + res } - fn parse_end(&mut self, ctx: &ExtensionContext<'_>, document: &ExecutableDocument) { - self.calls.lock().push("parse_end"); + async fn validation( + &self, + ctx: &ExtensionContext<'_>, + next: NextExtension<'_>, + ) -> Result> { + self.calls.lock().await.push("validation_start"); + let res = next.validation(ctx).await; + self.calls.lock().await.push("validation_end"); + res } - fn validation_start(&mut self, ctx: &ExtensionContext<'_>) { - self.calls.lock().push("validation_start"); + async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + self.calls.lock().await.push("execute_start"); + let res = next.execute(ctx).await; + self.calls.lock().await.push("execute_end"); + res } - fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) { - self.calls.lock().push("validation_end"); - } - - fn execution_start(&mut self, ctx: &ExtensionContext<'_>) { - self.calls.lock().push("execution_start"); - } - - fn execution_end(&mut self, ctx: &ExtensionContext<'_>) { - self.calls.lock().push("execution_end"); - } - - fn resolve_start(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { - self.calls.lock().push("resolve_start"); - } - - fn resolve_end(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { - self.calls.lock().push("resolve_end"); - } - - fn result(&mut self, ctx: &ExtensionContext<'_>) -> Option { - self.calls.lock().push("result"); - None + async fn resolve( + &self, + ctx: &ExtensionContext<'_>, + info: ResolveInfo<'_>, + next: NextExtension<'_>, + ) -> ServerResult> { + self.calls.lock().await.push("resolve_start"); + let res = next.resolve(ctx, info).await; + self.calls.lock().await.push("resolve_end"); + res } } @@ -197,8 +216,8 @@ pub async fn test_extension_call_order() { } impl ExtensionFactory for MyExtension { - fn create(&self) -> Box { - Box::new(MyExtensionImpl { + fn create(&self) -> Arc { + Arc::new(MyExtensionImpl { calls: self.calls.clone(), }) } @@ -238,24 +257,24 @@ pub async fn test_extension_call_order() { .await .into_result() .unwrap(); - let calls = calls.lock(); + let calls = calls.lock().await; assert_eq!( &*calls, &vec![ - "start", - "prepare_request", - "parse_start", - "parse_end", + "request_start", + "prepare_request_start", + "prepare_request_end", + "parse_query_start", + "parse_query_end", "validation_start", "validation_end", - "execution_start", + "execute_start", "resolve_start", "resolve_end", "resolve_start", "resolve_end", - "execution_end", - "result", - "end", + "execute_end", + "request_end", ] ); } @@ -267,34 +286,36 @@ pub async fn test_extension_call_order() { calls: calls.clone(), }) .finish(); - let mut stream = schema.execute_stream("subscription { value }").boxed(); + let mut stream = schema.execute_stream("subscription { value }"); while let Some(_) = stream.next().await {} - let calls = calls.lock(); + let calls = calls.lock().await; assert_eq!( &*calls, &vec![ - "start", - "prepare_request", - "parse_start", - "parse_end", + "subscribe_start", + "prepare_request_start", + "prepare_request_end", + "parse_query_start", + "parse_query_end", "validation_start", "validation_end", - "execution_start", + // push 1 + "execute_start", "resolve_start", "resolve_end", - "execution_end", - "result", - "execution_start", + "execute_end", + // push 2 + "execute_start", "resolve_start", "resolve_end", - "execution_end", - "result", - "execution_start", + "execute_end", + // push 3 + "execute_start", "resolve_start", "resolve_end", - "execution_end", - "result", - "end", + "execute_end", + // end + "subscribe_end", ] ); } diff --git a/tests/field_features.rs b/tests/field_features.rs index 3c67c91b..8bb3d13e 100644 --- a/tests/field_features.rs +++ b/tests/field_features.rs @@ -119,7 +119,7 @@ pub async fn test_field_features() { }] ); - let mut stream = schema.execute_stream("subscription { values }").boxed(); + let mut stream = schema.execute_stream("subscription { values }"); assert_eq!( stream .next() @@ -131,7 +131,7 @@ pub async fn test_field_features() { }) ); - let mut stream = schema.execute_stream("subscription { valuesBson }").boxed(); + let mut stream = schema.execute_stream("subscription { valuesBson }"); assert_eq!( stream.next().await.map(|resp| resp.data).unwrap(), value!({ @@ -142,7 +142,6 @@ pub async fn test_field_features() { assert_eq!( schema .execute_stream("subscription { valuesAbc }") - .boxed() .next() .await .unwrap() diff --git a/tests/generic_types.rs b/tests/generic_types.rs index 6befa0dd..48d6533a 100644 --- a/tests/generic_types.rs +++ b/tests/generic_types.rs @@ -284,8 +284,7 @@ pub async fn test_generic_subscription() { { let mut stream = schema .execute_stream("subscription { values }") - .map(|resp| resp.into_result().unwrap().data) - .boxed(); + .map(|resp| resp.into_result().unwrap().data); for i in 1..=2 { assert_eq!(value!({ "values": i }), stream.next().await.unwrap()); } diff --git a/tests/guard.rs b/tests/guard.rs index 1b898016..edc0877a 100644 --- a/tests/guard.rs +++ b/tests/guard.rs @@ -117,7 +117,6 @@ pub async fn test_guard_simple_rule() { assert_eq!( schema .execute_stream(Request::new("subscription { values }").data(Role::Guest)) - .boxed() .next() .await .unwrap() diff --git a/tests/merged_object.rs b/tests/merged_object.rs index 4ec13cbb..77ccf663 100644 --- a/tests/merged_object.rs +++ b/tests/merged_object.rs @@ -196,8 +196,7 @@ pub async fn test_merged_subscription() { { let mut stream = schema .execute_stream("subscription { events1 }") - .map(|resp| resp.into_result().unwrap().data) - .boxed(); + .map(|resp| resp.into_result().unwrap().data); for i in 0i32..10 { assert_eq!( value!({ @@ -212,8 +211,7 @@ pub async fn test_merged_subscription() { { let mut stream = schema .execute_stream("subscription { events2 }") - .map(|resp| resp.into_result().unwrap().data) - .boxed(); + .map(|resp| resp.into_result().unwrap().data); for i in 10i32..20 { assert_eq!( value!({ diff --git a/tests/raw_ident.rs b/tests/raw_ident.rs index 8360732d..663a5798 100644 --- a/tests/raw_ident.rs +++ b/tests/raw_ident.rs @@ -66,8 +66,7 @@ pub async fn test_input_value_custom_error() { let mut stream = schema .execute_stream("subscription { type }") .map(|resp| resp.into_result()) - .map_ok(|resp| resp.data) - .boxed(); + .map_ok(|resp| resp.data); for i in 0..10 { assert_eq!(value!({ "type": i }), stream.next().await.unwrap().unwrap()); } diff --git a/tests/rename.rs b/tests/rename.rs index eb1bb809..1d35116a 100644 --- a/tests/rename.rs +++ b/tests/rename.rs @@ -127,7 +127,6 @@ pub async fn test_subscription() { assert_eq!( Schema::new(Query, EmptyMutation, Subscription) .execute_stream("subscription { CREATE_OBJECT(objectid: 100) }") - .boxed() .next() .await .unwrap() diff --git a/tests/subscription.rs b/tests/subscription.rs index 7bb77504..1477c9a1 100644 --- a/tests/subscription.rs +++ b/tests/subscription.rs @@ -36,8 +36,7 @@ pub async fn test_subscription() { { let mut stream = schema .execute_stream("subscription { values(start: 10, end: 20) }") - .map(|resp| resp.into_result().unwrap().data) - .boxed(); + .map(|resp| resp.into_result().unwrap().data); for i in 10..20 { assert_eq!(value!({ "values": i }), stream.next().await.unwrap()); } @@ -47,8 +46,7 @@ pub async fn test_subscription() { { let mut stream = schema .execute_stream("subscription { events(start: 10, end: 20) { a b } }") - .map(|resp| resp.into_result().unwrap().data) - .boxed(); + .map(|resp| resp.into_result().unwrap().data); for i in 10..20 { assert_eq!( value!({ "events": {"a": i, "b": i * 10} }), @@ -98,8 +96,7 @@ pub async fn test_subscription_with_ctx_data() { { let mut stream = schema .execute_stream(Request::new("subscription { values objects { value } }").data(100i32)) - .map(|resp| resp.data) - .boxed(); + .map(|resp| resp.data); assert_eq!(value!({ "values": 100 }), stream.next().await.unwrap()); assert_eq!( value!({ "objects": { "value": 100 } }), @@ -141,8 +138,7 @@ pub async fn test_subscription_with_token() { .execute_stream( Request::new("subscription { values }").data(Token("123456".to_string())), ) - .map(|resp| resp.into_result().unwrap().data) - .boxed(); + .map(|resp| resp.into_result().unwrap().data); assert_eq!(value!({ "values": 100 }), stream.next().await.unwrap()); assert!(stream.next().await.is_none()); } @@ -152,7 +148,6 @@ pub async fn test_subscription_with_token() { .execute_stream( Request::new("subscription { values }").data(Token("654321".to_string())) ) - .boxed() .next() .await .unwrap() @@ -200,8 +195,7 @@ pub async fn test_subscription_inline_fragment() { } "#, ) - .map(|resp| resp.data) - .boxed(); + .map(|resp| resp.data); for i in 10..20 { assert_eq!( value!({ "events": {"a": i, "b": i * 10} }), @@ -250,8 +244,7 @@ pub async fn test_subscription_fragment() { } "#, ) - .map(|resp| resp.data) - .boxed(); + .map(|resp| resp.data); for i in 10i32..20 { assert_eq!( value!({ "events": {"a": i, "b": i * 10} }), @@ -301,8 +294,7 @@ pub async fn test_subscription_fragment2() { } "#, ) - .map(|resp| resp.data) - .boxed(); + .map(|resp| resp.data); for i in 10..20 { assert_eq!( value!({ "events": {"a": i, "b": i * 10} }), @@ -342,8 +334,7 @@ pub async fn test_subscription_error() { let mut stream = schema .execute_stream("subscription { events { value } }") .map(|resp| resp.into_result()) - .map_ok(|resp| resp.data) - .boxed(); + .map_ok(|resp| resp.data); for i in 0i32..5 { assert_eq!( value!({ "events": { "value": i } }), @@ -388,8 +379,7 @@ pub async fn test_subscription_fieldresult() { let mut stream = schema .execute_stream("subscription { values }") .map(|resp| resp.into_result()) - .map_ok(|resp| resp.data) - .boxed(); + .map_ok(|resp| resp.data); for i in 0i32..5 { assert_eq!( value!({ "values": i }),