Rework Extension

This commit is contained in:
Sunli 2021-04-04 12:05:54 +08:00
parent 13298b8d61
commit 824356d118
27 changed files with 1058 additions and 1145 deletions

View File

@ -2,7 +2,14 @@
All notable changes to this project will be documented in this file. All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
nd this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Changed
- Rework `Extension`, now fully supports asynchronous, better to use than before, and can achieve more features.
Because it contains a lot of changes _(if you don't have a custom extension, it will not cause the existing code to fail to compile)_, the major version will be updated to `3.0.0`.
## [2.7.4] 2021-04-02 ## [2.7.4] 2021-04-02

View File

@ -19,6 +19,7 @@ apollo_persisted_queries = ["lru", "sha2"]
unblock = ["blocking"] unblock = ["blocking"]
string_number = ["num-traits"] string_number = ["num-traits"]
dataloader = ["futures-timer", "futures-channel", "lru"] dataloader = ["futures-timer", "futures-channel", "lru"]
tracing = ["tracinglib", "tracing-futures"]
[dependencies] [dependencies]
async-graphql-derive = { path = "derive", version = "=2.7.3" } async-graphql-derive = { path = "derive", version = "=2.7.3" }
@ -47,7 +48,8 @@ bson = { version = "1.2.0", optional = true }
chrono = { version = "0.4.19", optional = true } chrono = { version = "0.4.19", optional = true }
chrono-tz = { version = "0.5.3", optional = true } chrono-tz = { version = "0.5.3", optional = true }
log = { version = "0.4.14", optional = true } log = { version = "0.4.14", optional = true }
tracing = { version = "0.1.25", optional = true } tracinglib = { version = "0.1.25", optional = true, package = "tracing" }
tracing-futures = { version = "0.2.5", optional = true, features = ["std-future", "futures-03"] }
opentelemetry = { version = "0.13.0", optional = true } opentelemetry = { version = "0.13.0", optional = true }
url = { version = "2.2.1", optional = true } url = { version = "2.2.1", optional = true }
uuid = { version = "0.8.2", optional = true, features = ["v4", "serde"] } uuid = { version = "0.8.2", optional = true, features = ["v4", "serde"] }

View File

@ -82,7 +82,7 @@ pub fn generate(object_args: &args::MergedSubscription) -> GeneratorResult<Token
fn create_field_stream<'__life>( fn create_field_stream<'__life>(
&'__life self, &'__life self,
ctx: &'__life #crate_name::Context<'__life> ctx: &'__life #crate_name::Context<'__life>
) -> ::std::option::Option<::std::pin::Pin<::std::boxed::Box<dyn #crate_name::futures_util::stream::Stream<Item = #crate_name::ServerResult<#crate_name::Value>> + ::std::marker::Send + '__life>>> { ) -> ::std::option::Option<::std::pin::Pin<::std::boxed::Box<dyn #crate_name::futures_util::stream::Stream<Item = #crate_name::Response> + ::std::marker::Send + '__life>>> {
::std::option::Option::None #create_field_stream ::std::option::Option::None #create_field_stream
} }
} }

View File

@ -282,7 +282,7 @@ pub fn generate(
quote! { quote! {
Some(#crate_name::registry::ComplexityType::Fn(|__ctx, __variables_definition, __field, child_complexity| { Some(#crate_name::registry::ComplexityType::Fn(|__ctx, __variables_definition, __field, child_complexity| {
#(#parse_args)* #(#parse_args)*
Ok(#expr) ::std::result::Result::Ok(#expr)
})) }))
} }
} }
@ -331,7 +331,7 @@ pub fn generate(
let stream_fn = quote! { let stream_fn = quote! {
#(#get_params)* #(#get_params)*
#guard #guard
let field_name = ::std::sync::Arc::new(::std::clone::Clone::clone(&ctx.item.node.response_key().node)); let field_name = ::std::clone::Clone::clone(&ctx.item.node.response_key().node);
let field = ::std::sync::Arc::new(::std::clone::Clone::clone(&ctx.item)); let field = ::std::sync::Arc::new(::std::clone::Clone::clone(&ctx.item));
let pos = ctx.item.pos; let pos = ctx.item.pos;
@ -345,11 +345,6 @@ pub fn generate(
let field = ::std::clone::Clone::clone(&field); let field = ::std::clone::Clone::clone(&field);
let field_name = ::std::clone::Clone::clone(&field_name); let field_name = ::std::clone::Clone::clone(&field_name);
async move { async move {
let resolve_id = #crate_name::ResolveId {
parent: ::std::option::Option::Some(0),
current: 1,
};
let inc_resolve_id = ::std::sync::atomic::AtomicUsize::new(1);
let ctx_selection_set = query_env.create_context( let ctx_selection_set = query_env.create_context(
&schema_env, &schema_env,
::std::option::Option::Some(#crate_name::QueryPathNode { ::std::option::Option::Some(#crate_name::QueryPathNode {
@ -357,26 +352,31 @@ pub fn generate(
segment: #crate_name::QueryPathSegment::Name(&field_name), segment: #crate_name::QueryPathSegment::Name(&field_name),
}), }),
&field.node.selection_set, &field.node.selection_set,
resolve_id,
&inc_resolve_id,
); );
query_env.extensions.execution_start(); let mut execute_fut = async {
#[allow(bare_trait_objects)]
#[allow(bare_trait_objects)] let ri = #crate_name::extensions::ResolveInfo {
let ri = #crate_name::extensions::ResolveInfo { path_node: ctx_selection_set.path_node.as_ref().unwrap(),
resolve_id, parent_type: #gql_typename,
path_node: ctx_selection_set.path_node.as_ref().unwrap(), return_type: &<<#stream_ty as #crate_name::futures_util::stream::Stream>::Item as #crate_name::Type>::qualified_type_name(),
parent_type: #gql_typename, };
return_type: &<<#stream_ty as #crate_name::futures_util::stream::Stream>::Item as #crate_name::Type>::qualified_type_name(), let resolve_fut = async {
#crate_name::OutputType::resolve(&msg, &ctx_selection_set, &*field)
.await
.map(::std::option::Option::Some)
};
#crate_name::futures_util::pin_mut!(resolve_fut);
query_env.extensions.resolve(ri, &mut resolve_fut).await
.map(|value| {
let mut map = ::std::collections::BTreeMap::new();
map.insert(::std::clone::Clone::clone(&field_name), value.unwrap_or_default());
#crate_name::Response::new(#crate_name::Value::Object(map))
})
.unwrap_or_else(|err| #crate_name::Response::from_errors(::std::vec![err]))
}; };
#crate_name::futures_util::pin_mut!(execute_fut);
query_env.extensions.resolve_start(&ri); ::std::result::Result::Ok(query_env.extensions.execute(&mut execute_fut).await)
let res = #crate_name::OutputType::resolve(&msg, &ctx_selection_set, &*field).await;
query_env.extensions.resolve_end(&ri);
query_env.extensions.execution_end();
res
} }
} }
}); });
@ -398,11 +398,14 @@ pub fn generate(
create_stream.push(quote! { create_stream.push(quote! {
#(#cfg_attrs)* #(#cfg_attrs)*
if ctx.item.node.name.node == #field_name { if ctx.item.node.name.node == #field_name {
return ::std::option::Option::Some(::std::boxed::Box::pin( let stream = #crate_name::futures_util::stream::TryStreamExt::try_flatten(
#crate_name::futures_util::stream::TryStreamExt::try_flatten( #crate_name::futures_util::stream::once((move || async move { #stream_fn })())
#crate_name::futures_util::stream::once((move || async move { #stream_fn })()) );
) let stream = #crate_name::futures_util::StreamExt::map(stream, |res| match res {
)); ::std::result::Result::Ok(resp) => resp,
::std::result::Result::Err(err) => #crate_name::Response::from_errors(::std::vec![err]),
});
return ::std::option::Option::Some(::std::boxed::Box::pin(stream));
} }
}); });
@ -451,7 +454,7 @@ pub fn generate(
fn create_field_stream<'__life>( fn create_field_stream<'__life>(
&'__life self, &'__life self,
ctx: &'__life #crate_name::Context<'_>, ctx: &'__life #crate_name::Context<'_>,
) -> ::std::option::Option<::std::pin::Pin<::std::boxed::Box<dyn #crate_name::futures_util::stream::Stream<Item = #crate_name::ServerResult<#crate_name::Value>> + ::std::marker::Send + '__life>>> { ) -> ::std::option::Option<::std::pin::Pin<::std::boxed::Box<dyn #crate_name::futures_util::stream::Stream<Item = #crate_name::Response> + ::std::marker::Send + '__life>>> {
#(#create_stream)* #(#create_stream)*
::std::option::Option::None ::std::option::Option::None
} }

View File

@ -4,7 +4,6 @@ use std::any::{Any, TypeId};
use std::collections::HashMap; use std::collections::HashMap;
use std::fmt::{self, Debug, Display, Formatter}; use std::fmt::{self, Debug, Display, Formatter};
use std::ops::Deref; use std::ops::Deref;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc; use std::sync::Arc;
use async_graphql_value::{Value as InputValue, Variables}; use async_graphql_value::{Value as InputValue, Variables};
@ -189,36 +188,6 @@ impl<'a> Iterator for Parents<'a> {
impl<'a> std::iter::FusedIterator for Parents<'a> {} impl<'a> std::iter::FusedIterator for Parents<'a> {}
/// The unique id of the current resolution.
#[derive(Debug, Clone, Copy)]
pub struct ResolveId {
/// The unique ID of the parent resolution.
pub parent: Option<usize>,
/// The current unique id.
pub current: usize,
}
impl ResolveId {
#[doc(hidden)]
pub fn root() -> ResolveId {
ResolveId {
parent: None,
current: 0,
}
}
}
impl Display for ResolveId {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
if let Some(parent) = self.parent {
write!(f, "{}:{}", parent, self.current)
} else {
write!(f, "{}", self.current)
}
}
}
/// Query context. /// Query context.
/// ///
/// **This type is not stable and should not be used directly.** /// **This type is not stable and should not be used directly.**
@ -226,8 +195,6 @@ impl Display for ResolveId {
pub struct ContextBase<'a, T> { pub struct ContextBase<'a, T> {
/// The current path node being resolved. /// The current path node being resolved.
pub path_node: Option<QueryPathNode<'a>>, pub path_node: Option<QueryPathNode<'a>>,
pub(crate) resolve_id: ResolveId,
pub(crate) inc_resolve_id: &'a AtomicUsize,
#[doc(hidden)] #[doc(hidden)]
pub item: T, pub item: T,
#[doc(hidden)] #[doc(hidden)]
@ -273,13 +240,9 @@ impl QueryEnv {
schema_env: &'a SchemaEnv, schema_env: &'a SchemaEnv,
path_node: Option<QueryPathNode<'a>>, path_node: Option<QueryPathNode<'a>>,
item: T, item: T,
resolve_id: ResolveId,
inc_resolve_id: &'a AtomicUsize,
) -> ContextBase<'a, T> { ) -> ContextBase<'a, T> {
ContextBase { ContextBase {
path_node, path_node,
resolve_id,
inc_resolve_id,
item, item,
schema_env, schema_env,
query_env: self, query_env: self,
@ -288,18 +251,6 @@ impl QueryEnv {
} }
impl<'a, T> ContextBase<'a, T> { impl<'a, T> ContextBase<'a, T> {
#[doc(hidden)]
pub fn get_child_resolve_id(&self) -> ResolveId {
let id = self
.inc_resolve_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
+ 1;
ResolveId {
parent: Some(self.resolve_id.current),
current: id,
}
}
#[doc(hidden)] #[doc(hidden)]
pub fn with_field( pub fn with_field(
&'a self, &'a self,
@ -311,8 +262,6 @@ impl<'a, T> ContextBase<'a, T> {
segment: QueryPathSegment::Name(&field.node.response_key().node), segment: QueryPathSegment::Name(&field.node.response_key().node),
}), }),
item: field, item: field,
resolve_id: self.get_child_resolve_id(),
inc_resolve_id: self.inc_resolve_id,
schema_env: self.schema_env, schema_env: self.schema_env,
query_env: self.query_env, query_env: self.query_env,
} }
@ -326,8 +275,6 @@ impl<'a, T> ContextBase<'a, T> {
ContextBase { ContextBase {
path_node: self.path_node, path_node: self.path_node,
item: selection_set, item: selection_set,
resolve_id: self.resolve_id,
inc_resolve_id: &self.inc_resolve_id,
schema_env: self.schema_env, schema_env: self.schema_env,
query_env: self.query_env, query_env: self.query_env,
} }
@ -560,8 +507,6 @@ impl<'a> ContextBase<'a, &'a Positioned<SelectionSet>> {
segment: QueryPathSegment::Index(idx), segment: QueryPathSegment::Index(idx),
}), }),
item: self.item, item: self.item,
resolve_id: self.get_child_resolve_id(),
inc_resolve_id: self.inc_resolve_id,
schema_env: self.schema_env, schema_env: self.schema_env,
query_env: self.query_env, query_env: self.query_env,
} }

View File

@ -1,5 +1,9 @@
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory}; use std::sync::Arc;
use crate::{value, ValidationResult, Value};
use futures_util::lock::Mutex;
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, NextExtension};
use crate::{value, Response, ServerError, ValidationResult};
/// Analyzer extension /// Analyzer extension
/// ///
@ -7,32 +11,41 @@ use crate::{value, ValidationResult, Value};
pub struct Analyzer; pub struct Analyzer;
impl ExtensionFactory for Analyzer { impl ExtensionFactory for Analyzer {
fn create(&self) -> Box<dyn Extension> { fn create(&self) -> Arc<dyn Extension> {
Box::new(AnalyzerExtension::default()) Arc::new(AnalyzerExtension::default())
} }
} }
#[derive(Default)] #[derive(Default)]
struct AnalyzerExtension { struct AnalyzerExtension {
complexity: usize, validation_result: Mutex<Option<ValidationResult>>,
depth: usize,
} }
#[async_trait::async_trait]
impl Extension for AnalyzerExtension { impl Extension for AnalyzerExtension {
fn name(&self) -> Option<&'static str> { async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
Some("analyzer") let mut resp = next.request(ctx).await;
let validation_result = self.validation_result.lock().await.take();
if let Some(validation_result) = validation_result {
resp = resp.extension(
"analyzer",
value! ({
"complexity": validation_result.complexity,
"depth": validation_result.depth,
}),
);
}
resp
} }
fn validation_end(&mut self, _ctx: &ExtensionContext<'_>, result: &ValidationResult) { async fn validation(
self.complexity = result.complexity; &self,
self.depth = result.depth; ctx: &ExtensionContext<'_>,
} next: NextExtension<'_>,
) -> Result<ValidationResult, Vec<ServerError>> {
fn result(&mut self, _ctx: &ExtensionContext<'_>) -> Option<Value> { let res = next.validation(ctx).await?;
Some(value! ({ *self.validation_result.lock().await = Some(res);
"complexity": self.complexity, Ok(res)
"depth": self.depth,
}))
} }
} }
@ -78,7 +91,7 @@ mod tests {
.extension(extensions::Analyzer) .extension(extensions::Analyzer)
.finish(); .finish();
let extensions = schema let res = schema
.execute( .execute(
r#"{ r#"{
value obj { value obj {
@ -93,15 +106,13 @@ mod tests {
.into_result() .into_result()
.unwrap() .unwrap()
.extensions .extensions
.unwrap(); .remove("analyzer");
assert_eq!( assert_eq!(
extensions, res,
value!({ Some(value!({
"analyzer": { "complexity": 5 + 10,
"complexity": 5 + 10, "depth": 3,
"depth": 3, }))
}
})
); );
} }
} }

View File

@ -6,7 +6,7 @@ use futures_util::lock::Mutex;
use serde::Deserialize; use serde::Deserialize;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory}; use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, NextExtension};
use crate::{from_value, Request, ServerError, ServerResult}; use crate::{from_value, Request, ServerError, ServerResult};
#[derive(Deserialize)] #[derive(Deserialize)]
@ -64,8 +64,8 @@ impl<T: CacheStorage> ApolloPersistedQueries<T> {
} }
impl<T: CacheStorage> ExtensionFactory for ApolloPersistedQueries<T> { impl<T: CacheStorage> ExtensionFactory for ApolloPersistedQueries<T> {
fn create(&self) -> Box<dyn Extension> { fn create(&self) -> Arc<dyn Extension> {
Box::new(ApolloPersistedQueriesExtension { Arc::new(ApolloPersistedQueriesExtension {
storage: self.0.clone(), storage: self.0.clone(),
}) })
} }
@ -78,18 +78,19 @@ struct ApolloPersistedQueriesExtension<T> {
#[async_trait::async_trait] #[async_trait::async_trait]
impl<T: CacheStorage> Extension for ApolloPersistedQueriesExtension<T> { impl<T: CacheStorage> Extension for ApolloPersistedQueriesExtension<T> {
async fn prepare_request( async fn prepare_request(
&mut self, &self,
_ctx: &ExtensionContext<'_>, ctx: &ExtensionContext<'_>,
mut request: Request, mut request: Request,
next: NextExtension<'_>,
) -> ServerResult<Request> { ) -> ServerResult<Request> {
if let Some(value) = request.extensions.remove("persistedQuery") { let res = if let Some(value) = request.extensions.remove("persistedQuery") {
let persisted_query: PersistedQuery = from_value(value).map_err(|_| { let persisted_query: PersistedQuery = from_value(value).map_err(|_| {
ServerError::new("Invalid \"PersistedQuery\" extension configuration.") ServerError::new("Invalid \"PersistedQuery\" extension configuration.")
})?; })?;
if persisted_query.version != 1 { if persisted_query.version != 1 {
return Err(ServerError::new( return Err(ServerError::new(
format!("Only the \"PersistedQuery\" extension of version \"1\" is supported, and the current version is \"{}\".", persisted_query.version), format!("Only the \"PersistedQuery\" extension of version \"1\" is supported, and the current version is \"{}\".", persisted_query.version),
)); ));
} }
if request.query.is_empty() { if request.query.is_empty() {
@ -110,7 +111,8 @@ impl<T: CacheStorage> Extension for ApolloPersistedQueriesExtension<T> {
} }
} else { } else {
Ok(request) Ok(request)
} };
next.prepare_request(ctx, res?).await
} }
} }

View File

@ -1,36 +1,26 @@
use std::collections::BTreeMap; use std::sync::Arc;
use std::ops::Deref;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use futures_util::lock::Mutex;
use serde::ser::SerializeMap; use serde::ser::SerializeMap;
use serde::{Serialize, Serializer}; use serde::{Serialize, Serializer};
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; use crate::extensions::{
use crate::{value, Value}; Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo,
};
use crate::{value, Response, ServerResult, Value};
struct PendingResolve { struct ResolveState {
path: Vec<String>, path: Vec<String>,
field_name: String, field_name: String,
parent_type: String, parent_type: String,
return_type: String, return_type: String,
start_time: DateTime<Utc>, start_time: DateTime<Utc>,
}
struct ResolveStat {
pending_resolve: PendingResolve,
end_time: DateTime<Utc>, end_time: DateTime<Utc>,
start_offset: i64, start_offset: i64,
} }
impl Deref for ResolveStat { impl Serialize for ResolveState {
type Target = PendingResolve;
fn deref(&self) -> &Self::Target {
&self.pending_resolve
}
}
impl Serialize for ResolveStat {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> { fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut map = serializer.serialize_map(None)?; let mut map = serializer.serialize_map(None)?;
map.serialize_entry("path", &self.path)?; map.serialize_entry("path", &self.path)?;
@ -57,76 +47,79 @@ impl Serialize for ResolveStat {
pub struct ApolloTracing; pub struct ApolloTracing;
impl ExtensionFactory for ApolloTracing { impl ExtensionFactory for ApolloTracing {
fn create(&self) -> Box<dyn Extension> { fn create(&self) -> Arc<dyn Extension> {
Box::new(ApolloTracingExtension { Arc::new(ApolloTracingExtension {
start_time: Utc::now(), inner: Mutex::new(Inner {
end_time: Utc::now(), start_time: Utc::now(),
pending_resolves: Default::default(), end_time: Utc::now(),
resolves: Default::default(), resolves: Default::default(),
}),
}) })
} }
} }
struct ApolloTracingExtension { struct Inner {
start_time: DateTime<Utc>, start_time: DateTime<Utc>,
end_time: DateTime<Utc>, end_time: DateTime<Utc>,
pending_resolves: BTreeMap<usize, PendingResolve>, resolves: Vec<ResolveState>,
resolves: Vec<ResolveStat>,
} }
struct ApolloTracingExtension {
inner: Mutex<Inner>,
}
#[async_trait::async_trait]
impl Extension for ApolloTracingExtension { impl Extension for ApolloTracingExtension {
fn name(&self) -> Option<&'static str> { async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
Some("tracing") self.inner.lock().await.start_time = Utc::now();
} let resp = next.execute(ctx).await;
fn execution_start(&mut self, _ctx: &ExtensionContext<'_>) { let mut inner = self.inner.lock().await;
self.start_time = Utc::now(); inner.end_time = Utc::now();
self.pending_resolves.clear(); inner
self.resolves.clear(); .resolves
}
fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) {
self.end_time = Utc::now();
}
fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
self.pending_resolves.insert(
info.resolve_id.current,
PendingResolve {
path: info.path_node.to_string_vec(),
field_name: info.path_node.field_name().to_string(),
parent_type: info.parent_type.to_string(),
return_type: info.return_type.to_string(),
start_time: Utc::now(),
},
);
}
fn resolve_end(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
if let Some(pending_resolve) = self.pending_resolves.remove(&info.resolve_id.current) {
let start_offset = (pending_resolve.start_time - self.start_time)
.num_nanoseconds()
.unwrap();
self.resolves.push(ResolveStat {
pending_resolve,
start_offset,
end_time: Utc::now(),
});
}
}
fn result(&mut self, _ctx: &ExtensionContext<'_>) -> Option<Value> {
self.resolves
.sort_by(|a, b| a.start_offset.cmp(&b.start_offset)); .sort_by(|a, b| a.start_offset.cmp(&b.start_offset));
resp.extension(
"tracing",
value!({
"version": 1,
"startTime": inner.start_time.to_rfc3339(),
"endTime": inner.end_time.to_rfc3339(),
"duration": (inner.end_time - inner.start_time).num_nanoseconds(),
"execution": {
"resolvers": inner.resolves
}
}),
)
}
Some(value!({ async fn resolve(
"version": 1, &self,
"startTime": self.start_time.to_rfc3339(), ctx: &ExtensionContext<'_>,
"endTime": self.end_time.to_rfc3339(), info: ResolveInfo<'_>,
"duration": (self.end_time - self.start_time).num_nanoseconds(), next: NextExtension<'_>,
"execution": { ) -> ServerResult<Option<Value>> {
"resolvers": self.resolves let path = info.path_node.to_string_vec();
} let field_name = info.path_node.field_name().to_string();
})) let parent_type = info.parent_type.to_string();
let return_type = info.return_type.to_string();
let start_time = Utc::now();
let start_offset = (start_time - self.inner.lock().await.start_time)
.num_nanoseconds()
.unwrap();
let res = next.resolve(ctx, info).await;
let end_time = Utc::now();
self.inner.lock().await.resolves.push(ResolveState {
path,
field_name,
parent_type,
return_type,
start_time,
end_time,
start_offset,
});
res
} }
} }

View File

@ -1,117 +1,124 @@
use std::fmt::{self, Display, Formatter}; use std::fmt::{self, Display, Formatter};
use std::sync::Arc;
use log::{error, info, trace}; use futures_util::lock::Mutex;
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; use crate::extensions::{
Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo,
};
use crate::parser::types::{ExecutableDocument, OperationType, Selection}; use crate::parser::types::{ExecutableDocument, OperationType, Selection};
use crate::{PathSegment, ServerError, Variables}; use crate::{PathSegment, ServerError, ServerResult, Value, Variables};
/// Logger extension /// Logger extension
#[cfg_attr(docsrs, doc(cfg(feature = "log")))] #[cfg_attr(docsrs, doc(cfg(feature = "log")))]
pub struct Logger; pub struct Logger;
impl ExtensionFactory for Logger { impl ExtensionFactory for Logger {
fn create(&self) -> Box<dyn Extension> { fn create(&self) -> Arc<dyn Extension> {
Box::new(LoggerExtension { Arc::new(LoggerExtension {
enabled: true, inner: Mutex::new(Inner {
query: String::new(), enabled: true,
variables: Default::default(), query: String::new(),
variables: Default::default(),
}),
}) })
} }
} }
struct LoggerExtension { struct Inner {
enabled: bool, enabled: bool,
query: String, query: String,
variables: Variables, variables: Variables,
} }
impl Extension for LoggerExtension { struct LoggerExtension {
fn parse_start( inner: Mutex<Inner>,
&mut self, }
_ctx: &ExtensionContext<'_>,
query_source: &str,
variables: &Variables,
) {
self.query = query_source.replace(char::is_whitespace, "");
self.variables = variables.clone();
}
fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, document: &ExecutableDocument) { #[async_trait::async_trait]
impl Extension for LoggerExtension {
async fn parse_query(
&self,
ctx: &ExtensionContext<'_>,
query: &str,
variables: &Variables,
next: NextExtension<'_>,
) -> ServerResult<ExecutableDocument> {
let mut inner = self.inner.lock().await;
inner.query = query.replace(char::is_whitespace, "");
inner.variables = variables.clone();
let document = next.parse_query(ctx, query, variables).await?;
let is_schema = document let is_schema = document
.operations .operations
.iter() .iter()
.filter(|(_, operation)| operation.node.ty == OperationType::Query) .filter(|(_, operation)| operation.node.ty == OperationType::Query)
.any(|(_, operation)| operation.node.selection_set.node.items.iter().any(|selection| matches!(&selection.node, Selection::Field(field) if field.node.name.node == "__schema"))); .any(|(_, operation)| operation.node.selection_set.node.items.iter().any(|selection| matches!(&selection.node, Selection::Field(field) if field.node.name.node == "__schema")));
inner.enabled = !is_schema;
if is_schema { Ok(document)
self.enabled = false;
return;
}
info!(target: "async-graphql", "[Query] query: \"{}\", variables: {}", &self.query, self.variables);
} }
fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { async fn resolve(
if !self.enabled { &self,
return; ctx: &ExtensionContext<'_>,
} info: ResolveInfo<'_>,
trace!(target: "async-graphql", "[ResolveStart] path: \"{}\"", info.path_node); next: NextExtension<'_>,
} ) -> ServerResult<Option<Value>> {
let enabled = self.inner.lock().await.enabled;
fn resolve_end(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { if enabled {
if !self.enabled { let path = info.path_node.to_string();
return; log::trace!(target: "async-graphql", "[ResolveStart] path: \"{}\"", path);
} let res = next.resolve(ctx, info).await;
trace!(target: "async-graphql", "[ResolveEnd] path: \"{}\"", info.path_node); if let Err(err) = &res {
} let inner = self.inner.lock().await;
log::error!(
fn error(&mut self, _ctx: &ExtensionContext<'_>, err: &ServerError) { target: "async-graphql",
struct DisplayError<'a> { "{}",
log: &'a LoggerExtension, DisplayError { query:&inner.query,variables:&inner.variables, e: &err }
e: &'a ServerError, );
}
impl<'a> Display for DisplayError<'a> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "[Error] ")?;
if !self.e.path.is_empty() {
write!(f, "path: ")?;
for (i, segment) in self.e.path.iter().enumerate() {
if i != 0 {
write!(f, ".")?;
}
match segment {
PathSegment::Field(field) => write!(f, "{}", field),
PathSegment::Index(i) => write!(f, "{}", i),
}?;
}
write!(f, ", ")?;
}
if !self.e.locations.is_empty() {
write!(f, "pos: [")?;
for (i, location) in self.e.locations.iter().enumerate() {
if i != 0 {
write!(f, ", ")?;
}
write!(f, "{}:{}", location.line, location.column)?;
}
write!(f, "], ")?;
}
write!(f, r#"query: "{}", "#, self.log.query)?;
write!(f, "variables: {}", self.log.variables)?;
write!(f, "{}", self.e.message)
} }
log::trace!(target: "async-graphql", "[ResolveEnd] path: \"{}\"", path);
res
} else {
next.resolve(ctx, info).await
} }
}
error!( }
target: "async-graphql",
"{}", struct DisplayError<'a> {
DisplayError { query: &'a str,
log: self, variables: &'a Variables,
e: err, e: &'a ServerError,
} }
); impl<'a> Display for DisplayError<'a> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "[Error] ")?;
if !self.e.path.is_empty() {
write!(f, "path: ")?;
for (i, segment) in self.e.path.iter().enumerate() {
if i != 0 {
write!(f, ".")?;
}
match segment {
PathSegment::Field(field) => write!(f, "{}", field),
PathSegment::Index(i) => write!(f, "{}", i),
}?;
}
write!(f, ", ")?;
}
if !self.e.locations.is_empty() {
write!(f, "pos: [")?;
for (i, location) in self.e.locations.iter().enumerate() {
if i != 0 {
write!(f, ", ")?;
}
write!(f, "{}:{}", location.line, location.column)?;
}
write!(f, "], ")?;
}
write!(f, r#"query: "{}", "#, self.query)?;
write!(f, "variables: {}", self.variables)?;
write!(f, "{}", self.e.message)
} }
} }

View File

@ -12,28 +12,28 @@ mod opentelemetry;
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
mod tracing; mod tracing;
use std::any::{Any, TypeId};
use std::collections::BTreeMap;
use std::sync::Arc;
use crate::context::{QueryPathNode, ResolveId};
use crate::parser::types::ExecutableDocument;
use crate::{
Data, Request, Result, SchemaEnv, ServerError, ServerResult, ValidationResult, Variables,
};
use crate::{Error, Name, Value};
pub use self::analyzer::Analyzer; pub use self::analyzer::Analyzer;
#[cfg(feature = "apollo_tracing")] #[cfg(feature = "apollo_tracing")]
pub use self::apollo_tracing::ApolloTracing; pub use self::apollo_tracing::ApolloTracing;
#[cfg(feature = "log")] #[cfg(feature = "log")]
pub use self::logger::Logger; pub use self::logger::Logger;
#[cfg(feature = "opentelemetry")] #[cfg(feature = "opentelemetry")]
pub use self::opentelemetry::{OpenTelemetry, OpenTelemetryConfig}; pub use self::opentelemetry::OpenTelemetry;
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
pub use self::tracing::{Tracing, TracingConfig}; pub use self::tracing::Tracing;
pub(crate) type BoxExtension = Box<dyn Extension>; use std::any::{Any, TypeId};
use std::future::Future;
use std::sync::Arc;
use futures_util::stream::BoxStream;
use futures_util::stream::StreamExt;
use crate::parser::types::ExecutableDocument;
use crate::{
Data, Error, QueryPathNode, Request, Response, Result, SchemaEnv, ServerError, ServerResult,
ValidationResult, Value, Variables,
};
/// Context for extension /// Context for extension
pub struct ExtensionContext<'a> { pub struct ExtensionContext<'a> {
@ -86,10 +86,6 @@ impl<'a> ExtensionContext<'a> {
/// Parameters for `Extension::resolve_field_start` /// Parameters for `Extension::resolve_field_start`
pub struct ResolveInfo<'a> { pub struct ResolveInfo<'a> {
/// Because resolver is concurrent, `Extension::resolve_field_start` and `Extension::resolve_field_end` are
/// not strictly ordered, so each pair is identified by an id.
pub resolve_id: ResolveId,
/// Current path node, You can go through the entire path. /// Current path node, You can go through the entire path.
pub path_node: &'a QueryPathNode<'a>, pub path_node: &'a QueryPathNode<'a>,
@ -100,120 +96,250 @@ pub struct ResolveInfo<'a> {
pub return_type: &'a str, pub return_type: &'a str,
} }
/// Represents a GraphQL extension type RequestFut<'a> = &'a mut (dyn Future<Output = Response> + Send + Unpin);
///
/// # Call order for query and mutation type ParseFut<'a> = &'a mut (dyn Future<Output = ServerResult<ExecutableDocument>> + Send + Unpin);
///
/// - start type ValidationFut<'a> =
/// - prepare_request &'a mut (dyn Future<Output = Result<ValidationResult, Vec<ServerError>>> + Send + Unpin);
/// - parse_start
/// - parse_end type ExecuteFut<'a> = &'a mut (dyn Future<Output = Response> + Send + Unpin);
/// - validation_start
/// - validation_end type ResolveFut<'a> = &'a mut (dyn Future<Output = ServerResult<Option<Value>>> + Send + Unpin);
/// - execution_start
/// - resolve_start /// The remainder of a extension chain.
/// - resolve_end pub struct NextExtension<'a> {
/// - result chain: &'a [Arc<dyn Extension>],
/// - execution_end request_fut: Option<RequestFut<'a>>,
/// - end parse_query_fut: Option<ParseFut<'a>>,
/// validation_fut: Option<ValidationFut<'a>>,
/// # Call order for subscription execute_fut: Option<ExecuteFut<'a>>,
/// resolve_fut: Option<ResolveFut<'a>>,
/// - start }
/// - prepare_request
/// - parse_start impl<'a> NextExtension<'a> {
/// - parse_end #[inline]
/// - validation_start pub(crate) fn new(chain: &'a [Arc<dyn Extension>]) -> Self {
/// - validation_end Self {
/// - execution_start chain,
/// - resolve_start request_fut: None,
/// - resolve_end parse_query_fut: None,
/// - execution_end validation_fut: None,
/// - result execute_fut: None,
/// ``` resolve_fut: None,
#[async_trait::async_trait] }
#[allow(unused_variables)]
pub trait Extension: Sync + Send + 'static {
/// If this extension needs to output data to query results, you need to specify a name.
fn name(&self) -> Option<&'static str> {
None
} }
/// Called at the beginning of query. #[inline]
fn start(&mut self, ctx: &ExtensionContext<'_>) {} pub(crate) fn with_chain(self, chain: &'a [Arc<dyn Extension>]) -> Self {
Self { chain, ..self }
}
/// Called at the beginning of query. #[inline]
fn end(&mut self, ctx: &ExtensionContext<'_>) {} pub(crate) fn with_request(self, fut: RequestFut<'a>) -> Self {
Self {
request_fut: Some(fut),
..self
}
}
/// Called at prepare request. #[inline]
async fn prepare_request( pub(crate) fn with_parse_query(self, fut: ParseFut<'a>) -> Self {
&mut self, Self {
parse_query_fut: Some(fut),
..self
}
}
#[inline]
pub(crate) fn with_validation(self, fut: ValidationFut<'a>) -> Self {
Self {
validation_fut: Some(fut),
..self
}
}
#[inline]
pub(crate) fn with_execute(self, fut: ExecuteFut<'a>) -> Self {
Self {
execute_fut: Some(fut),
..self
}
}
#[inline]
pub(crate) fn with_resolve(self, fut: ResolveFut<'a>) -> Self {
Self {
resolve_fut: Some(fut),
..self
}
}
/// Call the [Extension::request] function of next extension.
pub async fn request(mut self, ctx: &ExtensionContext<'_>) -> Response {
if let Some((first, next)) = self.chain.split_first() {
first.request(ctx, self.with_chain(next)).await
} else {
self.request_fut
.take()
.expect("You definitely called the wrong function.")
.await
}
}
/// Call the [Extension::subscribe] function of next extension.
pub fn subscribe<'s>(
self,
ctx: &ExtensionContext<'_>,
stream: BoxStream<'s, Response>,
) -> BoxStream<'s, Response> {
if let Some((first, next)) = self.chain.split_first() {
first.subscribe(ctx, stream, self.with_chain(next))
} else {
stream
}
}
/// Call the [Extension::prepare_request] function of next extension.
pub async fn prepare_request(
self,
ctx: &ExtensionContext<'_>, ctx: &ExtensionContext<'_>,
request: Request, request: Request,
) -> ServerResult<Request> { ) -> ServerResult<Request> {
Ok(request) if let Some((first, next)) = self.chain.split_first() {
first
.prepare_request(ctx, request, self.with_chain(next))
.await
} else {
Ok(request)
}
} }
/// Called at the beginning of parse query source. /// Call the [Extension::parse_query] function of next extension.
fn parse_start( pub async fn parse_query(
&mut self, mut self,
ctx: &ExtensionContext<'_>, ctx: &ExtensionContext<'_>,
query_source: &str, query: &str,
variables: &Variables, variables: &Variables,
) { ) -> ServerResult<ExecutableDocument> {
} if let Some((first, next)) = self.chain.split_first() {
first
/// Called at the end of parse query source. .parse_query(ctx, query, variables, self.with_chain(next))
fn parse_end(&mut self, ctx: &ExtensionContext<'_>, document: &ExecutableDocument) {} .await
} else {
/// Called at the beginning of the validation. self.parse_query_fut
fn validation_start(&mut self, ctx: &ExtensionContext<'_>) {} .take()
.expect("You definitely called the wrong function.")
/// Called at the end of the validation. .await
fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) {} }
}
/// Called at the beginning of execute a query.
fn execution_start(&mut self, ctx: &ExtensionContext<'_>) {} /// Call the [Extension::validation] function of next extension.
pub async fn validation(
/// Called at the end of execute a query. mut self,
fn execution_end(&mut self, ctx: &ExtensionContext<'_>) {} ctx: &ExtensionContext<'_>,
) -> Result<ValidationResult, Vec<ServerError>> {
/// Called at the beginning of resolve a field. if let Some((first, next)) = self.chain.split_first() {
fn resolve_start(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {} first.validation(ctx, self.with_chain(next)).await
} else {
/// Called at the end of resolve a field. self.validation_fut
fn resolve_end(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {} .take()
.expect("You definitely called the wrong function.")
/// Called when an error occurs. .await
fn error(&mut self, ctx: &ExtensionContext<'_>, err: &ServerError) {} }
}
/// Get the results.
fn result(&mut self, ctx: &ExtensionContext<'_>) -> Option<Value> { /// Call the [Extension::execute] function of next extension.
None pub async fn execute(mut self, ctx: &ExtensionContext<'_>) -> Response {
} if let Some((first, next)) = self.chain.split_first() {
} first.execute(ctx, self.with_chain(next)).await
} else {
pub(crate) trait ErrorLogger { self.execute_fut
fn log_error(self, extensions: &Extensions) -> Self; .take()
} .expect("You definitely called the wrong function.")
.await
impl<T> ErrorLogger for ServerResult<T> { }
fn log_error(self, extensions: &Extensions) -> Self { }
if let Err(err) = &self {
extensions.error(err); /// Call the [Extension::resolve] function of next extension.
pub async fn resolve(
mut self,
ctx: &ExtensionContext<'_>,
info: ResolveInfo<'_>,
) -> ServerResult<Option<Value>> {
if let Some((first, next)) = self.chain.split_first() {
first.resolve(ctx, info, self.with_chain(next)).await
} else {
self.resolve_fut
.take()
.expect("You definitely called the wrong function.")
.await
} }
self
} }
} }
impl<T> ErrorLogger for Result<T, Vec<ServerError>> { /// Represents a GraphQL extension
fn log_error(self, extensions: &Extensions) -> Self { #[async_trait::async_trait]
if let Err(errors) = &self { #[allow(unused_variables)]
for error in errors { pub trait Extension: Sync + Send + 'static {
extensions.error(error); /// Called at start query/mutation request.
} async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
} next.request(ctx).await
self }
/// Called at subscribe request.
fn subscribe<'s>(
&self,
ctx: &ExtensionContext<'_>,
stream: BoxStream<'s, Response>,
next: NextExtension<'_>,
) -> BoxStream<'s, Response> {
next.subscribe(ctx, stream)
}
/// Called at prepare request.
async fn prepare_request(
&self,
ctx: &ExtensionContext<'_>,
request: Request,
next: NextExtension<'_>,
) -> ServerResult<Request> {
next.prepare_request(ctx, request).await
}
/// Called at parse query.
async fn parse_query(
&self,
ctx: &ExtensionContext<'_>,
query: &str,
variables: &Variables,
next: NextExtension<'_>,
) -> ServerResult<ExecutableDocument> {
next.parse_query(ctx, query, variables).await
}
/// Called at validation query.
async fn validation(
&self,
ctx: &ExtensionContext<'_>,
next: NextExtension<'_>,
) -> Result<ValidationResult, Vec<ServerError>> {
next.validation(ctx).await
}
/// Called at execute query.
async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
next.execute(ctx).await
}
/// Called at resolve field.
async fn resolve(
&self,
ctx: &ExtensionContext<'_>,
info: ResolveInfo<'_>,
next: NextExtension<'_>,
) -> ServerResult<Option<Value>> {
next.resolve(ctx, info).await
} }
} }
@ -222,12 +348,13 @@ impl<T> ErrorLogger for Result<T, Vec<ServerError>> {
/// Used to create an extension instance. /// Used to create an extension instance.
pub trait ExtensionFactory: Send + Sync + 'static { pub trait ExtensionFactory: Send + Sync + 'static {
/// Create an extended instance. /// Create an extended instance.
fn create(&self) -> Box<dyn Extension>; fn create(&self) -> Arc<dyn Extension>;
} }
#[derive(Clone)]
#[doc(hidden)] #[doc(hidden)]
pub struct Extensions { pub struct Extensions {
extensions: Option<spin::Mutex<Vec<BoxExtension>>>, extensions: Vec<Arc<dyn Extension>>,
schema_env: SchemaEnv, schema_env: SchemaEnv,
session_data: Arc<Data>, session_data: Arc<Data>,
query_data: Option<Arc<Data>>, query_data: Option<Arc<Data>>,
@ -235,17 +362,13 @@ pub struct Extensions {
#[doc(hidden)] #[doc(hidden)]
impl Extensions { impl Extensions {
pub fn new( pub(crate) fn new(
extensions: Vec<BoxExtension>, extensions: impl IntoIterator<Item = Arc<dyn Extension>>,
schema_env: SchemaEnv, schema_env: SchemaEnv,
session_data: Arc<Data>, session_data: Arc<Data>,
) -> Self { ) -> Self {
Extensions { Extensions {
extensions: if extensions.is_empty() { extensions: extensions.into_iter().collect(),
None
} else {
Some(spin::Mutex::new(extensions))
},
schema_env, schema_env,
session_data, session_data,
query_data: None, query_data: None,
@ -255,18 +378,14 @@ impl Extensions {
pub fn attach_query_data(&mut self, data: Arc<Data>) { pub fn attach_query_data(&mut self, data: Arc<Data>) {
self.query_data = Some(data); self.query_data = Some(data);
} }
}
impl Drop for Extensions {
fn drop(&mut self) {
self.end();
}
}
#[doc(hidden)]
impl Extensions {
#[inline] #[inline]
fn context(&self) -> ExtensionContext<'_> { pub(crate) fn is_empty(&self) -> bool {
self.extensions.is_empty()
}
#[inline]
fn create_context(&self) -> ExtensionContext {
ExtensionContext { ExtensionContext {
schema_data: &self.schema_env.data, schema_data: &self.schema_env.data,
session_data: &self.session_data, session_data: &self.session_data,
@ -274,124 +393,79 @@ impl Extensions {
} }
} }
pub fn is_empty(&self) -> bool { pub async fn request(&self, request_fut: RequestFut<'_>) -> Response {
self.extensions.is_none() if !self.extensions.is_empty() {
} let next = NextExtension::new(&self.extensions).with_request(request_fut);
next.request(&self.create_context()).await
pub fn start(&self) { } else {
if let Some(e) = &self.extensions { request_fut.await
e.lock().iter_mut().for_each(|e| e.start(&self.context()));
} }
} }
pub fn end(&self) { pub fn subscribe<'s>(&self, stream: BoxStream<'s, Response>) -> BoxStream<'s, Response> {
if let Some(e) = &self.extensions { if !self.extensions.is_empty() {
e.lock().iter_mut().for_each(|e| e.end(&self.context())); let next = NextExtension::new(&self.extensions);
next.subscribe(&self.create_context(), stream)
} else {
stream.boxed()
} }
} }
pub async fn prepare_request(&self, request: Request) -> ServerResult<Request> { pub async fn prepare_request(&self, request: Request) -> ServerResult<Request> {
let mut request = request; if !self.extensions.is_empty() {
if let Some(e) = &self.extensions { let next = NextExtension::new(&self.extensions);
for e in e.lock().iter_mut() { next.prepare_request(&self.create_context(), request).await
request = e.prepare_request(&self.context(), request).await?;
}
}
Ok(request)
}
pub fn parse_start(&self, query_source: &str, variables: &Variables) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.parse_start(&self.context(), query_source, variables));
}
}
pub fn parse_end(&self, document: &ExecutableDocument) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.parse_end(&self.context(), document));
}
}
pub fn validation_start(&self) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.validation_start(&self.context()));
}
}
pub fn validation_end(&self, result: &ValidationResult) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.validation_end(&self.context(), result));
}
}
pub fn execution_start(&self) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.execution_start(&self.context()));
}
}
pub fn execution_end(&self) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.execution_end(&self.context()));
}
}
pub fn resolve_start(&self, info: &ResolveInfo<'_>) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.resolve_start(&self.context(), info));
}
}
pub fn resolve_end(&self, resolve_id: &ResolveInfo<'_>) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.resolve_end(&self.context(), resolve_id));
}
}
pub fn error(&self, err: &ServerError) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.error(&self.context(), err));
}
}
pub fn result(&self) -> Option<Value> {
if let Some(e) = &self.extensions {
let value = e
.lock()
.iter_mut()
.filter_map(|e| {
if let Some(name) = e.name() {
e.result(&self.context()).map(|res| (Name::new(name), res))
} else {
None
}
})
.collect::<BTreeMap<_, _>>();
if value.is_empty() {
None
} else {
Some(Value::Object(value))
}
} else { } else {
None Ok(request)
}
}
pub async fn parse_query(
&self,
query: &str,
variables: &Variables,
parse_query_fut: ParseFut<'_>,
) -> ServerResult<ExecutableDocument> {
if !self.extensions.is_empty() {
let next = NextExtension::new(&self.extensions).with_parse_query(parse_query_fut);
next.parse_query(&self.create_context(), query, variables)
.await
} else {
parse_query_fut.await
}
}
pub async fn validation(
&self,
validation_fut: ValidationFut<'_>,
) -> Result<ValidationResult, Vec<ServerError>> {
if !self.extensions.is_empty() {
let next = NextExtension::new(&self.extensions).with_validation(validation_fut);
next.validation(&self.create_context()).await
} else {
validation_fut.await
}
}
pub async fn execute(&self, execute_fut: ExecuteFut<'_>) -> Response {
if !self.extensions.is_empty() {
let next = NextExtension::new(&self.extensions).with_execute(execute_fut);
next.execute(&self.create_context()).await
} else {
execute_fut.await
}
}
pub async fn resolve(
&self,
info: ResolveInfo<'_>,
resolve_fut: ResolveFut<'_>,
) -> ServerResult<Option<Value>> {
if !self.extensions.is_empty() {
let next = NextExtension::new(&self.extensions).with_resolve(resolve_fut);
next.resolve(&self.create_context(), info).await
} else {
resolve_fut.await
} }
} }
} }

View File

@ -1,49 +1,25 @@
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use async_graphql_parser::types::ExecutableDocument; use async_graphql_parser::types::ExecutableDocument;
use async_graphql_value::Variables; use async_graphql_value::Variables;
use opentelemetry::trace::{SpanKind, TraceContextExt, Tracer}; use futures_util::stream::BoxStream;
use futures_util::TryFutureExt;
use opentelemetry::trace::{FutureExt, SpanKind, TraceContextExt, Tracer};
use opentelemetry::{Context as OpenTelemetryContext, Key}; use opentelemetry::{Context as OpenTelemetryContext, Key};
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; use crate::extensions::{
use crate::{ServerError, ValidationResult}; Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo,
};
const REQUEST_CTX: usize = 0; use crate::{Response, ServerError, ServerResult, ValidationResult, Value};
const PARSE_CTX: usize = 1;
const VALIDATION_CTX: usize = 2;
const EXECUTE_CTX: usize = 3;
#[inline]
fn resolve_ctx_id(resolver_id: usize) -> usize {
resolver_id + 10
}
const KEY_SOURCE: Key = Key::from_static_str("graphql.source"); const KEY_SOURCE: Key = Key::from_static_str("graphql.source");
const KEY_VARIABLES: Key = Key::from_static_str("graphql.variables"); const KEY_VARIABLES: Key = Key::from_static_str("graphql.variables");
const KEY_PARENT_TYPE: Key = Key::from_static_str("graphql.parentType"); const KEY_PARENT_TYPE: Key = Key::from_static_str("graphql.parentType");
const KEY_RETURN_TYPE: Key = Key::from_static_str("graphql.returnType"); const KEY_RETURN_TYPE: Key = Key::from_static_str("graphql.returnType");
const KEY_RESOLVE_ID: Key = Key::from_static_str("graphql.resolveId");
const KEY_ERROR: Key = Key::from_static_str("graphql.error"); const KEY_ERROR: Key = Key::from_static_str("graphql.error");
const KEY_COMPLEXITY: Key = Key::from_static_str("graphql.complexity"); const KEY_COMPLEXITY: Key = Key::from_static_str("graphql.complexity");
const KEY_DEPTH: Key = Key::from_static_str("graphql.depth"); const KEY_DEPTH: Key = Key::from_static_str("graphql.depth");
/// OpenTelemetry extension configuration for each request.
#[derive(Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))]
pub struct OpenTelemetryConfig {
/// Use a context as the parent node of the entire query.
parent: spin::Mutex<Option<OpenTelemetryContext>>,
}
impl OpenTelemetryConfig {
/// Use a context as the parent of the entire query.
pub fn parent_context(mut self, cx: OpenTelemetryContext) -> Self {
*self.parent.get_mut() = Some(cx);
self
}
}
/// OpenTelemetry extension /// OpenTelemetry extension
#[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))] #[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))]
pub struct OpenTelemetry<T> { pub struct OpenTelemetry<T> {
@ -63,159 +39,127 @@ impl<T> OpenTelemetry<T> {
} }
impl<T: Tracer + Send + Sync> ExtensionFactory for OpenTelemetry<T> { impl<T: Tracer + Send + Sync> ExtensionFactory for OpenTelemetry<T> {
fn create(&self) -> Box<dyn Extension> { fn create(&self) -> Arc<dyn Extension> {
Box::new(OpenTelemetryExtension { Arc::new(OpenTelemetryExtension {
tracer: self.tracer.clone(), tracer: self.tracer.clone(),
contexts: Default::default(),
}) })
} }
} }
struct OpenTelemetryExtension<T> { struct OpenTelemetryExtension<T> {
tracer: Arc<T>, tracer: Arc<T>,
contexts: HashMap<usize, OpenTelemetryContext>,
}
impl<T> OpenTelemetryExtension<T> {
fn enter_context(&mut self, id: usize, cx: OpenTelemetryContext) {
let _ = cx.clone().attach();
self.contexts.insert(id, cx);
}
fn exit_context(&mut self, id: usize) -> Option<OpenTelemetryContext> {
if let Some(cx) = self.contexts.remove(&id) {
let _ = cx.clone().attach();
Some(cx)
} else {
None
}
}
} }
#[async_trait::async_trait]
impl<T: Tracer + Send + Sync> Extension for OpenTelemetryExtension<T> { impl<T: Tracer + Send + Sync> Extension for OpenTelemetryExtension<T> {
fn start(&mut self, ctx: &ExtensionContext<'_>) { async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
let request_cx = ctx next.request(ctx)
.data_opt::<OpenTelemetryConfig>() .with_context(OpenTelemetryContext::current_with_span(
.and_then(|cfg| cfg.parent.lock().take()) self.tracer
.unwrap_or_else(|| { .span_builder("request")
OpenTelemetryContext::current_with_span( .with_kind(SpanKind::Server)
.start(&*self.tracer),
))
.await
}
fn subscribe<'s>(
&self,
ctx: &ExtensionContext<'_>,
stream: BoxStream<'s, Response>,
next: NextExtension<'_>,
) -> BoxStream<'s, Response> {
Box::pin(
next.subscribe(ctx, stream)
.with_context(OpenTelemetryContext::current_with_span(
self.tracer self.tracer
.span_builder("request") .span_builder("subscribe")
.with_kind(SpanKind::Server) .with_kind(SpanKind::Server)
.start(&*self.tracer), .start(&*self.tracer),
) )),
}); )
self.enter_context(REQUEST_CTX, request_cx);
} }
fn end(&mut self, _ctx: &ExtensionContext<'_>) { async fn parse_query(
self.exit_context(REQUEST_CTX); &self,
} ctx: &ExtensionContext<'_>,
query: &str,
fn parse_start(
&mut self,
_ctx: &ExtensionContext<'_>,
query_source: &str,
variables: &Variables, variables: &Variables,
) { next: NextExtension<'_>,
if let Some(parent_cx) = self.contexts.get(&REQUEST_CTX).cloned() { ) -> ServerResult<ExecutableDocument> {
let attributes = vec![ let attributes = vec![
KEY_SOURCE.string(query_source.to_string()), KEY_SOURCE.string(query.to_string()),
KEY_VARIABLES.string(serde_json::to_string(variables).unwrap()), KEY_VARIABLES.string(serde_json::to_string(variables).unwrap()),
]; ];
let parse_span = self let span = self
.tracer .tracer
.span_builder("parse") .span_builder("parse")
.with_kind(SpanKind::Server) .with_kind(SpanKind::Server)
.with_attributes(attributes) .with_attributes(attributes)
.with_parent_context(parent_cx) .start(&*self.tracer);
.start(&*self.tracer); next.parse_query(ctx, query, variables)
let parse_cx = OpenTelemetryContext::current_with_span(parse_span); .with_context(OpenTelemetryContext::current_with_span(span))
self.enter_context(PARSE_CTX, parse_cx); .await
}
} }
fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, _document: &ExecutableDocument) { async fn validation(
self.exit_context(PARSE_CTX); &self,
ctx: &ExtensionContext<'_>,
next: NextExtension<'_>,
) -> Result<ValidationResult, Vec<ServerError>> {
let span = self
.tracer
.span_builder("validation")
.with_kind(SpanKind::Server)
.start(&*self.tracer);
next.validation(ctx)
.with_context(OpenTelemetryContext::current_with_span(span))
.map_ok(|res| {
let current_cx = OpenTelemetryContext::current();
let span = current_cx.span();
span.set_attribute(KEY_COMPLEXITY.i64(res.complexity as i64));
span.set_attribute(KEY_DEPTH.i64(res.depth as i64));
res
})
.await
} }
fn validation_start(&mut self, _ctx: &ExtensionContext<'_>) { async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
if let Some(parent_cx) = self.contexts.get(&REQUEST_CTX).cloned() { let span = self
let span = self .tracer
.tracer .span_builder("execute")
.span_builder("validation") .with_kind(SpanKind::Server)
.with_kind(SpanKind::Server) .start(&*self.tracer);
.with_parent_context(parent_cx) next.execute(ctx)
.start(&*self.tracer); .with_context(OpenTelemetryContext::current_with_span(span))
let validation_cx = OpenTelemetryContext::current_with_span(span); .await
self.enter_context(VALIDATION_CTX, validation_cx);
}
} }
fn validation_end(&mut self, _ctx: &ExtensionContext<'_>, result: &ValidationResult) { async fn resolve(
if let Some(validation_cx) = self.exit_context(VALIDATION_CTX) { &self,
let span = validation_cx.span(); ctx: &ExtensionContext<'_>,
span.set_attribute(KEY_COMPLEXITY.i64(result.complexity as i64)); info: ResolveInfo<'_>,
span.set_attribute(KEY_DEPTH.i64(result.depth as i64)); next: NextExtension<'_>,
} ) -> ServerResult<Option<Value>> {
} let attributes = vec![
KEY_PARENT_TYPE.string(info.parent_type.to_string()),
fn execution_start(&mut self, _ctx: &ExtensionContext<'_>) { KEY_RETURN_TYPE.string(info.return_type.to_string()),
let span = match self.contexts.get(&REQUEST_CTX).cloned() { ];
Some(parent_cx) => self let span = self
.tracer .tracer
.span_builder("execute") .span_builder(&info.path_node.to_string())
.with_kind(SpanKind::Server) .with_kind(SpanKind::Server)
.with_parent_context(parent_cx) .with_attributes(attributes)
.start(&*self.tracer), .start(&*self.tracer);
None => self next.resolve(ctx, info)
.tracer .with_context(OpenTelemetryContext::current_with_span(span))
.span_builder("execute") .map_err(|err| {
.with_kind(SpanKind::Server) let current_cx = OpenTelemetryContext::current();
.start(&*self.tracer), current_cx
}; .span()
let execute_cx = OpenTelemetryContext::current_with_span(span); .add_event("error".to_string(), vec![KEY_ERROR.string(err.to_string())]);
self.enter_context(EXECUTE_CTX, execute_cx); err
} })
.await
fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) {
self.exit_context(EXECUTE_CTX);
}
fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
let parent_cx = match info.resolve_id.parent {
Some(parent_id) if parent_id > 0 => self.contexts.get(&resolve_ctx_id(parent_id)),
_ => self.contexts.get(&EXECUTE_CTX),
}
.cloned();
if let Some(parent_cx) = parent_cx {
let attributes = vec![
KEY_RESOLVE_ID.i64(info.resolve_id.current as i64),
KEY_PARENT_TYPE.string(info.parent_type.to_string()),
KEY_RETURN_TYPE.string(info.return_type.to_string()),
];
let span = self
.tracer
.span_builder(&info.path_node.to_string())
.with_kind(SpanKind::Server)
.with_parent_context(parent_cx)
.with_attributes(attributes)
.start(&*self.tracer);
let resolve_cx = OpenTelemetryContext::current_with_span(span);
self.enter_context(resolve_ctx_id(info.resolve_id.current), resolve_cx);
}
}
fn resolve_end(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
self.exit_context(resolve_ctx_id(info.resolve_id.current));
}
fn error(&mut self, _ctx: &ExtensionContext<'_>, err: &ServerError) {
if let Some(parent_cx) = self.contexts.get(&EXECUTE_CTX).cloned() {
parent_cx
.span()
.add_event("error".to_string(), vec![KEY_ERROR.string(err.to_string())]);
}
} }
} }

View File

@ -1,36 +1,15 @@
use std::collections::HashMap; use std::sync::Arc;
use tracing::{span, Level, Span}; use futures_util::stream::BoxStream;
use tracing_futures::Instrument;
use tracinglib::{span, Level};
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; use crate::extensions::{
Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo,
};
use crate::futures_util::TryFutureExt;
use crate::parser::types::ExecutableDocument; use crate::parser::types::ExecutableDocument;
use crate::{ServerError, ValidationResult, Variables}; use crate::{Response, ServerError, ServerResult, ValidationResult, Value, Variables};
/// Tracing extension configuration for each request.
#[derive(Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "tracing")))]
pub struct TracingConfig {
/// Use a span as the parent node of the entire query.
parent: spin::Mutex<Option<Span>>,
}
impl TracingConfig {
/// Use a span as the parent of the entire query.
pub fn parent_span(mut self, span: Span) -> Self {
*self.parent.get_mut() = Some(span);
self
}
}
const REQUEST_CTX: usize = 0;
const PARSE_CTX: usize = 1;
const VALIDATION_CTX: usize = 2;
const EXECUTE_CTX: usize = 3;
#[inline]
fn resolve_span_id(resolver_id: usize) -> usize {
resolver_id + 10
}
/// Tracing extension /// Tracing extension
/// ///
@ -42,7 +21,7 @@ fn resolve_span_id(resolver_id: usize) -> usize {
/// ///
/// ```no_run /// ```no_run
/// use async_graphql::*; /// use async_graphql::*;
/// use async_graphql::extensions::{Tracing, TracingConfig}; /// use async_graphql::extensions::Tracing;
/// use tracing::{span, Level, Instrument}; /// use tracing::{span, Level, Instrument};
/// ///
/// #[derive(SimpleObject)] /// #[derive(SimpleObject)]
@ -50,164 +29,112 @@ fn resolve_span_id(resolver_id: usize) -> usize {
/// value: i32, /// value: i32,
/// } /// }
/// ///
/// let schema = Schema::build(Query { value: 100 }, EmptyMutation, EmptySubscription). /// let schema = Schema::build(Query { value: 100 }, EmptyMutation, EmptySubscription)
/// extension(Tracing::default()) /// .extension(Tracing)
/// .finish(); /// .finish();
/// ///
/// tokio::runtime::Runtime::new().unwrap().block_on(async { /// tokio::runtime::Runtime::new().unwrap().block_on(async {
/// schema.execute(Request::new("{ value }")).await; /// schema.execute(Request::new("{ value }")).await;
/// }); /// });
///
/// // tracing in parent span
/// tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let root_span = span!(
/// parent: None,
/// Level::INFO,
/// "span root"
/// );
/// schema.execute(Request::new("{ value }")).instrument(root_span).await;
/// });
///
/// // replace root span
/// tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let root_span = span!(
/// parent: None,
/// Level::INFO,
/// "span root"
/// );
/// schema.execute(Request::new("{ value }").data(TracingConfig::default().parent_span(root_span))).await;
/// });
/// ``` /// ```
#[derive(Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "tracing")))] #[cfg_attr(docsrs, doc(cfg(feature = "tracing")))]
pub struct Tracing; pub struct Tracing;
impl ExtensionFactory for Tracing { impl ExtensionFactory for Tracing {
fn create(&self) -> Box<dyn Extension> { fn create(&self) -> Arc<dyn Extension> {
Box::new(TracingExtension::default()) Arc::new(TracingExtension::default())
} }
} }
#[derive(Default)] #[derive(Default)]
struct TracingExtension { struct TracingExtension;
spans: HashMap<usize, Span>,
}
impl TracingExtension {
fn enter_span(&mut self, id: usize, span: Span) -> &Span {
let _ = span.enter();
self.spans.insert(id, span);
self.spans.get(&id).unwrap()
}
fn exit_span(&mut self, id: usize) {
if let Some(span) = self.spans.remove(&id) {
let _ = span.enter();
}
}
}
#[async_trait::async_trait]
impl Extension for TracingExtension { impl Extension for TracingExtension {
fn start(&mut self, ctx: &ExtensionContext<'_>) { async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
let request_span = ctx next.request(ctx)
.data_opt::<TracingConfig>() .instrument(span!(
.and_then(|cfg| cfg.parent.lock().take()) target: "async_graphql::graphql",
.unwrap_or_else(|| { Level::INFO,
span!( "request",
target: "async_graphql::graphql", ))
Level::INFO, .await
"request",
)
});
self.enter_span(REQUEST_CTX, request_span);
} }
fn end(&mut self, _ctx: &ExtensionContext<'_>) { fn subscribe<'s>(
self.exit_span(REQUEST_CTX); &self,
ctx: &ExtensionContext<'_>,
stream: BoxStream<'s, Response>,
next: NextExtension<'_>,
) -> BoxStream<'s, Response> {
Box::pin(next.subscribe(ctx, stream).instrument(span!(
target: "async_graphql::graphql",
Level::INFO,
"subscribe",
)))
} }
fn parse_start( async fn parse_query(
&mut self, &self,
_ctx: &ExtensionContext<'_>, ctx: &ExtensionContext<'_>,
query_source: &str, query: &str,
variables: &Variables, variables: &Variables,
) { next: NextExtension<'_>,
if let Some(parent) = self.spans.get(&REQUEST_CTX) { ) -> ServerResult<ExecutableDocument> {
let variables = serde_json::to_string(&variables).unwrap(); let span = span!(
let parse_span = span!( target: "async_graphql::graphql",
target: "async_graphql::graphql", Level::INFO,
parent: parent, "parse",
Level::INFO, source = query,
"parse", variables = %serde_json::to_string(&variables).unwrap(),
source = query_source, );
variables = %variables, next.parse_query(ctx, query, variables)
); .instrument(span)
self.enter_span(PARSE_CTX, parse_span); .await
}
} }
fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, _document: &ExecutableDocument) { async fn validation(
self.exit_span(PARSE_CTX); &self,
ctx: &ExtensionContext<'_>,
next: NextExtension<'_>,
) -> Result<ValidationResult, Vec<ServerError>> {
let span = span!(
target: "async_graphql::graphql",
Level::INFO,
"validation"
);
next.validation(ctx).instrument(span).await
} }
fn validation_start(&mut self, _ctx: &ExtensionContext<'_>) { async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
if let Some(parent) = self.spans.get(&REQUEST_CTX) { let span = span!(
let span = span!( target: "async_graphql::graphql",
target: "async_graphql::graphql", Level::INFO,
parent: parent, "execute"
Level::INFO, );
"validation" next.execute(ctx).instrument(span).await
);
self.enter_span(VALIDATION_CTX, span);
}
} }
fn validation_end(&mut self, _ctx: &ExtensionContext<'_>, _result: &ValidationResult) { async fn resolve(
self.exit_span(VALIDATION_CTX); &self,
} ctx: &ExtensionContext<'_>,
info: ResolveInfo<'_>,
fn execution_start(&mut self, _ctx: &ExtensionContext<'_>) { next: NextExtension<'_>,
if let Some(parent) = self.spans.get(&REQUEST_CTX) { ) -> ServerResult<Option<Value>> {
let span = span!( let span = span!(
target: "async_graphql::graphql", target: "async_graphql::graphql",
parent: parent, Level::INFO,
Level::INFO, "field",
"execute" path = %info.path_node,
); parent_type = %info.parent_type,
self.enter_span(EXECUTE_CTX, span); return_type = %info.return_type,
}; );
} next.resolve(ctx, info)
.instrument(span)
fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) { .map_err(|err| {
self.exit_span(EXECUTE_CTX); tracinglib::error!(target: "async_graphql::graphql", error = %err.message);
} err
})
fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { .await
let parent = match info.resolve_id.parent {
Some(parent_id) if parent_id > 0 => self.spans.get(&resolve_span_id(parent_id)),
_ => self.spans.get(&EXECUTE_CTX),
};
if let Some(parent) = parent {
let span = span!(
target: "async_graphql::graphql",
parent: parent,
Level::INFO,
"field",
id = %info.resolve_id.current,
path = %info.path_node,
parent_type = %info.parent_type,
return_type = %info.return_type,
);
self.enter_span(resolve_span_id(info.resolve_id.current), span);
}
}
fn resolve_end(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
self.exit_span(resolve_span_id(info.resolve_id.current));
}
fn error(&mut self, _ctx: &ExtensionContext<'_>, err: &ServerError) {
tracing::error!(target: "async_graphql::graphql", error = %err.message);
} }
} }

View File

@ -2,7 +2,7 @@ use std::collections::BTreeMap;
use std::future::Future; use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
use crate::extensions::{ErrorLogger, ResolveInfo}; use crate::extensions::ResolveInfo;
use crate::parser::types::Selection; use crate::parser::types::Selection;
use crate::registry::MetaType; use crate::registry::MetaType;
use crate::{ use crate::{
@ -174,18 +174,18 @@ impl<'a> Fields<'a> {
async move { async move {
let ctx_field = ctx.with_field(field); let ctx_field = ctx.with_field(field);
let field_name = ctx_field.item.node.response_key().node.clone(); let field_name = ctx_field.item.node.response_key().node.clone();
let extensions = &ctx.query_env.extensions;
let res = if ctx_field.query_env.extensions.is_empty() { if extensions.is_empty() {
match root.resolve_field(&ctx_field).await { match root.resolve_field(&ctx_field).await {
Ok(value) => Ok((field_name, value.unwrap_or_default())), Ok(value) => Ok((field_name, value.unwrap_or_default())),
Err(e) => { Err(e) => {
Err(e.path(PathSegment::Field(field_name.to_string()))) Err(e.path(PathSegment::Field(field_name.to_string())))
} }
}? }
} else { } else {
let type_name = T::type_name(); let type_name = T::type_name();
let resolve_info = ResolveInfo { let resolve_info = ResolveInfo {
resolve_id: ctx_field.resolve_id,
path_node: ctx_field.path_node.as_ref().unwrap(), path_node: ctx_field.path_node.as_ref().unwrap(),
parent_type: &type_name, parent_type: &type_name,
return_type: match ctx_field return_type: match ctx_field
@ -210,20 +210,16 @@ impl<'a> Fields<'a> {
}, },
}; };
ctx_field.query_env.extensions.resolve_start(&resolve_info); let resolve_fut = async { root.resolve_field(&ctx_field).await };
futures_util::pin_mut!(resolve_fut);
let res = match root.resolve_field(&ctx_field).await { let res = extensions.resolve(resolve_info, &mut resolve_fut).await;
match res {
Ok(value) => Ok((field_name, value.unwrap_or_default())), Ok(value) => Ok((field_name, value.unwrap_or_default())),
Err(e) => { Err(e) => {
Err(e.path(PathSegment::Field(field_name.to_string()))) Err(e.path(PathSegment::Field(field_name.to_string())))
} }
} }
.log_error(&ctx_field.query_env.extensions)?; }
ctx_field.query_env.extensions.resolve_end(&resolve_info);
res
};
Ok(res)
} }
})); }));
} }

View File

@ -1,4 +1,4 @@
use crate::extensions::{ErrorLogger, ResolveInfo}; use crate::extensions::ResolveInfo;
use crate::parser::types::Field; use crate::parser::types::Field;
use crate::{ContextSelectionSet, OutputType, PathSegment, Positioned, ServerResult, Type, Value}; use crate::{ContextSelectionSet, OutputType, PathSegment, Positioned, ServerResult, Type, Value};
@ -9,38 +9,49 @@ pub async fn resolve_list<'a, T: OutputType + 'a>(
iter: impl IntoIterator<Item = T>, iter: impl IntoIterator<Item = T>,
len: Option<usize>, len: Option<usize>,
) -> ServerResult<Value> { ) -> ServerResult<Value> {
let mut futures = len.map(Vec::with_capacity).unwrap_or_default(); let extensions = &ctx.query_env.extensions;
for (idx, item) in iter.into_iter().enumerate() { if extensions.is_empty() {
let ctx_idx = ctx.with_index(idx); let mut futures = len.map(Vec::with_capacity).unwrap_or_default();
futures.push(async move { for (idx, item) in iter.into_iter().enumerate() {
if ctx_idx.query_env.extensions.is_empty() { futures.push({
let ctx = ctx.clone();
async move {
let ctx_idx = ctx.with_index(idx);
let resolve_info = ResolveInfo {
path_node: ctx_idx.path_node.as_ref().unwrap(),
parent_type: &Vec::<T>::type_name(),
return_type: &T::qualified_type_name(),
};
let resolve_fut = async {
OutputType::resolve(&item, &ctx_idx, field)
.await
.map(Option::Some)
.map_err(|e| e.path(PathSegment::Index(idx)))
};
futures_util::pin_mut!(resolve_fut);
extensions
.resolve(resolve_info, &mut resolve_fut)
.await
.map(|value| value.expect("You definitely encountered a bug!"))
}
});
}
Ok(Value::List(
futures_util::future::try_join_all(futures).await?,
))
} else {
let mut futures = len.map(Vec::with_capacity).unwrap_or_default();
for (idx, item) in iter.into_iter().enumerate() {
let ctx_idx = ctx.with_index(idx);
futures.push(async move {
OutputType::resolve(&item, &ctx_idx, field) OutputType::resolve(&item, &ctx_idx, field)
.await .await
.map_err(|e| e.path(PathSegment::Index(idx))) .map_err(|e| e.path(PathSegment::Index(idx)))
.log_error(&ctx_idx.query_env.extensions) });
} else { }
let resolve_info = ResolveInfo { Ok(Value::List(
resolve_id: ctx_idx.resolve_id, futures_util::future::try_join_all(futures).await?,
path_node: ctx_idx.path_node.as_ref().unwrap(), ))
parent_type: &Vec::<T>::type_name(),
return_type: &T::qualified_type_name(),
};
ctx_idx.query_env.extensions.resolve_start(&resolve_info);
let res = OutputType::resolve(&item, &ctx_idx, field)
.await
.map_err(|e| e.path(PathSegment::Index(idx)))
.log_error(&ctx_idx.query_env.extensions)?;
ctx_idx.query_env.extensions.resolve_end(&resolve_info);
Ok(res)
}
});
} }
Ok(Value::List(
futures_util::future::try_join_all(futures).await?,
))
} }

View File

@ -1,5 +1,6 @@
use http::header::HeaderMap; use std::collections::BTreeMap;
use http::header::HeaderMap;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{CacheControl, Result, ServerError, Value}; use crate::{CacheControl, Result, ServerError, Value};
@ -12,8 +13,8 @@ pub struct Response {
pub data: Value, pub data: Value,
/// Extensions result /// Extensions result
#[serde(skip_serializing_if = "Option::is_none", default)] #[serde(skip_serializing_if = "BTreeMap::is_empty", default)]
pub extensions: Option<Value>, pub extensions: BTreeMap<String, Value>,
/// Cache control value /// Cache control value
#[serde(skip)] #[serde(skip)]
@ -47,10 +48,11 @@ impl Response {
} }
} }
/// Set the extensions result of the response. /// Set the extension result of the response.
#[must_use] #[must_use]
pub fn extensions(self, extensions: Option<Value>) -> Self { pub fn extension(mut self, name: impl Into<String>, value: Value) -> Self {
Self { extensions, ..self } self.extensions.insert(name.into(), value);
self
} }
/// Set the http headers of the response. /// Set the http headers of the response.

View File

@ -1,14 +1,12 @@
use std::any::Any; use std::any::Any;
use std::collections::BTreeMap;
use std::ops::Deref; use std::ops::Deref;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc; use std::sync::Arc;
use futures_util::stream::{self, Stream, StreamExt}; use futures_util::stream::{self, Stream, StreamExt};
use indexmap::map::IndexMap; use indexmap::map::IndexMap;
use crate::context::{Data, QueryEnvInner, ResolveId}; use crate::context::{Data, QueryEnvInner};
use crate::extensions::{ErrorLogger, ExtensionFactory, Extensions}; use crate::extensions::{ExtensionFactory, Extensions};
use crate::model::__DirectiveLocation; use crate::model::__DirectiveLocation;
use crate::parser::parse_query; use crate::parser::parse_query;
use crate::parser::types::{DocumentOperations, OperationType}; use crate::parser::types::{DocumentOperations, OperationType};
@ -19,7 +17,7 @@ use crate::types::QueryRoot;
use crate::validation::{check_rules, ValidationMode}; use crate::validation::{check_rules, ValidationMode};
use crate::{ use crate::{
BatchRequest, BatchResponse, CacheControl, ContextBase, ObjectType, QueryEnv, Request, BatchRequest, BatchResponse, CacheControl, ContextBase, ObjectType, QueryEnv, Request,
Response, ServerError, SubscriptionType, Type, Value, ID, Response, ServerError, SubscriptionType, Type, ID,
}; };
/// Schema builder /// Schema builder
@ -354,17 +352,11 @@ where
} }
fn create_extensions(&self, session_data: Arc<Data>) -> Extensions { fn create_extensions(&self, session_data: Arc<Data>) -> Extensions {
let extensions = Extensions::new( Extensions::new(
self.0 self.extensions.iter().map(|f| f.create()),
.extensions
.iter()
.map(|factory| factory.create())
.collect::<Vec<_>>(),
self.env.clone(), self.env.clone(),
session_data, session_data,
); )
extensions.start();
extensions
} }
async fn prepare_request( async fn prepare_request(
@ -376,36 +368,41 @@ where
let mut request = request; let mut request = request;
let query_data = Arc::new(std::mem::take(&mut request.data)); let query_data = Arc::new(std::mem::take(&mut request.data));
extensions.attach_query_data(query_data.clone()); extensions.attach_query_data(query_data.clone());
let request = extensions.prepare_request(request).await?;
extensions.parse_start(&request.query, &request.variables); let request = extensions.prepare_request(request).await?;
let document = parse_query(&request.query) let document = {
.map_err(Into::<ServerError>::into) let query = &request.query;
.log_error(&extensions)?; let fut_parse = async { parse_query(&query).map_err(Into::<ServerError>::into) };
extensions.parse_end(&document); futures_util::pin_mut!(fut_parse);
extensions
.parse_query(&query, &request.variables, &mut fut_parse)
.await?
};
// check rules // check rules
extensions.validation_start(); let validation_result = {
let validation_result = check_rules( let validation_fut = async {
&self.env.registry, check_rules(
&document, &self.env.registry,
Some(&request.variables), &document,
self.validation_mode, Some(&request.variables),
) self.validation_mode,
.log_error(&extensions)?; )
extensions.validation_end(&validation_result); };
futures_util::pin_mut!(validation_fut);
extensions.validation(&mut validation_fut).await?
};
// check limit // check limit
if let Some(limit_complexity) = self.complexity { if let Some(limit_complexity) = self.complexity {
if validation_result.complexity > limit_complexity { if validation_result.complexity > limit_complexity {
return Err(vec![ServerError::new("Query is too complex.")]).log_error(&extensions); return Err(vec![ServerError::new("Query is too complex.")]);
} }
} }
if let Some(limit_depth) = self.depth { if let Some(limit_depth) = self.depth {
if validation_result.depth > limit_depth { if validation_result.depth > limit_depth {
return Err(vec![ServerError::new("Query is nested too deep.")]) return Err(vec![ServerError::new("Query is nested too deep.")]);
.log_error(&extensions);
} }
} }
@ -432,10 +429,7 @@ where
}; };
let operation = match operation { let operation = match operation {
Ok(operation) => operation, Ok(operation) => operation,
Err(e) => { Err(e) => return Err(vec![e]),
extensions.error(&e);
return Err(vec![e]);
}
}; };
let env = QueryEnvInner { let env = QueryEnvInner {
@ -454,49 +448,58 @@ where
async fn execute_once(&self, env: QueryEnv) -> Response { async fn execute_once(&self, env: QueryEnv) -> Response {
// execute // execute
let inc_resolve_id = AtomicUsize::default();
let ctx = ContextBase { let ctx = ContextBase {
path_node: None, path_node: None,
resolve_id: ResolveId::root(),
inc_resolve_id: &inc_resolve_id,
item: &env.operation.node.selection_set, item: &env.operation.node.selection_set,
schema_env: &self.env, schema_env: &self.env,
query_env: &env, query_env: &env,
}; };
env.extensions.execution_start();
let data = match &env.operation.node.ty { let res = match &env.operation.node.ty {
OperationType::Query => resolve_container(&ctx, &self.query).await, OperationType::Query => resolve_container(&ctx, &self.query).await,
OperationType::Mutation => resolve_container_serial(&ctx, &self.mutation).await, OperationType::Mutation => resolve_container_serial(&ctx, &self.mutation).await,
OperationType::Subscription => { OperationType::Subscription => {
return Response::from_errors(vec![ServerError::new( return Response::from_errors(vec![ServerError::new(
"Subscriptions are not supported on this transport.", "Subscriptions are not supported on this transport.",
)]) )]);
} }
}; };
env.extensions.execution_end(); match res {
let extensions = env.extensions.result(); Ok(data) => {
let resp = Response::new(data);
match data { resp.http_headers(std::mem::take(&mut *env.http_headers.lock()))
Ok(data) => Response::new(data), }
Err(e) => Response::from_errors(vec![e]), Err(err) => Response::from_errors(vec![err]),
} }
.extensions(extensions)
.http_headers(std::mem::take(&mut *env.http_headers.lock()))
} }
/// Execute a GraphQL query. /// Execute a GraphQL query.
pub async fn execute(&self, request: impl Into<Request>) -> Response { pub async fn execute(&self, request: impl Into<Request>) -> Response {
let request = request.into(); let request = request.into();
let extensions = self.create_extensions(Default::default()); let extensions = self.create_extensions(Default::default());
match self let request_fut = {
.prepare_request(extensions, request, Default::default()) let extensions = extensions.clone();
.await async move {
{ match self
Ok((env, cache_control)) => self.execute_once(env).await.cache_control(cache_control), .prepare_request(extensions, request, Default::default())
Err(errors) => Response::from_errors(errors), .await
} {
Ok((env, cache_control)) => {
let fut = async {
self.execute_once(env.clone())
.await
.cache_control(cache_control)
};
futures_util::pin_mut!(fut);
env.extensions.execute(&mut fut).await
}
Err(errors) => Response::from_errors(errors),
}
}
};
futures_util::pin_mut!(request_fut);
extensions.request(&mut request_fut).await
} }
/// Execute a GraphQL batch query. /// Execute a GraphQL batch query.
@ -518,68 +521,52 @@ where
&self, &self,
request: impl Into<Request> + Send, request: impl Into<Request> + Send,
session_data: Arc<Data>, session_data: Arc<Data>,
) -> impl Stream<Item = Response> + Send { ) -> impl Stream<Item = Response> + Send + Unpin {
let schema = self.clone(); let schema = self.clone();
let request = request.into(); let request = request.into();
let extensions = self.create_extensions(session_data.clone()); let extensions = self.create_extensions(session_data.clone());
async_stream::stream! { let stream = futures_util::stream::StreamExt::boxed({
let (env, cache_control) = match schema.prepare_request(extensions, request, session_data).await { let extensions = extensions.clone();
Ok(res) => res, async_stream::stream! {
Err(errors) => { let (env, cache_control) = match schema.prepare_request(extensions, request, session_data).await {
yield Response::from_errors(errors); Ok(res) => res,
Err(errors) => {
yield Response::from_errors(errors);
return;
}
};
if env.operation.node.ty != OperationType::Subscription {
yield schema.execute_once(env).await.cache_control(cache_control);
return; return;
} }
};
if env.operation.node.ty != OperationType::Subscription { let ctx = env.create_context(
yield schema &schema.env,
.execute_once(env) None,
.await &env.operation.node.selection_set,
.cache_control(cache_control); );
return;
}
let resolve_id = AtomicUsize::default(); let mut streams = Vec::new();
let ctx = env.create_context( if let Err(err) = collect_subscription_streams(&ctx, &schema.subscription, &mut streams) {
&schema.env, yield Response::from_errors(vec![err]);
None, }
&env.operation.node.selection_set,
ResolveId::root(),
&resolve_id,
);
let mut streams = Vec::new(); let mut stream = stream::select_all(streams);
if let Err(e) = collect_subscription_streams(&ctx, &schema.subscription, &mut streams) { while let Some(resp) = stream.next().await {
env.extensions.execution_end(); yield resp;
yield Response::from_errors(vec![e]);
return;
}
let mut stream = stream::select_all(streams);
while let Some(data) = stream.next().await {
let is_err = data.is_err();
let extensions = env.extensions.result();
yield match data {
Ok((name, value)) => {
let mut map = BTreeMap::new();
map.insert(name, value);
Response::new(Value::Object(map))
},
Err(e) => Response::from_errors(vec![e]),
}.extensions(extensions);
if is_err {
break;
} }
} }
} });
extensions.subscribe(stream)
} }
/// Execute a GraphQL subscription. /// Execute a GraphQL subscription.
pub fn execute_stream( pub fn execute_stream(
&self, &self,
request: impl Into<Request>, request: impl Into<Request>,
) -> impl Stream<Item = Response> + Send { ) -> impl Stream<Item = Response> + Send + Unpin {
self.execute_stream_with_session_data(request.into(), Default::default()) self.execute_stream_with_session_data(request.into(), Default::default())
} }
} }

View File

@ -3,9 +3,7 @@ use std::pin::Pin;
use futures_util::stream::{Stream, StreamExt}; use futures_util::stream::{Stream, StreamExt};
use crate::parser::types::{Selection, TypeCondition}; use crate::parser::types::{Selection, TypeCondition};
use crate::{ use crate::{Context, ContextSelectionSet, PathSegment, Response, ServerError, ServerResult, Type};
Context, ContextSelectionSet, Name, PathSegment, ServerError, ServerResult, Type, Value,
};
/// A GraphQL subscription object /// A GraphQL subscription object
pub trait SubscriptionType: Type + Send + Sync { pub trait SubscriptionType: Type + Send + Sync {
@ -19,10 +17,10 @@ pub trait SubscriptionType: Type + Send + Sync {
fn create_field_stream<'a>( fn create_field_stream<'a>(
&'a self, &'a self,
ctx: &'a Context<'_>, ctx: &'a Context<'_>,
) -> Option<Pin<Box<dyn Stream<Item = ServerResult<Value>> + Send + 'a>>>; ) -> Option<Pin<Box<dyn Stream<Item = Response> + Send + 'a>>>;
} }
type BoxFieldStream<'a> = Pin<Box<dyn Stream<Item = ServerResult<(Name, Value)>> + 'a + Send>>; type BoxFieldStream<'a> = Pin<Box<dyn Stream<Item = Response> + 'a + Send>>;
pub(crate) fn collect_subscription_streams<'a, T: SubscriptionType + 'static>( pub(crate) fn collect_subscription_streams<'a, T: SubscriptionType + 'static>(
ctx: &ContextSelectionSet<'a>, ctx: &ContextSelectionSet<'a>,
@ -41,16 +39,14 @@ pub(crate) fn collect_subscription_streams<'a, T: SubscriptionType + 'static>(
let field_name = ctx.item.node.response_key().node.clone(); let field_name = ctx.item.node.response_key().node.clone();
let stream = root.create_field_stream(&ctx); let stream = root.create_field_stream(&ctx);
if let Some(mut stream) = stream { if let Some(mut stream) = stream {
while let Some(item) = stream.next().await { while let Some(resp) = stream.next().await {
yield match item { yield resp;
Ok(value) => Ok((field_name.to_owned(), value)),
Err(e) => Err(e.path(PathSegment::Field(field_name.to_string()))),
};
} }
} else { } else {
yield Err(ServerError::new(format!(r#"Cannot query field "{}" on type "{}"."#, field_name, T::type_name())) let err = ServerError::new(format!(r#"Cannot query field "{}" on type "{}"."#, field_name, T::type_name()))
.at(ctx.item.pos) .at(ctx.item.pos)
.path(PathSegment::Field(field_name.to_string()))); .path(PathSegment::Field(field_name.to_string()));
yield Response::from_errors(vec![err]);
} }
} }
})), })),
@ -98,7 +94,7 @@ impl<T: SubscriptionType> SubscriptionType for &T {
fn create_field_stream<'a>( fn create_field_stream<'a>(
&'a self, &'a self,
ctx: &'a Context<'_>, ctx: &'a Context<'_>,
) -> Option<Pin<Box<dyn Stream<Item = ServerResult<Value>> + Send + 'a>>> { ) -> Option<Pin<Box<dyn Stream<Item = Response> + Send + 'a>>> {
T::create_field_stream(*self, ctx) T::create_field_stream(*self, ctx)
} }
} }

View File

@ -3,7 +3,7 @@ use std::pin::Pin;
use futures_util::stream::{self, Stream}; use futures_util::stream::{self, Stream};
use crate::{registry, Context, ServerError, ServerResult, SubscriptionType, Type, Value}; use crate::{registry, Context, Response, ServerError, SubscriptionType, Type};
/// Empty subscription /// Empty subscription
/// ///
@ -37,12 +37,13 @@ impl SubscriptionType for EmptySubscription {
fn create_field_stream<'a>( fn create_field_stream<'a>(
&'a self, &'a self,
ctx: &'a Context<'_>, ctx: &'a Context<'_>,
) -> Option<Pin<Box<dyn Stream<Item = ServerResult<Value>> + Send + 'a>>> ) -> Option<Pin<Box<dyn Stream<Item = Response> + Send + 'a>>>
where where
Self: Send + Sync + 'static + Sized, Self: Send + Sync + 'static + Sized,
{ {
Some(Box::pin(stream::once(async move { Some(Box::pin(stream::once(async move {
Err(ServerError::new("Schema is not configured for mutations.").at(ctx.item.pos)) let err = ServerError::new("Schema is not configured for mutations.").at(ctx.item.pos);
Response::from_errors(vec![err])
}))) })))
} }
} }

View File

@ -16,6 +16,7 @@ pub use visitor::VisitorContext;
use visitor::{visit, VisitorNil}; use visitor::{visit, VisitorNil};
/// Validation results. /// Validation results.
#[derive(Debug, Copy, Clone)]
pub struct ValidationResult { pub struct ValidationResult {
/// Cache control /// Cache control
pub cache_control: CacheControl, pub cache_control: CacheControl,

View File

@ -1,12 +1,15 @@
use std::sync::Arc; use std::sync::Arc;
use async_graphql::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; use async_graphql::extensions::{
Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo,
};
use async_graphql::futures_util::stream::BoxStream;
use async_graphql::parser::types::ExecutableDocument; use async_graphql::parser::types::ExecutableDocument;
use async_graphql::*; use async_graphql::*;
use async_graphql_value::ConstValue; use async_graphql_value::ConstValue;
use futures_util::lock::Mutex;
use futures_util::stream::Stream; use futures_util::stream::Stream;
use futures_util::StreamExt; use futures_util::StreamExt;
use spin::Mutex;
#[tokio::test] #[tokio::test]
pub async fn test_extension_ctx() { pub async fn test_extension_ctx() {
@ -18,7 +21,7 @@ pub async fn test_extension_ctx() {
#[Object] #[Object]
impl Query { impl Query {
async fn value(&self, ctx: &Context<'_>) -> i32 { async fn value(&self, ctx: &Context<'_>) -> i32 {
*ctx.data_unchecked::<MyData>().0.lock() *ctx.data_unchecked::<MyData>().0.lock().await
} }
} }
@ -27,7 +30,7 @@ pub async fn test_extension_ctx() {
#[Subscription] #[Subscription]
impl Subscription { impl Subscription {
async fn value(&self, ctx: &Context<'_>) -> impl Stream<Item = i32> { async fn value(&self, ctx: &Context<'_>) -> impl Stream<Item = i32> {
let data = *ctx.data_unchecked::<MyData>().0.lock(); let data = *ctx.data_unchecked::<MyData>().0.lock().await;
futures_util::stream::once(async move { data }) futures_util::stream::once(async move { data })
} }
} }
@ -36,23 +39,25 @@ pub async fn test_extension_ctx() {
#[async_trait::async_trait] #[async_trait::async_trait]
impl Extension for MyExtensionImpl { impl Extension for MyExtensionImpl {
fn parse_start( async fn parse_query(
&mut self, &self,
ctx: &ExtensionContext<'_>, ctx: &ExtensionContext<'_>,
_query_source: &str, query: &str,
_variables: &Variables, variables: &Variables,
) { next: NextExtension<'_>,
) -> ServerResult<ExecutableDocument> {
if let Ok(data) = ctx.data::<MyData>() { if let Ok(data) = ctx.data::<MyData>() {
*data.0.lock() = 100; *data.0.lock().await = 100;
} }
next.parse_query(ctx, query, variables).await
} }
} }
struct MyExtension; struct MyExtension;
impl ExtensionFactory for MyExtension { impl ExtensionFactory for MyExtension {
fn create(&self) -> Box<dyn Extension> { fn create(&self) -> Arc<dyn Extension> {
Box::new(MyExtensionImpl) Arc::new(MyExtensionImpl)
} }
} }
@ -104,12 +109,10 @@ pub async fn test_extension_ctx() {
let mut data = Data::default(); let mut data = Data::default();
data.insert(MyData::default()); data.insert(MyData::default());
let mut stream = schema let mut stream = schema.execute_stream_with_session_data(
.execute_stream_with_session_data( Request::new("subscription { value }"),
Request::new("subscription { value }"), Arc::new(data),
Arc::new(data), );
)
.boxed();
assert_eq!( assert_eq!(
stream.next().await.unwrap().into_result().unwrap().data, stream.next().await.unwrap().into_result().unwrap().data,
value! ({ value! ({
@ -128,67 +131,83 @@ pub async fn test_extension_call_order() {
#[async_trait::async_trait] #[async_trait::async_trait]
#[allow(unused_variables)] #[allow(unused_variables)]
impl Extension for MyExtensionImpl { impl Extension for MyExtensionImpl {
fn name(&self) -> Option<&'static str> { async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
Some("test") self.calls.lock().await.push("request_start");
let res = next.request(ctx).await;
self.calls.lock().await.push("request_end");
res
} }
fn start(&mut self, ctx: &ExtensionContext<'_>) { fn subscribe<'s>(
self.calls.lock().push("start"); &self,
} ctx: &ExtensionContext<'_>,
mut stream: BoxStream<'s, Response>,
fn end(&mut self, ctx: &ExtensionContext<'_>) { next: NextExtension<'_>,
self.calls.lock().push("end"); ) -> BoxStream<'s, Response> {
let calls = self.calls.clone();
let stream = async_stream::stream! {
calls.lock().await.push("subscribe_start");
while let Some(item) = stream.next().await {
yield item;
}
calls.lock().await.push("subscribe_end");
};
Box::pin(stream)
} }
async fn prepare_request( async fn prepare_request(
&mut self, &self,
ctx: &ExtensionContext<'_>, ctx: &ExtensionContext<'_>,
request: Request, request: Request,
next: NextExtension<'_>,
) -> ServerResult<Request> { ) -> ServerResult<Request> {
self.calls.lock().push("prepare_request"); self.calls.lock().await.push("prepare_request_start");
Ok(request) let res = next.prepare_request(ctx, request).await;
self.calls.lock().await.push("prepare_request_end");
res
} }
fn parse_start( async fn parse_query(
&mut self, &self,
ctx: &ExtensionContext<'_>, ctx: &ExtensionContext<'_>,
query_source: &str, query: &str,
variables: &Variables, variables: &Variables,
) { next: NextExtension<'_>,
self.calls.lock().push("parse_start"); ) -> ServerResult<ExecutableDocument> {
self.calls.lock().await.push("parse_query_start");
let res = next.parse_query(ctx, query, variables).await;
self.calls.lock().await.push("parse_query_end");
res
} }
fn parse_end(&mut self, ctx: &ExtensionContext<'_>, document: &ExecutableDocument) { async fn validation(
self.calls.lock().push("parse_end"); &self,
ctx: &ExtensionContext<'_>,
next: NextExtension<'_>,
) -> Result<ValidationResult, Vec<ServerError>> {
self.calls.lock().await.push("validation_start");
let res = next.validation(ctx).await;
self.calls.lock().await.push("validation_end");
res
} }
fn validation_start(&mut self, ctx: &ExtensionContext<'_>) { async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
self.calls.lock().push("validation_start"); self.calls.lock().await.push("execute_start");
let res = next.execute(ctx).await;
self.calls.lock().await.push("execute_end");
res
} }
fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) { async fn resolve(
self.calls.lock().push("validation_end"); &self,
} ctx: &ExtensionContext<'_>,
info: ResolveInfo<'_>,
fn execution_start(&mut self, ctx: &ExtensionContext<'_>) { next: NextExtension<'_>,
self.calls.lock().push("execution_start"); ) -> ServerResult<Option<ConstValue>> {
} self.calls.lock().await.push("resolve_start");
let res = next.resolve(ctx, info).await;
fn execution_end(&mut self, ctx: &ExtensionContext<'_>) { self.calls.lock().await.push("resolve_end");
self.calls.lock().push("execution_end"); res
}
fn resolve_start(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
self.calls.lock().push("resolve_start");
}
fn resolve_end(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
self.calls.lock().push("resolve_end");
}
fn result(&mut self, ctx: &ExtensionContext<'_>) -> Option<ConstValue> {
self.calls.lock().push("result");
None
} }
} }
@ -197,8 +216,8 @@ pub async fn test_extension_call_order() {
} }
impl ExtensionFactory for MyExtension { impl ExtensionFactory for MyExtension {
fn create(&self) -> Box<dyn Extension> { fn create(&self) -> Arc<dyn Extension> {
Box::new(MyExtensionImpl { Arc::new(MyExtensionImpl {
calls: self.calls.clone(), calls: self.calls.clone(),
}) })
} }
@ -238,24 +257,24 @@ pub async fn test_extension_call_order() {
.await .await
.into_result() .into_result()
.unwrap(); .unwrap();
let calls = calls.lock(); let calls = calls.lock().await;
assert_eq!( assert_eq!(
&*calls, &*calls,
&vec![ &vec![
"start", "request_start",
"prepare_request", "prepare_request_start",
"parse_start", "prepare_request_end",
"parse_end", "parse_query_start",
"parse_query_end",
"validation_start", "validation_start",
"validation_end", "validation_end",
"execution_start", "execute_start",
"resolve_start", "resolve_start",
"resolve_end", "resolve_end",
"resolve_start", "resolve_start",
"resolve_end", "resolve_end",
"execution_end", "execute_end",
"result", "request_end",
"end",
] ]
); );
} }
@ -267,34 +286,36 @@ pub async fn test_extension_call_order() {
calls: calls.clone(), calls: calls.clone(),
}) })
.finish(); .finish();
let mut stream = schema.execute_stream("subscription { value }").boxed(); let mut stream = schema.execute_stream("subscription { value }");
while let Some(_) = stream.next().await {} while let Some(_) = stream.next().await {}
let calls = calls.lock(); let calls = calls.lock().await;
assert_eq!( assert_eq!(
&*calls, &*calls,
&vec![ &vec![
"start", "subscribe_start",
"prepare_request", "prepare_request_start",
"parse_start", "prepare_request_end",
"parse_end", "parse_query_start",
"parse_query_end",
"validation_start", "validation_start",
"validation_end", "validation_end",
"execution_start", // push 1
"execute_start",
"resolve_start", "resolve_start",
"resolve_end", "resolve_end",
"execution_end", "execute_end",
"result", // push 2
"execution_start", "execute_start",
"resolve_start", "resolve_start",
"resolve_end", "resolve_end",
"execution_end", "execute_end",
"result", // push 3
"execution_start", "execute_start",
"resolve_start", "resolve_start",
"resolve_end", "resolve_end",
"execution_end", "execute_end",
"result", // end
"end", "subscribe_end",
] ]
); );
} }

View File

@ -119,7 +119,7 @@ pub async fn test_field_features() {
}] }]
); );
let mut stream = schema.execute_stream("subscription { values }").boxed(); let mut stream = schema.execute_stream("subscription { values }");
assert_eq!( assert_eq!(
stream stream
.next() .next()
@ -131,7 +131,7 @@ pub async fn test_field_features() {
}) })
); );
let mut stream = schema.execute_stream("subscription { valuesBson }").boxed(); let mut stream = schema.execute_stream("subscription { valuesBson }");
assert_eq!( assert_eq!(
stream.next().await.map(|resp| resp.data).unwrap(), stream.next().await.map(|resp| resp.data).unwrap(),
value!({ value!({
@ -142,7 +142,6 @@ pub async fn test_field_features() {
assert_eq!( assert_eq!(
schema schema
.execute_stream("subscription { valuesAbc }") .execute_stream("subscription { valuesAbc }")
.boxed()
.next() .next()
.await .await
.unwrap() .unwrap()

View File

@ -284,8 +284,7 @@ pub async fn test_generic_subscription() {
{ {
let mut stream = schema let mut stream = schema
.execute_stream("subscription { values }") .execute_stream("subscription { values }")
.map(|resp| resp.into_result().unwrap().data) .map(|resp| resp.into_result().unwrap().data);
.boxed();
for i in 1..=2 { for i in 1..=2 {
assert_eq!(value!({ "values": i }), stream.next().await.unwrap()); assert_eq!(value!({ "values": i }), stream.next().await.unwrap());
} }

View File

@ -117,7 +117,6 @@ pub async fn test_guard_simple_rule() {
assert_eq!( assert_eq!(
schema schema
.execute_stream(Request::new("subscription { values }").data(Role::Guest)) .execute_stream(Request::new("subscription { values }").data(Role::Guest))
.boxed()
.next() .next()
.await .await
.unwrap() .unwrap()

View File

@ -196,8 +196,7 @@ pub async fn test_merged_subscription() {
{ {
let mut stream = schema let mut stream = schema
.execute_stream("subscription { events1 }") .execute_stream("subscription { events1 }")
.map(|resp| resp.into_result().unwrap().data) .map(|resp| resp.into_result().unwrap().data);
.boxed();
for i in 0i32..10 { for i in 0i32..10 {
assert_eq!( assert_eq!(
value!({ value!({
@ -212,8 +211,7 @@ pub async fn test_merged_subscription() {
{ {
let mut stream = schema let mut stream = schema
.execute_stream("subscription { events2 }") .execute_stream("subscription { events2 }")
.map(|resp| resp.into_result().unwrap().data) .map(|resp| resp.into_result().unwrap().data);
.boxed();
for i in 10i32..20 { for i in 10i32..20 {
assert_eq!( assert_eq!(
value!({ value!({

View File

@ -66,8 +66,7 @@ pub async fn test_input_value_custom_error() {
let mut stream = schema let mut stream = schema
.execute_stream("subscription { type }") .execute_stream("subscription { type }")
.map(|resp| resp.into_result()) .map(|resp| resp.into_result())
.map_ok(|resp| resp.data) .map_ok(|resp| resp.data);
.boxed();
for i in 0..10 { for i in 0..10 {
assert_eq!(value!({ "type": i }), stream.next().await.unwrap().unwrap()); assert_eq!(value!({ "type": i }), stream.next().await.unwrap().unwrap());
} }

View File

@ -127,7 +127,6 @@ pub async fn test_subscription() {
assert_eq!( assert_eq!(
Schema::new(Query, EmptyMutation, Subscription) Schema::new(Query, EmptyMutation, Subscription)
.execute_stream("subscription { CREATE_OBJECT(objectid: 100) }") .execute_stream("subscription { CREATE_OBJECT(objectid: 100) }")
.boxed()
.next() .next()
.await .await
.unwrap() .unwrap()

View File

@ -36,8 +36,7 @@ pub async fn test_subscription() {
{ {
let mut stream = schema let mut stream = schema
.execute_stream("subscription { values(start: 10, end: 20) }") .execute_stream("subscription { values(start: 10, end: 20) }")
.map(|resp| resp.into_result().unwrap().data) .map(|resp| resp.into_result().unwrap().data);
.boxed();
for i in 10..20 { for i in 10..20 {
assert_eq!(value!({ "values": i }), stream.next().await.unwrap()); assert_eq!(value!({ "values": i }), stream.next().await.unwrap());
} }
@ -47,8 +46,7 @@ pub async fn test_subscription() {
{ {
let mut stream = schema let mut stream = schema
.execute_stream("subscription { events(start: 10, end: 20) { a b } }") .execute_stream("subscription { events(start: 10, end: 20) { a b } }")
.map(|resp| resp.into_result().unwrap().data) .map(|resp| resp.into_result().unwrap().data);
.boxed();
for i in 10..20 { for i in 10..20 {
assert_eq!( assert_eq!(
value!({ "events": {"a": i, "b": i * 10} }), value!({ "events": {"a": i, "b": i * 10} }),
@ -98,8 +96,7 @@ pub async fn test_subscription_with_ctx_data() {
{ {
let mut stream = schema let mut stream = schema
.execute_stream(Request::new("subscription { values objects { value } }").data(100i32)) .execute_stream(Request::new("subscription { values objects { value } }").data(100i32))
.map(|resp| resp.data) .map(|resp| resp.data);
.boxed();
assert_eq!(value!({ "values": 100 }), stream.next().await.unwrap()); assert_eq!(value!({ "values": 100 }), stream.next().await.unwrap());
assert_eq!( assert_eq!(
value!({ "objects": { "value": 100 } }), value!({ "objects": { "value": 100 } }),
@ -141,8 +138,7 @@ pub async fn test_subscription_with_token() {
.execute_stream( .execute_stream(
Request::new("subscription { values }").data(Token("123456".to_string())), Request::new("subscription { values }").data(Token("123456".to_string())),
) )
.map(|resp| resp.into_result().unwrap().data) .map(|resp| resp.into_result().unwrap().data);
.boxed();
assert_eq!(value!({ "values": 100 }), stream.next().await.unwrap()); assert_eq!(value!({ "values": 100 }), stream.next().await.unwrap());
assert!(stream.next().await.is_none()); assert!(stream.next().await.is_none());
} }
@ -152,7 +148,6 @@ pub async fn test_subscription_with_token() {
.execute_stream( .execute_stream(
Request::new("subscription { values }").data(Token("654321".to_string())) Request::new("subscription { values }").data(Token("654321".to_string()))
) )
.boxed()
.next() .next()
.await .await
.unwrap() .unwrap()
@ -200,8 +195,7 @@ pub async fn test_subscription_inline_fragment() {
} }
"#, "#,
) )
.map(|resp| resp.data) .map(|resp| resp.data);
.boxed();
for i in 10..20 { for i in 10..20 {
assert_eq!( assert_eq!(
value!({ "events": {"a": i, "b": i * 10} }), value!({ "events": {"a": i, "b": i * 10} }),
@ -250,8 +244,7 @@ pub async fn test_subscription_fragment() {
} }
"#, "#,
) )
.map(|resp| resp.data) .map(|resp| resp.data);
.boxed();
for i in 10i32..20 { for i in 10i32..20 {
assert_eq!( assert_eq!(
value!({ "events": {"a": i, "b": i * 10} }), value!({ "events": {"a": i, "b": i * 10} }),
@ -301,8 +294,7 @@ pub async fn test_subscription_fragment2() {
} }
"#, "#,
) )
.map(|resp| resp.data) .map(|resp| resp.data);
.boxed();
for i in 10..20 { for i in 10..20 {
assert_eq!( assert_eq!(
value!({ "events": {"a": i, "b": i * 10} }), value!({ "events": {"a": i, "b": i * 10} }),
@ -342,8 +334,7 @@ pub async fn test_subscription_error() {
let mut stream = schema let mut stream = schema
.execute_stream("subscription { events { value } }") .execute_stream("subscription { events { value } }")
.map(|resp| resp.into_result()) .map(|resp| resp.into_result())
.map_ok(|resp| resp.data) .map_ok(|resp| resp.data);
.boxed();
for i in 0i32..5 { for i in 0i32..5 {
assert_eq!( assert_eq!(
value!({ "events": { "value": i } }), value!({ "events": { "value": i } }),
@ -388,8 +379,7 @@ pub async fn test_subscription_fieldresult() {
let mut stream = schema let mut stream = schema
.execute_stream("subscription { values }") .execute_stream("subscription { values }")
.map(|resp| resp.into_result()) .map(|resp| resp.into_result())
.map_ok(|resp| resp.data) .map_ok(|resp| resp.data);
.boxed();
for i in 0i32..5 { for i in 0i32..5 {
assert_eq!( assert_eq!(
value!({ "values": i }), value!({ "values": i }),