Rework Extension

This commit is contained in:
Sunli 2021-04-04 12:05:54 +08:00
parent 13298b8d61
commit 824356d118
27 changed files with 1058 additions and 1145 deletions

View File

@ -2,7 +2,14 @@
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
nd this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Changed
- Rework `Extension`, now fully supports asynchronous, better to use than before, and can achieve more features.
Because it contains a lot of changes _(if you don't have a custom extension, it will not cause the existing code to fail to compile)_, the major version will be updated to `3.0.0`.
## [2.7.4] 2021-04-02

View File

@ -19,6 +19,7 @@ apollo_persisted_queries = ["lru", "sha2"]
unblock = ["blocking"]
string_number = ["num-traits"]
dataloader = ["futures-timer", "futures-channel", "lru"]
tracing = ["tracinglib", "tracing-futures"]
[dependencies]
async-graphql-derive = { path = "derive", version = "=2.7.3" }
@ -47,7 +48,8 @@ bson = { version = "1.2.0", optional = true }
chrono = { version = "0.4.19", optional = true }
chrono-tz = { version = "0.5.3", optional = true }
log = { version = "0.4.14", optional = true }
tracing = { version = "0.1.25", optional = true }
tracinglib = { version = "0.1.25", optional = true, package = "tracing" }
tracing-futures = { version = "0.2.5", optional = true, features = ["std-future", "futures-03"] }
opentelemetry = { version = "0.13.0", optional = true }
url = { version = "2.2.1", optional = true }
uuid = { version = "0.8.2", optional = true, features = ["v4", "serde"] }

View File

@ -82,7 +82,7 @@ pub fn generate(object_args: &args::MergedSubscription) -> GeneratorResult<Token
fn create_field_stream<'__life>(
&'__life self,
ctx: &'__life #crate_name::Context<'__life>
) -> ::std::option::Option<::std::pin::Pin<::std::boxed::Box<dyn #crate_name::futures_util::stream::Stream<Item = #crate_name::ServerResult<#crate_name::Value>> + ::std::marker::Send + '__life>>> {
) -> ::std::option::Option<::std::pin::Pin<::std::boxed::Box<dyn #crate_name::futures_util::stream::Stream<Item = #crate_name::Response> + ::std::marker::Send + '__life>>> {
::std::option::Option::None #create_field_stream
}
}

View File

@ -282,7 +282,7 @@ pub fn generate(
quote! {
Some(#crate_name::registry::ComplexityType::Fn(|__ctx, __variables_definition, __field, child_complexity| {
#(#parse_args)*
Ok(#expr)
::std::result::Result::Ok(#expr)
}))
}
}
@ -331,7 +331,7 @@ pub fn generate(
let stream_fn = quote! {
#(#get_params)*
#guard
let field_name = ::std::sync::Arc::new(::std::clone::Clone::clone(&ctx.item.node.response_key().node));
let field_name = ::std::clone::Clone::clone(&ctx.item.node.response_key().node);
let field = ::std::sync::Arc::new(::std::clone::Clone::clone(&ctx.item));
let pos = ctx.item.pos;
@ -345,11 +345,6 @@ pub fn generate(
let field = ::std::clone::Clone::clone(&field);
let field_name = ::std::clone::Clone::clone(&field_name);
async move {
let resolve_id = #crate_name::ResolveId {
parent: ::std::option::Option::Some(0),
current: 1,
};
let inc_resolve_id = ::std::sync::atomic::AtomicUsize::new(1);
let ctx_selection_set = query_env.create_context(
&schema_env,
::std::option::Option::Some(#crate_name::QueryPathNode {
@ -357,26 +352,31 @@ pub fn generate(
segment: #crate_name::QueryPathSegment::Name(&field_name),
}),
&field.node.selection_set,
resolve_id,
&inc_resolve_id,
);
query_env.extensions.execution_start();
#[allow(bare_trait_objects)]
let ri = #crate_name::extensions::ResolveInfo {
resolve_id,
path_node: ctx_selection_set.path_node.as_ref().unwrap(),
parent_type: #gql_typename,
return_type: &<<#stream_ty as #crate_name::futures_util::stream::Stream>::Item as #crate_name::Type>::qualified_type_name(),
let mut execute_fut = async {
#[allow(bare_trait_objects)]
let ri = #crate_name::extensions::ResolveInfo {
path_node: ctx_selection_set.path_node.as_ref().unwrap(),
parent_type: #gql_typename,
return_type: &<<#stream_ty as #crate_name::futures_util::stream::Stream>::Item as #crate_name::Type>::qualified_type_name(),
};
let resolve_fut = async {
#crate_name::OutputType::resolve(&msg, &ctx_selection_set, &*field)
.await
.map(::std::option::Option::Some)
};
#crate_name::futures_util::pin_mut!(resolve_fut);
query_env.extensions.resolve(ri, &mut resolve_fut).await
.map(|value| {
let mut map = ::std::collections::BTreeMap::new();
map.insert(::std::clone::Clone::clone(&field_name), value.unwrap_or_default());
#crate_name::Response::new(#crate_name::Value::Object(map))
})
.unwrap_or_else(|err| #crate_name::Response::from_errors(::std::vec![err]))
};
query_env.extensions.resolve_start(&ri);
let res = #crate_name::OutputType::resolve(&msg, &ctx_selection_set, &*field).await;
query_env.extensions.resolve_end(&ri);
query_env.extensions.execution_end();
res
#crate_name::futures_util::pin_mut!(execute_fut);
::std::result::Result::Ok(query_env.extensions.execute(&mut execute_fut).await)
}
}
});
@ -398,11 +398,14 @@ pub fn generate(
create_stream.push(quote! {
#(#cfg_attrs)*
if ctx.item.node.name.node == #field_name {
return ::std::option::Option::Some(::std::boxed::Box::pin(
#crate_name::futures_util::stream::TryStreamExt::try_flatten(
#crate_name::futures_util::stream::once((move || async move { #stream_fn })())
)
));
let stream = #crate_name::futures_util::stream::TryStreamExt::try_flatten(
#crate_name::futures_util::stream::once((move || async move { #stream_fn })())
);
let stream = #crate_name::futures_util::StreamExt::map(stream, |res| match res {
::std::result::Result::Ok(resp) => resp,
::std::result::Result::Err(err) => #crate_name::Response::from_errors(::std::vec![err]),
});
return ::std::option::Option::Some(::std::boxed::Box::pin(stream));
}
});
@ -451,7 +454,7 @@ pub fn generate(
fn create_field_stream<'__life>(
&'__life self,
ctx: &'__life #crate_name::Context<'_>,
) -> ::std::option::Option<::std::pin::Pin<::std::boxed::Box<dyn #crate_name::futures_util::stream::Stream<Item = #crate_name::ServerResult<#crate_name::Value>> + ::std::marker::Send + '__life>>> {
) -> ::std::option::Option<::std::pin::Pin<::std::boxed::Box<dyn #crate_name::futures_util::stream::Stream<Item = #crate_name::Response> + ::std::marker::Send + '__life>>> {
#(#create_stream)*
::std::option::Option::None
}

View File

@ -4,7 +4,6 @@ use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::fmt::{self, Debug, Display, Formatter};
use std::ops::Deref;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use async_graphql_value::{Value as InputValue, Variables};
@ -189,36 +188,6 @@ impl<'a> Iterator for Parents<'a> {
impl<'a> std::iter::FusedIterator for Parents<'a> {}
/// The unique id of the current resolution.
#[derive(Debug, Clone, Copy)]
pub struct ResolveId {
/// The unique ID of the parent resolution.
pub parent: Option<usize>,
/// The current unique id.
pub current: usize,
}
impl ResolveId {
#[doc(hidden)]
pub fn root() -> ResolveId {
ResolveId {
parent: None,
current: 0,
}
}
}
impl Display for ResolveId {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
if let Some(parent) = self.parent {
write!(f, "{}:{}", parent, self.current)
} else {
write!(f, "{}", self.current)
}
}
}
/// Query context.
///
/// **This type is not stable and should not be used directly.**
@ -226,8 +195,6 @@ impl Display for ResolveId {
pub struct ContextBase<'a, T> {
/// The current path node being resolved.
pub path_node: Option<QueryPathNode<'a>>,
pub(crate) resolve_id: ResolveId,
pub(crate) inc_resolve_id: &'a AtomicUsize,
#[doc(hidden)]
pub item: T,
#[doc(hidden)]
@ -273,13 +240,9 @@ impl QueryEnv {
schema_env: &'a SchemaEnv,
path_node: Option<QueryPathNode<'a>>,
item: T,
resolve_id: ResolveId,
inc_resolve_id: &'a AtomicUsize,
) -> ContextBase<'a, T> {
ContextBase {
path_node,
resolve_id,
inc_resolve_id,
item,
schema_env,
query_env: self,
@ -288,18 +251,6 @@ impl QueryEnv {
}
impl<'a, T> ContextBase<'a, T> {
#[doc(hidden)]
pub fn get_child_resolve_id(&self) -> ResolveId {
let id = self
.inc_resolve_id
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
+ 1;
ResolveId {
parent: Some(self.resolve_id.current),
current: id,
}
}
#[doc(hidden)]
pub fn with_field(
&'a self,
@ -311,8 +262,6 @@ impl<'a, T> ContextBase<'a, T> {
segment: QueryPathSegment::Name(&field.node.response_key().node),
}),
item: field,
resolve_id: self.get_child_resolve_id(),
inc_resolve_id: self.inc_resolve_id,
schema_env: self.schema_env,
query_env: self.query_env,
}
@ -326,8 +275,6 @@ impl<'a, T> ContextBase<'a, T> {
ContextBase {
path_node: self.path_node,
item: selection_set,
resolve_id: self.resolve_id,
inc_resolve_id: &self.inc_resolve_id,
schema_env: self.schema_env,
query_env: self.query_env,
}
@ -560,8 +507,6 @@ impl<'a> ContextBase<'a, &'a Positioned<SelectionSet>> {
segment: QueryPathSegment::Index(idx),
}),
item: self.item,
resolve_id: self.get_child_resolve_id(),
inc_resolve_id: self.inc_resolve_id,
schema_env: self.schema_env,
query_env: self.query_env,
}

View File

@ -1,5 +1,9 @@
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory};
use crate::{value, ValidationResult, Value};
use std::sync::Arc;
use futures_util::lock::Mutex;
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, NextExtension};
use crate::{value, Response, ServerError, ValidationResult};
/// Analyzer extension
///
@ -7,32 +11,41 @@ use crate::{value, ValidationResult, Value};
pub struct Analyzer;
impl ExtensionFactory for Analyzer {
fn create(&self) -> Box<dyn Extension> {
Box::new(AnalyzerExtension::default())
fn create(&self) -> Arc<dyn Extension> {
Arc::new(AnalyzerExtension::default())
}
}
#[derive(Default)]
struct AnalyzerExtension {
complexity: usize,
depth: usize,
validation_result: Mutex<Option<ValidationResult>>,
}
#[async_trait::async_trait]
impl Extension for AnalyzerExtension {
fn name(&self) -> Option<&'static str> {
Some("analyzer")
async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
let mut resp = next.request(ctx).await;
let validation_result = self.validation_result.lock().await.take();
if let Some(validation_result) = validation_result {
resp = resp.extension(
"analyzer",
value! ({
"complexity": validation_result.complexity,
"depth": validation_result.depth,
}),
);
}
resp
}
fn validation_end(&mut self, _ctx: &ExtensionContext<'_>, result: &ValidationResult) {
self.complexity = result.complexity;
self.depth = result.depth;
}
fn result(&mut self, _ctx: &ExtensionContext<'_>) -> Option<Value> {
Some(value! ({
"complexity": self.complexity,
"depth": self.depth,
}))
async fn validation(
&self,
ctx: &ExtensionContext<'_>,
next: NextExtension<'_>,
) -> Result<ValidationResult, Vec<ServerError>> {
let res = next.validation(ctx).await?;
*self.validation_result.lock().await = Some(res);
Ok(res)
}
}
@ -78,7 +91,7 @@ mod tests {
.extension(extensions::Analyzer)
.finish();
let extensions = schema
let res = schema
.execute(
r#"{
value obj {
@ -93,15 +106,13 @@ mod tests {
.into_result()
.unwrap()
.extensions
.unwrap();
.remove("analyzer");
assert_eq!(
extensions,
value!({
"analyzer": {
"complexity": 5 + 10,
"depth": 3,
}
})
res,
Some(value!({
"complexity": 5 + 10,
"depth": 3,
}))
);
}
}

View File

@ -6,7 +6,7 @@ use futures_util::lock::Mutex;
use serde::Deserialize;
use sha2::{Digest, Sha256};
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory};
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, NextExtension};
use crate::{from_value, Request, ServerError, ServerResult};
#[derive(Deserialize)]
@ -64,8 +64,8 @@ impl<T: CacheStorage> ApolloPersistedQueries<T> {
}
impl<T: CacheStorage> ExtensionFactory for ApolloPersistedQueries<T> {
fn create(&self) -> Box<dyn Extension> {
Box::new(ApolloPersistedQueriesExtension {
fn create(&self) -> Arc<dyn Extension> {
Arc::new(ApolloPersistedQueriesExtension {
storage: self.0.clone(),
})
}
@ -78,18 +78,19 @@ struct ApolloPersistedQueriesExtension<T> {
#[async_trait::async_trait]
impl<T: CacheStorage> Extension for ApolloPersistedQueriesExtension<T> {
async fn prepare_request(
&mut self,
_ctx: &ExtensionContext<'_>,
&self,
ctx: &ExtensionContext<'_>,
mut request: Request,
next: NextExtension<'_>,
) -> ServerResult<Request> {
if let Some(value) = request.extensions.remove("persistedQuery") {
let res = if let Some(value) = request.extensions.remove("persistedQuery") {
let persisted_query: PersistedQuery = from_value(value).map_err(|_| {
ServerError::new("Invalid \"PersistedQuery\" extension configuration.")
})?;
if persisted_query.version != 1 {
return Err(ServerError::new(
format!("Only the \"PersistedQuery\" extension of version \"1\" is supported, and the current version is \"{}\".", persisted_query.version),
));
));
}
if request.query.is_empty() {
@ -110,7 +111,8 @@ impl<T: CacheStorage> Extension for ApolloPersistedQueriesExtension<T> {
}
} else {
Ok(request)
}
};
next.prepare_request(ctx, res?).await
}
}

View File

@ -1,36 +1,26 @@
use std::collections::BTreeMap;
use std::ops::Deref;
use std::sync::Arc;
use chrono::{DateTime, Utc};
use futures_util::lock::Mutex;
use serde::ser::SerializeMap;
use serde::{Serialize, Serializer};
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo};
use crate::{value, Value};
use crate::extensions::{
Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo,
};
use crate::{value, Response, ServerResult, Value};
struct PendingResolve {
struct ResolveState {
path: Vec<String>,
field_name: String,
parent_type: String,
return_type: String,
start_time: DateTime<Utc>,
}
struct ResolveStat {
pending_resolve: PendingResolve,
end_time: DateTime<Utc>,
start_offset: i64,
}
impl Deref for ResolveStat {
type Target = PendingResolve;
fn deref(&self) -> &Self::Target {
&self.pending_resolve
}
}
impl Serialize for ResolveStat {
impl Serialize for ResolveState {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut map = serializer.serialize_map(None)?;
map.serialize_entry("path", &self.path)?;
@ -57,76 +47,79 @@ impl Serialize for ResolveStat {
pub struct ApolloTracing;
impl ExtensionFactory for ApolloTracing {
fn create(&self) -> Box<dyn Extension> {
Box::new(ApolloTracingExtension {
start_time: Utc::now(),
end_time: Utc::now(),
pending_resolves: Default::default(),
resolves: Default::default(),
fn create(&self) -> Arc<dyn Extension> {
Arc::new(ApolloTracingExtension {
inner: Mutex::new(Inner {
start_time: Utc::now(),
end_time: Utc::now(),
resolves: Default::default(),
}),
})
}
}
struct ApolloTracingExtension {
struct Inner {
start_time: DateTime<Utc>,
end_time: DateTime<Utc>,
pending_resolves: BTreeMap<usize, PendingResolve>,
resolves: Vec<ResolveStat>,
resolves: Vec<ResolveState>,
}
struct ApolloTracingExtension {
inner: Mutex<Inner>,
}
#[async_trait::async_trait]
impl Extension for ApolloTracingExtension {
fn name(&self) -> Option<&'static str> {
Some("tracing")
}
async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
self.inner.lock().await.start_time = Utc::now();
let resp = next.execute(ctx).await;
fn execution_start(&mut self, _ctx: &ExtensionContext<'_>) {
self.start_time = Utc::now();
self.pending_resolves.clear();
self.resolves.clear();
}
fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) {
self.end_time = Utc::now();
}
fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
self.pending_resolves.insert(
info.resolve_id.current,
PendingResolve {
path: info.path_node.to_string_vec(),
field_name: info.path_node.field_name().to_string(),
parent_type: info.parent_type.to_string(),
return_type: info.return_type.to_string(),
start_time: Utc::now(),
},
);
}
fn resolve_end(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
if let Some(pending_resolve) = self.pending_resolves.remove(&info.resolve_id.current) {
let start_offset = (pending_resolve.start_time - self.start_time)
.num_nanoseconds()
.unwrap();
self.resolves.push(ResolveStat {
pending_resolve,
start_offset,
end_time: Utc::now(),
});
}
}
fn result(&mut self, _ctx: &ExtensionContext<'_>) -> Option<Value> {
self.resolves
let mut inner = self.inner.lock().await;
inner.end_time = Utc::now();
inner
.resolves
.sort_by(|a, b| a.start_offset.cmp(&b.start_offset));
resp.extension(
"tracing",
value!({
"version": 1,
"startTime": inner.start_time.to_rfc3339(),
"endTime": inner.end_time.to_rfc3339(),
"duration": (inner.end_time - inner.start_time).num_nanoseconds(),
"execution": {
"resolvers": inner.resolves
}
}),
)
}
Some(value!({
"version": 1,
"startTime": self.start_time.to_rfc3339(),
"endTime": self.end_time.to_rfc3339(),
"duration": (self.end_time - self.start_time).num_nanoseconds(),
"execution": {
"resolvers": self.resolves
}
}))
async fn resolve(
&self,
ctx: &ExtensionContext<'_>,
info: ResolveInfo<'_>,
next: NextExtension<'_>,
) -> ServerResult<Option<Value>> {
let path = info.path_node.to_string_vec();
let field_name = info.path_node.field_name().to_string();
let parent_type = info.parent_type.to_string();
let return_type = info.return_type.to_string();
let start_time = Utc::now();
let start_offset = (start_time - self.inner.lock().await.start_time)
.num_nanoseconds()
.unwrap();
let res = next.resolve(ctx, info).await;
let end_time = Utc::now();
self.inner.lock().await.resolves.push(ResolveState {
path,
field_name,
parent_type,
return_type,
start_time,
end_time,
start_offset,
});
res
}
}

View File

@ -1,117 +1,124 @@
use std::fmt::{self, Display, Formatter};
use std::sync::Arc;
use log::{error, info, trace};
use futures_util::lock::Mutex;
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo};
use crate::extensions::{
Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo,
};
use crate::parser::types::{ExecutableDocument, OperationType, Selection};
use crate::{PathSegment, ServerError, Variables};
use crate::{PathSegment, ServerError, ServerResult, Value, Variables};
/// Logger extension
#[cfg_attr(docsrs, doc(cfg(feature = "log")))]
pub struct Logger;
impl ExtensionFactory for Logger {
fn create(&self) -> Box<dyn Extension> {
Box::new(LoggerExtension {
enabled: true,
query: String::new(),
variables: Default::default(),
fn create(&self) -> Arc<dyn Extension> {
Arc::new(LoggerExtension {
inner: Mutex::new(Inner {
enabled: true,
query: String::new(),
variables: Default::default(),
}),
})
}
}
struct LoggerExtension {
struct Inner {
enabled: bool,
query: String,
variables: Variables,
}
impl Extension for LoggerExtension {
fn parse_start(
&mut self,
_ctx: &ExtensionContext<'_>,
query_source: &str,
variables: &Variables,
) {
self.query = query_source.replace(char::is_whitespace, "");
self.variables = variables.clone();
}
struct LoggerExtension {
inner: Mutex<Inner>,
}
fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, document: &ExecutableDocument) {
#[async_trait::async_trait]
impl Extension for LoggerExtension {
async fn parse_query(
&self,
ctx: &ExtensionContext<'_>,
query: &str,
variables: &Variables,
next: NextExtension<'_>,
) -> ServerResult<ExecutableDocument> {
let mut inner = self.inner.lock().await;
inner.query = query.replace(char::is_whitespace, "");
inner.variables = variables.clone();
let document = next.parse_query(ctx, query, variables).await?;
let is_schema = document
.operations
.iter()
.filter(|(_, operation)| operation.node.ty == OperationType::Query)
.any(|(_, operation)| operation.node.selection_set.node.items.iter().any(|selection| matches!(&selection.node, Selection::Field(field) if field.node.name.node == "__schema")));
if is_schema {
self.enabled = false;
return;
}
info!(target: "async-graphql", "[Query] query: \"{}\", variables: {}", &self.query, self.variables);
inner.enabled = !is_schema;
Ok(document)
}
fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
if !self.enabled {
return;
}
trace!(target: "async-graphql", "[ResolveStart] path: \"{}\"", info.path_node);
}
fn resolve_end(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
if !self.enabled {
return;
}
trace!(target: "async-graphql", "[ResolveEnd] path: \"{}\"", info.path_node);
}
fn error(&mut self, _ctx: &ExtensionContext<'_>, err: &ServerError) {
struct DisplayError<'a> {
log: &'a LoggerExtension,
e: &'a ServerError,
}
impl<'a> Display for DisplayError<'a> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "[Error] ")?;
if !self.e.path.is_empty() {
write!(f, "path: ")?;
for (i, segment) in self.e.path.iter().enumerate() {
if i != 0 {
write!(f, ".")?;
}
match segment {
PathSegment::Field(field) => write!(f, "{}", field),
PathSegment::Index(i) => write!(f, "{}", i),
}?;
}
write!(f, ", ")?;
}
if !self.e.locations.is_empty() {
write!(f, "pos: [")?;
for (i, location) in self.e.locations.iter().enumerate() {
if i != 0 {
write!(f, ", ")?;
}
write!(f, "{}:{}", location.line, location.column)?;
}
write!(f, "], ")?;
}
write!(f, r#"query: "{}", "#, self.log.query)?;
write!(f, "variables: {}", self.log.variables)?;
write!(f, "{}", self.e.message)
async fn resolve(
&self,
ctx: &ExtensionContext<'_>,
info: ResolveInfo<'_>,
next: NextExtension<'_>,
) -> ServerResult<Option<Value>> {
let enabled = self.inner.lock().await.enabled;
if enabled {
let path = info.path_node.to_string();
log::trace!(target: "async-graphql", "[ResolveStart] path: \"{}\"", path);
let res = next.resolve(ctx, info).await;
if let Err(err) = &res {
let inner = self.inner.lock().await;
log::error!(
target: "async-graphql",
"{}",
DisplayError { query:&inner.query,variables:&inner.variables, e: &err }
);
}
log::trace!(target: "async-graphql", "[ResolveEnd] path: \"{}\"", path);
res
} else {
next.resolve(ctx, info).await
}
error!(
target: "async-graphql",
"{}",
DisplayError {
log: self,
e: err,
}
);
}
}
struct DisplayError<'a> {
query: &'a str,
variables: &'a Variables,
e: &'a ServerError,
}
impl<'a> Display for DisplayError<'a> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
write!(f, "[Error] ")?;
if !self.e.path.is_empty() {
write!(f, "path: ")?;
for (i, segment) in self.e.path.iter().enumerate() {
if i != 0 {
write!(f, ".")?;
}
match segment {
PathSegment::Field(field) => write!(f, "{}", field),
PathSegment::Index(i) => write!(f, "{}", i),
}?;
}
write!(f, ", ")?;
}
if !self.e.locations.is_empty() {
write!(f, "pos: [")?;
for (i, location) in self.e.locations.iter().enumerate() {
if i != 0 {
write!(f, ", ")?;
}
write!(f, "{}:{}", location.line, location.column)?;
}
write!(f, "], ")?;
}
write!(f, r#"query: "{}", "#, self.query)?;
write!(f, "variables: {}", self.variables)?;
write!(f, "{}", self.e.message)
}
}

View File

@ -12,28 +12,28 @@ mod opentelemetry;
#[cfg(feature = "tracing")]
mod tracing;
use std::any::{Any, TypeId};
use std::collections::BTreeMap;
use std::sync::Arc;
use crate::context::{QueryPathNode, ResolveId};
use crate::parser::types::ExecutableDocument;
use crate::{
Data, Request, Result, SchemaEnv, ServerError, ServerResult, ValidationResult, Variables,
};
use crate::{Error, Name, Value};
pub use self::analyzer::Analyzer;
#[cfg(feature = "apollo_tracing")]
pub use self::apollo_tracing::ApolloTracing;
#[cfg(feature = "log")]
pub use self::logger::Logger;
#[cfg(feature = "opentelemetry")]
pub use self::opentelemetry::{OpenTelemetry, OpenTelemetryConfig};
pub use self::opentelemetry::OpenTelemetry;
#[cfg(feature = "tracing")]
pub use self::tracing::{Tracing, TracingConfig};
pub use self::tracing::Tracing;
pub(crate) type BoxExtension = Box<dyn Extension>;
use std::any::{Any, TypeId};
use std::future::Future;
use std::sync::Arc;
use futures_util::stream::BoxStream;
use futures_util::stream::StreamExt;
use crate::parser::types::ExecutableDocument;
use crate::{
Data, Error, QueryPathNode, Request, Response, Result, SchemaEnv, ServerError, ServerResult,
ValidationResult, Value, Variables,
};
/// Context for extension
pub struct ExtensionContext<'a> {
@ -86,10 +86,6 @@ impl<'a> ExtensionContext<'a> {
/// Parameters for `Extension::resolve_field_start`
pub struct ResolveInfo<'a> {
/// Because resolver is concurrent, `Extension::resolve_field_start` and `Extension::resolve_field_end` are
/// not strictly ordered, so each pair is identified by an id.
pub resolve_id: ResolveId,
/// Current path node, You can go through the entire path.
pub path_node: &'a QueryPathNode<'a>,
@ -100,120 +96,250 @@ pub struct ResolveInfo<'a> {
pub return_type: &'a str,
}
/// Represents a GraphQL extension
///
/// # Call order for query and mutation
///
/// - start
/// - prepare_request
/// - parse_start
/// - parse_end
/// - validation_start
/// - validation_end
/// - execution_start
/// - resolve_start
/// - resolve_end
/// - result
/// - execution_end
/// - end
///
/// # Call order for subscription
///
/// - start
/// - prepare_request
/// - parse_start
/// - parse_end
/// - validation_start
/// - validation_end
/// - execution_start
/// - resolve_start
/// - resolve_end
/// - execution_end
/// - result
/// ```
#[async_trait::async_trait]
#[allow(unused_variables)]
pub trait Extension: Sync + Send + 'static {
/// If this extension needs to output data to query results, you need to specify a name.
fn name(&self) -> Option<&'static str> {
None
type RequestFut<'a> = &'a mut (dyn Future<Output = Response> + Send + Unpin);
type ParseFut<'a> = &'a mut (dyn Future<Output = ServerResult<ExecutableDocument>> + Send + Unpin);
type ValidationFut<'a> =
&'a mut (dyn Future<Output = Result<ValidationResult, Vec<ServerError>>> + Send + Unpin);
type ExecuteFut<'a> = &'a mut (dyn Future<Output = Response> + Send + Unpin);
type ResolveFut<'a> = &'a mut (dyn Future<Output = ServerResult<Option<Value>>> + Send + Unpin);
/// The remainder of a extension chain.
pub struct NextExtension<'a> {
chain: &'a [Arc<dyn Extension>],
request_fut: Option<RequestFut<'a>>,
parse_query_fut: Option<ParseFut<'a>>,
validation_fut: Option<ValidationFut<'a>>,
execute_fut: Option<ExecuteFut<'a>>,
resolve_fut: Option<ResolveFut<'a>>,
}
impl<'a> NextExtension<'a> {
#[inline]
pub(crate) fn new(chain: &'a [Arc<dyn Extension>]) -> Self {
Self {
chain,
request_fut: None,
parse_query_fut: None,
validation_fut: None,
execute_fut: None,
resolve_fut: None,
}
}
/// Called at the beginning of query.
fn start(&mut self, ctx: &ExtensionContext<'_>) {}
#[inline]
pub(crate) fn with_chain(self, chain: &'a [Arc<dyn Extension>]) -> Self {
Self { chain, ..self }
}
/// Called at the beginning of query.
fn end(&mut self, ctx: &ExtensionContext<'_>) {}
#[inline]
pub(crate) fn with_request(self, fut: RequestFut<'a>) -> Self {
Self {
request_fut: Some(fut),
..self
}
}
/// Called at prepare request.
async fn prepare_request(
&mut self,
#[inline]
pub(crate) fn with_parse_query(self, fut: ParseFut<'a>) -> Self {
Self {
parse_query_fut: Some(fut),
..self
}
}
#[inline]
pub(crate) fn with_validation(self, fut: ValidationFut<'a>) -> Self {
Self {
validation_fut: Some(fut),
..self
}
}
#[inline]
pub(crate) fn with_execute(self, fut: ExecuteFut<'a>) -> Self {
Self {
execute_fut: Some(fut),
..self
}
}
#[inline]
pub(crate) fn with_resolve(self, fut: ResolveFut<'a>) -> Self {
Self {
resolve_fut: Some(fut),
..self
}
}
/// Call the [Extension::request] function of next extension.
pub async fn request(mut self, ctx: &ExtensionContext<'_>) -> Response {
if let Some((first, next)) = self.chain.split_first() {
first.request(ctx, self.with_chain(next)).await
} else {
self.request_fut
.take()
.expect("You definitely called the wrong function.")
.await
}
}
/// Call the [Extension::subscribe] function of next extension.
pub fn subscribe<'s>(
self,
ctx: &ExtensionContext<'_>,
stream: BoxStream<'s, Response>,
) -> BoxStream<'s, Response> {
if let Some((first, next)) = self.chain.split_first() {
first.subscribe(ctx, stream, self.with_chain(next))
} else {
stream
}
}
/// Call the [Extension::prepare_request] function of next extension.
pub async fn prepare_request(
self,
ctx: &ExtensionContext<'_>,
request: Request,
) -> ServerResult<Request> {
Ok(request)
if let Some((first, next)) = self.chain.split_first() {
first
.prepare_request(ctx, request, self.with_chain(next))
.await
} else {
Ok(request)
}
}
/// Called at the beginning of parse query source.
fn parse_start(
&mut self,
/// Call the [Extension::parse_query] function of next extension.
pub async fn parse_query(
mut self,
ctx: &ExtensionContext<'_>,
query_source: &str,
query: &str,
variables: &Variables,
) {
}
/// Called at the end of parse query source.
fn parse_end(&mut self, ctx: &ExtensionContext<'_>, document: &ExecutableDocument) {}
/// Called at the beginning of the validation.
fn validation_start(&mut self, ctx: &ExtensionContext<'_>) {}
/// Called at the end of the validation.
fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) {}
/// Called at the beginning of execute a query.
fn execution_start(&mut self, ctx: &ExtensionContext<'_>) {}
/// Called at the end of execute a query.
fn execution_end(&mut self, ctx: &ExtensionContext<'_>) {}
/// Called at the beginning of resolve a field.
fn resolve_start(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {}
/// Called at the end of resolve a field.
fn resolve_end(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {}
/// Called when an error occurs.
fn error(&mut self, ctx: &ExtensionContext<'_>, err: &ServerError) {}
/// Get the results.
fn result(&mut self, ctx: &ExtensionContext<'_>) -> Option<Value> {
None
}
}
pub(crate) trait ErrorLogger {
fn log_error(self, extensions: &Extensions) -> Self;
}
impl<T> ErrorLogger for ServerResult<T> {
fn log_error(self, extensions: &Extensions) -> Self {
if let Err(err) = &self {
extensions.error(err);
) -> ServerResult<ExecutableDocument> {
if let Some((first, next)) = self.chain.split_first() {
first
.parse_query(ctx, query, variables, self.with_chain(next))
.await
} else {
self.parse_query_fut
.take()
.expect("You definitely called the wrong function.")
.await
}
}
/// Call the [Extension::validation] function of next extension.
pub async fn validation(
mut self,
ctx: &ExtensionContext<'_>,
) -> Result<ValidationResult, Vec<ServerError>> {
if let Some((first, next)) = self.chain.split_first() {
first.validation(ctx, self.with_chain(next)).await
} else {
self.validation_fut
.take()
.expect("You definitely called the wrong function.")
.await
}
}
/// Call the [Extension::execute] function of next extension.
pub async fn execute(mut self, ctx: &ExtensionContext<'_>) -> Response {
if let Some((first, next)) = self.chain.split_first() {
first.execute(ctx, self.with_chain(next)).await
} else {
self.execute_fut
.take()
.expect("You definitely called the wrong function.")
.await
}
}
/// Call the [Extension::resolve] function of next extension.
pub async fn resolve(
mut self,
ctx: &ExtensionContext<'_>,
info: ResolveInfo<'_>,
) -> ServerResult<Option<Value>> {
if let Some((first, next)) = self.chain.split_first() {
first.resolve(ctx, info, self.with_chain(next)).await
} else {
self.resolve_fut
.take()
.expect("You definitely called the wrong function.")
.await
}
self
}
}
impl<T> ErrorLogger for Result<T, Vec<ServerError>> {
fn log_error(self, extensions: &Extensions) -> Self {
if let Err(errors) = &self {
for error in errors {
extensions.error(error);
}
}
self
/// Represents a GraphQL extension
#[async_trait::async_trait]
#[allow(unused_variables)]
pub trait Extension: Sync + Send + 'static {
/// Called at start query/mutation request.
async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
next.request(ctx).await
}
/// Called at subscribe request.
fn subscribe<'s>(
&self,
ctx: &ExtensionContext<'_>,
stream: BoxStream<'s, Response>,
next: NextExtension<'_>,
) -> BoxStream<'s, Response> {
next.subscribe(ctx, stream)
}
/// Called at prepare request.
async fn prepare_request(
&self,
ctx: &ExtensionContext<'_>,
request: Request,
next: NextExtension<'_>,
) -> ServerResult<Request> {
next.prepare_request(ctx, request).await
}
/// Called at parse query.
async fn parse_query(
&self,
ctx: &ExtensionContext<'_>,
query: &str,
variables: &Variables,
next: NextExtension<'_>,
) -> ServerResult<ExecutableDocument> {
next.parse_query(ctx, query, variables).await
}
/// Called at validation query.
async fn validation(
&self,
ctx: &ExtensionContext<'_>,
next: NextExtension<'_>,
) -> Result<ValidationResult, Vec<ServerError>> {
next.validation(ctx).await
}
/// Called at execute query.
async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
next.execute(ctx).await
}
/// Called at resolve field.
async fn resolve(
&self,
ctx: &ExtensionContext<'_>,
info: ResolveInfo<'_>,
next: NextExtension<'_>,
) -> ServerResult<Option<Value>> {
next.resolve(ctx, info).await
}
}
@ -222,12 +348,13 @@ impl<T> ErrorLogger for Result<T, Vec<ServerError>> {
/// Used to create an extension instance.
pub trait ExtensionFactory: Send + Sync + 'static {
/// Create an extended instance.
fn create(&self) -> Box<dyn Extension>;
fn create(&self) -> Arc<dyn Extension>;
}
#[derive(Clone)]
#[doc(hidden)]
pub struct Extensions {
extensions: Option<spin::Mutex<Vec<BoxExtension>>>,
extensions: Vec<Arc<dyn Extension>>,
schema_env: SchemaEnv,
session_data: Arc<Data>,
query_data: Option<Arc<Data>>,
@ -235,17 +362,13 @@ pub struct Extensions {
#[doc(hidden)]
impl Extensions {
pub fn new(
extensions: Vec<BoxExtension>,
pub(crate) fn new(
extensions: impl IntoIterator<Item = Arc<dyn Extension>>,
schema_env: SchemaEnv,
session_data: Arc<Data>,
) -> Self {
Extensions {
extensions: if extensions.is_empty() {
None
} else {
Some(spin::Mutex::new(extensions))
},
extensions: extensions.into_iter().collect(),
schema_env,
session_data,
query_data: None,
@ -255,18 +378,14 @@ impl Extensions {
pub fn attach_query_data(&mut self, data: Arc<Data>) {
self.query_data = Some(data);
}
}
impl Drop for Extensions {
fn drop(&mut self) {
self.end();
}
}
#[doc(hidden)]
impl Extensions {
#[inline]
fn context(&self) -> ExtensionContext<'_> {
pub(crate) fn is_empty(&self) -> bool {
self.extensions.is_empty()
}
#[inline]
fn create_context(&self) -> ExtensionContext {
ExtensionContext {
schema_data: &self.schema_env.data,
session_data: &self.session_data,
@ -274,124 +393,79 @@ impl Extensions {
}
}
pub fn is_empty(&self) -> bool {
self.extensions.is_none()
}
pub fn start(&self) {
if let Some(e) = &self.extensions {
e.lock().iter_mut().for_each(|e| e.start(&self.context()));
pub async fn request(&self, request_fut: RequestFut<'_>) -> Response {
if !self.extensions.is_empty() {
let next = NextExtension::new(&self.extensions).with_request(request_fut);
next.request(&self.create_context()).await
} else {
request_fut.await
}
}
pub fn end(&self) {
if let Some(e) = &self.extensions {
e.lock().iter_mut().for_each(|e| e.end(&self.context()));
pub fn subscribe<'s>(&self, stream: BoxStream<'s, Response>) -> BoxStream<'s, Response> {
if !self.extensions.is_empty() {
let next = NextExtension::new(&self.extensions);
next.subscribe(&self.create_context(), stream)
} else {
stream.boxed()
}
}
pub async fn prepare_request(&self, request: Request) -> ServerResult<Request> {
let mut request = request;
if let Some(e) = &self.extensions {
for e in e.lock().iter_mut() {
request = e.prepare_request(&self.context(), request).await?;
}
}
Ok(request)
}
pub fn parse_start(&self, query_source: &str, variables: &Variables) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.parse_start(&self.context(), query_source, variables));
}
}
pub fn parse_end(&self, document: &ExecutableDocument) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.parse_end(&self.context(), document));
}
}
pub fn validation_start(&self) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.validation_start(&self.context()));
}
}
pub fn validation_end(&self, result: &ValidationResult) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.validation_end(&self.context(), result));
}
}
pub fn execution_start(&self) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.execution_start(&self.context()));
}
}
pub fn execution_end(&self) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.execution_end(&self.context()));
}
}
pub fn resolve_start(&self, info: &ResolveInfo<'_>) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.resolve_start(&self.context(), info));
}
}
pub fn resolve_end(&self, resolve_id: &ResolveInfo<'_>) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.resolve_end(&self.context(), resolve_id));
}
}
pub fn error(&self, err: &ServerError) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.error(&self.context(), err));
}
}
pub fn result(&self) -> Option<Value> {
if let Some(e) = &self.extensions {
let value = e
.lock()
.iter_mut()
.filter_map(|e| {
if let Some(name) = e.name() {
e.result(&self.context()).map(|res| (Name::new(name), res))
} else {
None
}
})
.collect::<BTreeMap<_, _>>();
if value.is_empty() {
None
} else {
Some(Value::Object(value))
}
if !self.extensions.is_empty() {
let next = NextExtension::new(&self.extensions);
next.prepare_request(&self.create_context(), request).await
} else {
None
Ok(request)
}
}
pub async fn parse_query(
&self,
query: &str,
variables: &Variables,
parse_query_fut: ParseFut<'_>,
) -> ServerResult<ExecutableDocument> {
if !self.extensions.is_empty() {
let next = NextExtension::new(&self.extensions).with_parse_query(parse_query_fut);
next.parse_query(&self.create_context(), query, variables)
.await
} else {
parse_query_fut.await
}
}
pub async fn validation(
&self,
validation_fut: ValidationFut<'_>,
) -> Result<ValidationResult, Vec<ServerError>> {
if !self.extensions.is_empty() {
let next = NextExtension::new(&self.extensions).with_validation(validation_fut);
next.validation(&self.create_context()).await
} else {
validation_fut.await
}
}
pub async fn execute(&self, execute_fut: ExecuteFut<'_>) -> Response {
if !self.extensions.is_empty() {
let next = NextExtension::new(&self.extensions).with_execute(execute_fut);
next.execute(&self.create_context()).await
} else {
execute_fut.await
}
}
pub async fn resolve(
&self,
info: ResolveInfo<'_>,
resolve_fut: ResolveFut<'_>,
) -> ServerResult<Option<Value>> {
if !self.extensions.is_empty() {
let next = NextExtension::new(&self.extensions).with_resolve(resolve_fut);
next.resolve(&self.create_context(), info).await
} else {
resolve_fut.await
}
}
}

View File

@ -1,49 +1,25 @@
use std::collections::HashMap;
use std::sync::Arc;
use async_graphql_parser::types::ExecutableDocument;
use async_graphql_value::Variables;
use opentelemetry::trace::{SpanKind, TraceContextExt, Tracer};
use futures_util::stream::BoxStream;
use futures_util::TryFutureExt;
use opentelemetry::trace::{FutureExt, SpanKind, TraceContextExt, Tracer};
use opentelemetry::{Context as OpenTelemetryContext, Key};
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo};
use crate::{ServerError, ValidationResult};
const REQUEST_CTX: usize = 0;
const PARSE_CTX: usize = 1;
const VALIDATION_CTX: usize = 2;
const EXECUTE_CTX: usize = 3;
#[inline]
fn resolve_ctx_id(resolver_id: usize) -> usize {
resolver_id + 10
}
use crate::extensions::{
Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo,
};
use crate::{Response, ServerError, ServerResult, ValidationResult, Value};
const KEY_SOURCE: Key = Key::from_static_str("graphql.source");
const KEY_VARIABLES: Key = Key::from_static_str("graphql.variables");
const KEY_PARENT_TYPE: Key = Key::from_static_str("graphql.parentType");
const KEY_RETURN_TYPE: Key = Key::from_static_str("graphql.returnType");
const KEY_RESOLVE_ID: Key = Key::from_static_str("graphql.resolveId");
const KEY_ERROR: Key = Key::from_static_str("graphql.error");
const KEY_COMPLEXITY: Key = Key::from_static_str("graphql.complexity");
const KEY_DEPTH: Key = Key::from_static_str("graphql.depth");
/// OpenTelemetry extension configuration for each request.
#[derive(Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))]
pub struct OpenTelemetryConfig {
/// Use a context as the parent node of the entire query.
parent: spin::Mutex<Option<OpenTelemetryContext>>,
}
impl OpenTelemetryConfig {
/// Use a context as the parent of the entire query.
pub fn parent_context(mut self, cx: OpenTelemetryContext) -> Self {
*self.parent.get_mut() = Some(cx);
self
}
}
/// OpenTelemetry extension
#[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))]
pub struct OpenTelemetry<T> {
@ -63,159 +39,127 @@ impl<T> OpenTelemetry<T> {
}
impl<T: Tracer + Send + Sync> ExtensionFactory for OpenTelemetry<T> {
fn create(&self) -> Box<dyn Extension> {
Box::new(OpenTelemetryExtension {
fn create(&self) -> Arc<dyn Extension> {
Arc::new(OpenTelemetryExtension {
tracer: self.tracer.clone(),
contexts: Default::default(),
})
}
}
struct OpenTelemetryExtension<T> {
tracer: Arc<T>,
contexts: HashMap<usize, OpenTelemetryContext>,
}
impl<T> OpenTelemetryExtension<T> {
fn enter_context(&mut self, id: usize, cx: OpenTelemetryContext) {
let _ = cx.clone().attach();
self.contexts.insert(id, cx);
}
fn exit_context(&mut self, id: usize) -> Option<OpenTelemetryContext> {
if let Some(cx) = self.contexts.remove(&id) {
let _ = cx.clone().attach();
Some(cx)
} else {
None
}
}
}
#[async_trait::async_trait]
impl<T: Tracer + Send + Sync> Extension for OpenTelemetryExtension<T> {
fn start(&mut self, ctx: &ExtensionContext<'_>) {
let request_cx = ctx
.data_opt::<OpenTelemetryConfig>()
.and_then(|cfg| cfg.parent.lock().take())
.unwrap_or_else(|| {
OpenTelemetryContext::current_with_span(
async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
next.request(ctx)
.with_context(OpenTelemetryContext::current_with_span(
self.tracer
.span_builder("request")
.with_kind(SpanKind::Server)
.start(&*self.tracer),
))
.await
}
fn subscribe<'s>(
&self,
ctx: &ExtensionContext<'_>,
stream: BoxStream<'s, Response>,
next: NextExtension<'_>,
) -> BoxStream<'s, Response> {
Box::pin(
next.subscribe(ctx, stream)
.with_context(OpenTelemetryContext::current_with_span(
self.tracer
.span_builder("request")
.span_builder("subscribe")
.with_kind(SpanKind::Server)
.start(&*self.tracer),
)
});
self.enter_context(REQUEST_CTX, request_cx);
)),
)
}
fn end(&mut self, _ctx: &ExtensionContext<'_>) {
self.exit_context(REQUEST_CTX);
}
fn parse_start(
&mut self,
_ctx: &ExtensionContext<'_>,
query_source: &str,
async fn parse_query(
&self,
ctx: &ExtensionContext<'_>,
query: &str,
variables: &Variables,
) {
if let Some(parent_cx) = self.contexts.get(&REQUEST_CTX).cloned() {
let attributes = vec![
KEY_SOURCE.string(query_source.to_string()),
KEY_VARIABLES.string(serde_json::to_string(variables).unwrap()),
];
let parse_span = self
.tracer
.span_builder("parse")
.with_kind(SpanKind::Server)
.with_attributes(attributes)
.with_parent_context(parent_cx)
.start(&*self.tracer);
let parse_cx = OpenTelemetryContext::current_with_span(parse_span);
self.enter_context(PARSE_CTX, parse_cx);
}
next: NextExtension<'_>,
) -> ServerResult<ExecutableDocument> {
let attributes = vec![
KEY_SOURCE.string(query.to_string()),
KEY_VARIABLES.string(serde_json::to_string(variables).unwrap()),
];
let span = self
.tracer
.span_builder("parse")
.with_kind(SpanKind::Server)
.with_attributes(attributes)
.start(&*self.tracer);
next.parse_query(ctx, query, variables)
.with_context(OpenTelemetryContext::current_with_span(span))
.await
}
fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, _document: &ExecutableDocument) {
self.exit_context(PARSE_CTX);
async fn validation(
&self,
ctx: &ExtensionContext<'_>,
next: NextExtension<'_>,
) -> Result<ValidationResult, Vec<ServerError>> {
let span = self
.tracer
.span_builder("validation")
.with_kind(SpanKind::Server)
.start(&*self.tracer);
next.validation(ctx)
.with_context(OpenTelemetryContext::current_with_span(span))
.map_ok(|res| {
let current_cx = OpenTelemetryContext::current();
let span = current_cx.span();
span.set_attribute(KEY_COMPLEXITY.i64(res.complexity as i64));
span.set_attribute(KEY_DEPTH.i64(res.depth as i64));
res
})
.await
}
fn validation_start(&mut self, _ctx: &ExtensionContext<'_>) {
if let Some(parent_cx) = self.contexts.get(&REQUEST_CTX).cloned() {
let span = self
.tracer
.span_builder("validation")
.with_kind(SpanKind::Server)
.with_parent_context(parent_cx)
.start(&*self.tracer);
let validation_cx = OpenTelemetryContext::current_with_span(span);
self.enter_context(VALIDATION_CTX, validation_cx);
}
async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
let span = self
.tracer
.span_builder("execute")
.with_kind(SpanKind::Server)
.start(&*self.tracer);
next.execute(ctx)
.with_context(OpenTelemetryContext::current_with_span(span))
.await
}
fn validation_end(&mut self, _ctx: &ExtensionContext<'_>, result: &ValidationResult) {
if let Some(validation_cx) = self.exit_context(VALIDATION_CTX) {
let span = validation_cx.span();
span.set_attribute(KEY_COMPLEXITY.i64(result.complexity as i64));
span.set_attribute(KEY_DEPTH.i64(result.depth as i64));
}
}
fn execution_start(&mut self, _ctx: &ExtensionContext<'_>) {
let span = match self.contexts.get(&REQUEST_CTX).cloned() {
Some(parent_cx) => self
.tracer
.span_builder("execute")
.with_kind(SpanKind::Server)
.with_parent_context(parent_cx)
.start(&*self.tracer),
None => self
.tracer
.span_builder("execute")
.with_kind(SpanKind::Server)
.start(&*self.tracer),
};
let execute_cx = OpenTelemetryContext::current_with_span(span);
self.enter_context(EXECUTE_CTX, execute_cx);
}
fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) {
self.exit_context(EXECUTE_CTX);
}
fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
let parent_cx = match info.resolve_id.parent {
Some(parent_id) if parent_id > 0 => self.contexts.get(&resolve_ctx_id(parent_id)),
_ => self.contexts.get(&EXECUTE_CTX),
}
.cloned();
if let Some(parent_cx) = parent_cx {
let attributes = vec![
KEY_RESOLVE_ID.i64(info.resolve_id.current as i64),
KEY_PARENT_TYPE.string(info.parent_type.to_string()),
KEY_RETURN_TYPE.string(info.return_type.to_string()),
];
let span = self
.tracer
.span_builder(&info.path_node.to_string())
.with_kind(SpanKind::Server)
.with_parent_context(parent_cx)
.with_attributes(attributes)
.start(&*self.tracer);
let resolve_cx = OpenTelemetryContext::current_with_span(span);
self.enter_context(resolve_ctx_id(info.resolve_id.current), resolve_cx);
}
}
fn resolve_end(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
self.exit_context(resolve_ctx_id(info.resolve_id.current));
}
fn error(&mut self, _ctx: &ExtensionContext<'_>, err: &ServerError) {
if let Some(parent_cx) = self.contexts.get(&EXECUTE_CTX).cloned() {
parent_cx
.span()
.add_event("error".to_string(), vec![KEY_ERROR.string(err.to_string())]);
}
async fn resolve(
&self,
ctx: &ExtensionContext<'_>,
info: ResolveInfo<'_>,
next: NextExtension<'_>,
) -> ServerResult<Option<Value>> {
let attributes = vec![
KEY_PARENT_TYPE.string(info.parent_type.to_string()),
KEY_RETURN_TYPE.string(info.return_type.to_string()),
];
let span = self
.tracer
.span_builder(&info.path_node.to_string())
.with_kind(SpanKind::Server)
.with_attributes(attributes)
.start(&*self.tracer);
next.resolve(ctx, info)
.with_context(OpenTelemetryContext::current_with_span(span))
.map_err(|err| {
let current_cx = OpenTelemetryContext::current();
current_cx
.span()
.add_event("error".to_string(), vec![KEY_ERROR.string(err.to_string())]);
err
})
.await
}
}

View File

@ -1,36 +1,15 @@
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{span, Level, Span};
use futures_util::stream::BoxStream;
use tracing_futures::Instrument;
use tracinglib::{span, Level};
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo};
use crate::extensions::{
Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo,
};
use crate::futures_util::TryFutureExt;
use crate::parser::types::ExecutableDocument;
use crate::{ServerError, ValidationResult, Variables};
/// Tracing extension configuration for each request.
#[derive(Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "tracing")))]
pub struct TracingConfig {
/// Use a span as the parent node of the entire query.
parent: spin::Mutex<Option<Span>>,
}
impl TracingConfig {
/// Use a span as the parent of the entire query.
pub fn parent_span(mut self, span: Span) -> Self {
*self.parent.get_mut() = Some(span);
self
}
}
const REQUEST_CTX: usize = 0;
const PARSE_CTX: usize = 1;
const VALIDATION_CTX: usize = 2;
const EXECUTE_CTX: usize = 3;
#[inline]
fn resolve_span_id(resolver_id: usize) -> usize {
resolver_id + 10
}
use crate::{Response, ServerError, ServerResult, ValidationResult, Value, Variables};
/// Tracing extension
///
@ -42,7 +21,7 @@ fn resolve_span_id(resolver_id: usize) -> usize {
///
/// ```no_run
/// use async_graphql::*;
/// use async_graphql::extensions::{Tracing, TracingConfig};
/// use async_graphql::extensions::Tracing;
/// use tracing::{span, Level, Instrument};
///
/// #[derive(SimpleObject)]
@ -50,164 +29,112 @@ fn resolve_span_id(resolver_id: usize) -> usize {
/// value: i32,
/// }
///
/// let schema = Schema::build(Query { value: 100 }, EmptyMutation, EmptySubscription).
/// extension(Tracing::default())
/// let schema = Schema::build(Query { value: 100 }, EmptyMutation, EmptySubscription)
/// .extension(Tracing)
/// .finish();
///
/// tokio::runtime::Runtime::new().unwrap().block_on(async {
/// schema.execute(Request::new("{ value }")).await;
/// });
///
/// // tracing in parent span
/// tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let root_span = span!(
/// parent: None,
/// Level::INFO,
/// "span root"
/// );
/// schema.execute(Request::new("{ value }")).instrument(root_span).await;
/// });
///
/// // replace root span
/// tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let root_span = span!(
/// parent: None,
/// Level::INFO,
/// "span root"
/// );
/// schema.execute(Request::new("{ value }").data(TracingConfig::default().parent_span(root_span))).await;
/// });
/// ```
#[derive(Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "tracing")))]
pub struct Tracing;
impl ExtensionFactory for Tracing {
fn create(&self) -> Box<dyn Extension> {
Box::new(TracingExtension::default())
fn create(&self) -> Arc<dyn Extension> {
Arc::new(TracingExtension::default())
}
}
#[derive(Default)]
struct TracingExtension {
spans: HashMap<usize, Span>,
}
impl TracingExtension {
fn enter_span(&mut self, id: usize, span: Span) -> &Span {
let _ = span.enter();
self.spans.insert(id, span);
self.spans.get(&id).unwrap()
}
fn exit_span(&mut self, id: usize) {
if let Some(span) = self.spans.remove(&id) {
let _ = span.enter();
}
}
}
struct TracingExtension;
#[async_trait::async_trait]
impl Extension for TracingExtension {
fn start(&mut self, ctx: &ExtensionContext<'_>) {
let request_span = ctx
.data_opt::<TracingConfig>()
.and_then(|cfg| cfg.parent.lock().take())
.unwrap_or_else(|| {
span!(
target: "async_graphql::graphql",
Level::INFO,
"request",
)
});
self.enter_span(REQUEST_CTX, request_span);
async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
next.request(ctx)
.instrument(span!(
target: "async_graphql::graphql",
Level::INFO,
"request",
))
.await
}
fn end(&mut self, _ctx: &ExtensionContext<'_>) {
self.exit_span(REQUEST_CTX);
fn subscribe<'s>(
&self,
ctx: &ExtensionContext<'_>,
stream: BoxStream<'s, Response>,
next: NextExtension<'_>,
) -> BoxStream<'s, Response> {
Box::pin(next.subscribe(ctx, stream).instrument(span!(
target: "async_graphql::graphql",
Level::INFO,
"subscribe",
)))
}
fn parse_start(
&mut self,
_ctx: &ExtensionContext<'_>,
query_source: &str,
async fn parse_query(
&self,
ctx: &ExtensionContext<'_>,
query: &str,
variables: &Variables,
) {
if let Some(parent) = self.spans.get(&REQUEST_CTX) {
let variables = serde_json::to_string(&variables).unwrap();
let parse_span = span!(
target: "async_graphql::graphql",
parent: parent,
Level::INFO,
"parse",
source = query_source,
variables = %variables,
);
self.enter_span(PARSE_CTX, parse_span);
}
next: NextExtension<'_>,
) -> ServerResult<ExecutableDocument> {
let span = span!(
target: "async_graphql::graphql",
Level::INFO,
"parse",
source = query,
variables = %serde_json::to_string(&variables).unwrap(),
);
next.parse_query(ctx, query, variables)
.instrument(span)
.await
}
fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, _document: &ExecutableDocument) {
self.exit_span(PARSE_CTX);
async fn validation(
&self,
ctx: &ExtensionContext<'_>,
next: NextExtension<'_>,
) -> Result<ValidationResult, Vec<ServerError>> {
let span = span!(
target: "async_graphql::graphql",
Level::INFO,
"validation"
);
next.validation(ctx).instrument(span).await
}
fn validation_start(&mut self, _ctx: &ExtensionContext<'_>) {
if let Some(parent) = self.spans.get(&REQUEST_CTX) {
let span = span!(
target: "async_graphql::graphql",
parent: parent,
Level::INFO,
"validation"
);
self.enter_span(VALIDATION_CTX, span);
}
async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
let span = span!(
target: "async_graphql::graphql",
Level::INFO,
"execute"
);
next.execute(ctx).instrument(span).await
}
fn validation_end(&mut self, _ctx: &ExtensionContext<'_>, _result: &ValidationResult) {
self.exit_span(VALIDATION_CTX);
}
fn execution_start(&mut self, _ctx: &ExtensionContext<'_>) {
if let Some(parent) = self.spans.get(&REQUEST_CTX) {
let span = span!(
target: "async_graphql::graphql",
parent: parent,
Level::INFO,
"execute"
);
self.enter_span(EXECUTE_CTX, span);
};
}
fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) {
self.exit_span(EXECUTE_CTX);
}
fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
let parent = match info.resolve_id.parent {
Some(parent_id) if parent_id > 0 => self.spans.get(&resolve_span_id(parent_id)),
_ => self.spans.get(&EXECUTE_CTX),
};
if let Some(parent) = parent {
let span = span!(
target: "async_graphql::graphql",
parent: parent,
Level::INFO,
"field",
id = %info.resolve_id.current,
path = %info.path_node,
parent_type = %info.parent_type,
return_type = %info.return_type,
);
self.enter_span(resolve_span_id(info.resolve_id.current), span);
}
}
fn resolve_end(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
self.exit_span(resolve_span_id(info.resolve_id.current));
}
fn error(&mut self, _ctx: &ExtensionContext<'_>, err: &ServerError) {
tracing::error!(target: "async_graphql::graphql", error = %err.message);
async fn resolve(
&self,
ctx: &ExtensionContext<'_>,
info: ResolveInfo<'_>,
next: NextExtension<'_>,
) -> ServerResult<Option<Value>> {
let span = span!(
target: "async_graphql::graphql",
Level::INFO,
"field",
path = %info.path_node,
parent_type = %info.parent_type,
return_type = %info.return_type,
);
next.resolve(ctx, info)
.instrument(span)
.map_err(|err| {
tracinglib::error!(target: "async_graphql::graphql", error = %err.message);
err
})
.await
}
}

View File

@ -2,7 +2,7 @@ use std::collections::BTreeMap;
use std::future::Future;
use std::pin::Pin;
use crate::extensions::{ErrorLogger, ResolveInfo};
use crate::extensions::ResolveInfo;
use crate::parser::types::Selection;
use crate::registry::MetaType;
use crate::{
@ -174,18 +174,18 @@ impl<'a> Fields<'a> {
async move {
let ctx_field = ctx.with_field(field);
let field_name = ctx_field.item.node.response_key().node.clone();
let extensions = &ctx.query_env.extensions;
let res = if ctx_field.query_env.extensions.is_empty() {
if extensions.is_empty() {
match root.resolve_field(&ctx_field).await {
Ok(value) => Ok((field_name, value.unwrap_or_default())),
Err(e) => {
Err(e.path(PathSegment::Field(field_name.to_string())))
}
}?
}
} else {
let type_name = T::type_name();
let resolve_info = ResolveInfo {
resolve_id: ctx_field.resolve_id,
path_node: ctx_field.path_node.as_ref().unwrap(),
parent_type: &type_name,
return_type: match ctx_field
@ -210,20 +210,16 @@ impl<'a> Fields<'a> {
},
};
ctx_field.query_env.extensions.resolve_start(&resolve_info);
let res = match root.resolve_field(&ctx_field).await {
let resolve_fut = async { root.resolve_field(&ctx_field).await };
futures_util::pin_mut!(resolve_fut);
let res = extensions.resolve(resolve_info, &mut resolve_fut).await;
match res {
Ok(value) => Ok((field_name, value.unwrap_or_default())),
Err(e) => {
Err(e.path(PathSegment::Field(field_name.to_string())))
}
}
.log_error(&ctx_field.query_env.extensions)?;
ctx_field.query_env.extensions.resolve_end(&resolve_info);
res
};
Ok(res)
}
}
}));
}

View File

@ -1,4 +1,4 @@
use crate::extensions::{ErrorLogger, ResolveInfo};
use crate::extensions::ResolveInfo;
use crate::parser::types::Field;
use crate::{ContextSelectionSet, OutputType, PathSegment, Positioned, ServerResult, Type, Value};
@ -9,38 +9,49 @@ pub async fn resolve_list<'a, T: OutputType + 'a>(
iter: impl IntoIterator<Item = T>,
len: Option<usize>,
) -> ServerResult<Value> {
let mut futures = len.map(Vec::with_capacity).unwrap_or_default();
let extensions = &ctx.query_env.extensions;
for (idx, item) in iter.into_iter().enumerate() {
let ctx_idx = ctx.with_index(idx);
futures.push(async move {
if ctx_idx.query_env.extensions.is_empty() {
if extensions.is_empty() {
let mut futures = len.map(Vec::with_capacity).unwrap_or_default();
for (idx, item) in iter.into_iter().enumerate() {
futures.push({
let ctx = ctx.clone();
async move {
let ctx_idx = ctx.with_index(idx);
let resolve_info = ResolveInfo {
path_node: ctx_idx.path_node.as_ref().unwrap(),
parent_type: &Vec::<T>::type_name(),
return_type: &T::qualified_type_name(),
};
let resolve_fut = async {
OutputType::resolve(&item, &ctx_idx, field)
.await
.map(Option::Some)
.map_err(|e| e.path(PathSegment::Index(idx)))
};
futures_util::pin_mut!(resolve_fut);
extensions
.resolve(resolve_info, &mut resolve_fut)
.await
.map(|value| value.expect("You definitely encountered a bug!"))
}
});
}
Ok(Value::List(
futures_util::future::try_join_all(futures).await?,
))
} else {
let mut futures = len.map(Vec::with_capacity).unwrap_or_default();
for (idx, item) in iter.into_iter().enumerate() {
let ctx_idx = ctx.with_index(idx);
futures.push(async move {
OutputType::resolve(&item, &ctx_idx, field)
.await
.map_err(|e| e.path(PathSegment::Index(idx)))
.log_error(&ctx_idx.query_env.extensions)
} else {
let resolve_info = ResolveInfo {
resolve_id: ctx_idx.resolve_id,
path_node: ctx_idx.path_node.as_ref().unwrap(),
parent_type: &Vec::<T>::type_name(),
return_type: &T::qualified_type_name(),
};
ctx_idx.query_env.extensions.resolve_start(&resolve_info);
let res = OutputType::resolve(&item, &ctx_idx, field)
.await
.map_err(|e| e.path(PathSegment::Index(idx)))
.log_error(&ctx_idx.query_env.extensions)?;
ctx_idx.query_env.extensions.resolve_end(&resolve_info);
Ok(res)
}
});
});
}
Ok(Value::List(
futures_util::future::try_join_all(futures).await?,
))
}
Ok(Value::List(
futures_util::future::try_join_all(futures).await?,
))
}

View File

@ -1,5 +1,6 @@
use http::header::HeaderMap;
use std::collections::BTreeMap;
use http::header::HeaderMap;
use serde::{Deserialize, Serialize};
use crate::{CacheControl, Result, ServerError, Value};
@ -12,8 +13,8 @@ pub struct Response {
pub data: Value,
/// Extensions result
#[serde(skip_serializing_if = "Option::is_none", default)]
pub extensions: Option<Value>,
#[serde(skip_serializing_if = "BTreeMap::is_empty", default)]
pub extensions: BTreeMap<String, Value>,
/// Cache control value
#[serde(skip)]
@ -47,10 +48,11 @@ impl Response {
}
}
/// Set the extensions result of the response.
/// Set the extension result of the response.
#[must_use]
pub fn extensions(self, extensions: Option<Value>) -> Self {
Self { extensions, ..self }
pub fn extension(mut self, name: impl Into<String>, value: Value) -> Self {
self.extensions.insert(name.into(), value);
self
}
/// Set the http headers of the response.

View File

@ -1,14 +1,12 @@
use std::any::Any;
use std::collections::BTreeMap;
use std::ops::Deref;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
use futures_util::stream::{self, Stream, StreamExt};
use indexmap::map::IndexMap;
use crate::context::{Data, QueryEnvInner, ResolveId};
use crate::extensions::{ErrorLogger, ExtensionFactory, Extensions};
use crate::context::{Data, QueryEnvInner};
use crate::extensions::{ExtensionFactory, Extensions};
use crate::model::__DirectiveLocation;
use crate::parser::parse_query;
use crate::parser::types::{DocumentOperations, OperationType};
@ -19,7 +17,7 @@ use crate::types::QueryRoot;
use crate::validation::{check_rules, ValidationMode};
use crate::{
BatchRequest, BatchResponse, CacheControl, ContextBase, ObjectType, QueryEnv, Request,
Response, ServerError, SubscriptionType, Type, Value, ID,
Response, ServerError, SubscriptionType, Type, ID,
};
/// Schema builder
@ -354,17 +352,11 @@ where
}
fn create_extensions(&self, session_data: Arc<Data>) -> Extensions {
let extensions = Extensions::new(
self.0
.extensions
.iter()
.map(|factory| factory.create())
.collect::<Vec<_>>(),
Extensions::new(
self.extensions.iter().map(|f| f.create()),
self.env.clone(),
session_data,
);
extensions.start();
extensions
)
}
async fn prepare_request(
@ -376,36 +368,41 @@ where
let mut request = request;
let query_data = Arc::new(std::mem::take(&mut request.data));
extensions.attach_query_data(query_data.clone());
let request = extensions.prepare_request(request).await?;
extensions.parse_start(&request.query, &request.variables);
let document = parse_query(&request.query)
.map_err(Into::<ServerError>::into)
.log_error(&extensions)?;
extensions.parse_end(&document);
let request = extensions.prepare_request(request).await?;
let document = {
let query = &request.query;
let fut_parse = async { parse_query(&query).map_err(Into::<ServerError>::into) };
futures_util::pin_mut!(fut_parse);
extensions
.parse_query(&query, &request.variables, &mut fut_parse)
.await?
};
// check rules
extensions.validation_start();
let validation_result = check_rules(
&self.env.registry,
&document,
Some(&request.variables),
self.validation_mode,
)
.log_error(&extensions)?;
extensions.validation_end(&validation_result);
let validation_result = {
let validation_fut = async {
check_rules(
&self.env.registry,
&document,
Some(&request.variables),
self.validation_mode,
)
};
futures_util::pin_mut!(validation_fut);
extensions.validation(&mut validation_fut).await?
};
// check limit
if let Some(limit_complexity) = self.complexity {
if validation_result.complexity > limit_complexity {
return Err(vec![ServerError::new("Query is too complex.")]).log_error(&extensions);
return Err(vec![ServerError::new("Query is too complex.")]);
}
}
if let Some(limit_depth) = self.depth {
if validation_result.depth > limit_depth {
return Err(vec![ServerError::new("Query is nested too deep.")])
.log_error(&extensions);
return Err(vec![ServerError::new("Query is nested too deep.")]);
}
}
@ -432,10 +429,7 @@ where
};
let operation = match operation {
Ok(operation) => operation,
Err(e) => {
extensions.error(&e);
return Err(vec![e]);
}
Err(e) => return Err(vec![e]),
};
let env = QueryEnvInner {
@ -454,49 +448,58 @@ where
async fn execute_once(&self, env: QueryEnv) -> Response {
// execute
let inc_resolve_id = AtomicUsize::default();
let ctx = ContextBase {
path_node: None,
resolve_id: ResolveId::root(),
inc_resolve_id: &inc_resolve_id,
item: &env.operation.node.selection_set,
schema_env: &self.env,
query_env: &env,
};
env.extensions.execution_start();
let data = match &env.operation.node.ty {
let res = match &env.operation.node.ty {
OperationType::Query => resolve_container(&ctx, &self.query).await,
OperationType::Mutation => resolve_container_serial(&ctx, &self.mutation).await,
OperationType::Subscription => {
return Response::from_errors(vec![ServerError::new(
"Subscriptions are not supported on this transport.",
)])
)]);
}
};
env.extensions.execution_end();
let extensions = env.extensions.result();
match data {
Ok(data) => Response::new(data),
Err(e) => Response::from_errors(vec![e]),
match res {
Ok(data) => {
let resp = Response::new(data);
resp.http_headers(std::mem::take(&mut *env.http_headers.lock()))
}
Err(err) => Response::from_errors(vec![err]),
}
.extensions(extensions)
.http_headers(std::mem::take(&mut *env.http_headers.lock()))
}
/// Execute a GraphQL query.
pub async fn execute(&self, request: impl Into<Request>) -> Response {
let request = request.into();
let extensions = self.create_extensions(Default::default());
match self
.prepare_request(extensions, request, Default::default())
.await
{
Ok((env, cache_control)) => self.execute_once(env).await.cache_control(cache_control),
Err(errors) => Response::from_errors(errors),
}
let request_fut = {
let extensions = extensions.clone();
async move {
match self
.prepare_request(extensions, request, Default::default())
.await
{
Ok((env, cache_control)) => {
let fut = async {
self.execute_once(env.clone())
.await
.cache_control(cache_control)
};
futures_util::pin_mut!(fut);
env.extensions.execute(&mut fut).await
}
Err(errors) => Response::from_errors(errors),
}
}
};
futures_util::pin_mut!(request_fut);
extensions.request(&mut request_fut).await
}
/// Execute a GraphQL batch query.
@ -518,68 +521,52 @@ where
&self,
request: impl Into<Request> + Send,
session_data: Arc<Data>,
) -> impl Stream<Item = Response> + Send {
) -> impl Stream<Item = Response> + Send + Unpin {
let schema = self.clone();
let request = request.into();
let extensions = self.create_extensions(session_data.clone());
async_stream::stream! {
let (env, cache_control) = match schema.prepare_request(extensions, request, session_data).await {
Ok(res) => res,
Err(errors) => {
yield Response::from_errors(errors);
let stream = futures_util::stream::StreamExt::boxed({
let extensions = extensions.clone();
async_stream::stream! {
let (env, cache_control) = match schema.prepare_request(extensions, request, session_data).await {
Ok(res) => res,
Err(errors) => {
yield Response::from_errors(errors);
return;
}
};
if env.operation.node.ty != OperationType::Subscription {
yield schema.execute_once(env).await.cache_control(cache_control);
return;
}
};
if env.operation.node.ty != OperationType::Subscription {
yield schema
.execute_once(env)
.await
.cache_control(cache_control);
return;
}
let ctx = env.create_context(
&schema.env,
None,
&env.operation.node.selection_set,
);
let resolve_id = AtomicUsize::default();
let ctx = env.create_context(
&schema.env,
None,
&env.operation.node.selection_set,
ResolveId::root(),
&resolve_id,
);
let mut streams = Vec::new();
if let Err(err) = collect_subscription_streams(&ctx, &schema.subscription, &mut streams) {
yield Response::from_errors(vec![err]);
}
let mut streams = Vec::new();
if let Err(e) = collect_subscription_streams(&ctx, &schema.subscription, &mut streams) {
env.extensions.execution_end();
yield Response::from_errors(vec![e]);
return;
}
let mut stream = stream::select_all(streams);
while let Some(data) = stream.next().await {
let is_err = data.is_err();
let extensions = env.extensions.result();
yield match data {
Ok((name, value)) => {
let mut map = BTreeMap::new();
map.insert(name, value);
Response::new(Value::Object(map))
},
Err(e) => Response::from_errors(vec![e]),
}.extensions(extensions);
if is_err {
break;
let mut stream = stream::select_all(streams);
while let Some(resp) = stream.next().await {
yield resp;
}
}
}
});
extensions.subscribe(stream)
}
/// Execute a GraphQL subscription.
pub fn execute_stream(
&self,
request: impl Into<Request>,
) -> impl Stream<Item = Response> + Send {
) -> impl Stream<Item = Response> + Send + Unpin {
self.execute_stream_with_session_data(request.into(), Default::default())
}
}

View File

@ -3,9 +3,7 @@ use std::pin::Pin;
use futures_util::stream::{Stream, StreamExt};
use crate::parser::types::{Selection, TypeCondition};
use crate::{
Context, ContextSelectionSet, Name, PathSegment, ServerError, ServerResult, Type, Value,
};
use crate::{Context, ContextSelectionSet, PathSegment, Response, ServerError, ServerResult, Type};
/// A GraphQL subscription object
pub trait SubscriptionType: Type + Send + Sync {
@ -19,10 +17,10 @@ pub trait SubscriptionType: Type + Send + Sync {
fn create_field_stream<'a>(
&'a self,
ctx: &'a Context<'_>,
) -> Option<Pin<Box<dyn Stream<Item = ServerResult<Value>> + Send + 'a>>>;
) -> Option<Pin<Box<dyn Stream<Item = Response> + Send + 'a>>>;
}
type BoxFieldStream<'a> = Pin<Box<dyn Stream<Item = ServerResult<(Name, Value)>> + 'a + Send>>;
type BoxFieldStream<'a> = Pin<Box<dyn Stream<Item = Response> + 'a + Send>>;
pub(crate) fn collect_subscription_streams<'a, T: SubscriptionType + 'static>(
ctx: &ContextSelectionSet<'a>,
@ -41,16 +39,14 @@ pub(crate) fn collect_subscription_streams<'a, T: SubscriptionType + 'static>(
let field_name = ctx.item.node.response_key().node.clone();
let stream = root.create_field_stream(&ctx);
if let Some(mut stream) = stream {
while let Some(item) = stream.next().await {
yield match item {
Ok(value) => Ok((field_name.to_owned(), value)),
Err(e) => Err(e.path(PathSegment::Field(field_name.to_string()))),
};
while let Some(resp) = stream.next().await {
yield resp;
}
} else {
yield Err(ServerError::new(format!(r#"Cannot query field "{}" on type "{}"."#, field_name, T::type_name()))
let err = ServerError::new(format!(r#"Cannot query field "{}" on type "{}"."#, field_name, T::type_name()))
.at(ctx.item.pos)
.path(PathSegment::Field(field_name.to_string())));
.path(PathSegment::Field(field_name.to_string()));
yield Response::from_errors(vec![err]);
}
}
})),
@ -98,7 +94,7 @@ impl<T: SubscriptionType> SubscriptionType for &T {
fn create_field_stream<'a>(
&'a self,
ctx: &'a Context<'_>,
) -> Option<Pin<Box<dyn Stream<Item = ServerResult<Value>> + Send + 'a>>> {
) -> Option<Pin<Box<dyn Stream<Item = Response> + Send + 'a>>> {
T::create_field_stream(*self, ctx)
}
}

View File

@ -3,7 +3,7 @@ use std::pin::Pin;
use futures_util::stream::{self, Stream};
use crate::{registry, Context, ServerError, ServerResult, SubscriptionType, Type, Value};
use crate::{registry, Context, Response, ServerError, SubscriptionType, Type};
/// Empty subscription
///
@ -37,12 +37,13 @@ impl SubscriptionType for EmptySubscription {
fn create_field_stream<'a>(
&'a self,
ctx: &'a Context<'_>,
) -> Option<Pin<Box<dyn Stream<Item = ServerResult<Value>> + Send + 'a>>>
) -> Option<Pin<Box<dyn Stream<Item = Response> + Send + 'a>>>
where
Self: Send + Sync + 'static + Sized,
{
Some(Box::pin(stream::once(async move {
Err(ServerError::new("Schema is not configured for mutations.").at(ctx.item.pos))
let err = ServerError::new("Schema is not configured for mutations.").at(ctx.item.pos);
Response::from_errors(vec![err])
})))
}
}

View File

@ -16,6 +16,7 @@ pub use visitor::VisitorContext;
use visitor::{visit, VisitorNil};
/// Validation results.
#[derive(Debug, Copy, Clone)]
pub struct ValidationResult {
/// Cache control
pub cache_control: CacheControl,

View File

@ -1,12 +1,15 @@
use std::sync::Arc;
use async_graphql::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo};
use async_graphql::extensions::{
Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo,
};
use async_graphql::futures_util::stream::BoxStream;
use async_graphql::parser::types::ExecutableDocument;
use async_graphql::*;
use async_graphql_value::ConstValue;
use futures_util::lock::Mutex;
use futures_util::stream::Stream;
use futures_util::StreamExt;
use spin::Mutex;
#[tokio::test]
pub async fn test_extension_ctx() {
@ -18,7 +21,7 @@ pub async fn test_extension_ctx() {
#[Object]
impl Query {
async fn value(&self, ctx: &Context<'_>) -> i32 {
*ctx.data_unchecked::<MyData>().0.lock()
*ctx.data_unchecked::<MyData>().0.lock().await
}
}
@ -27,7 +30,7 @@ pub async fn test_extension_ctx() {
#[Subscription]
impl Subscription {
async fn value(&self, ctx: &Context<'_>) -> impl Stream<Item = i32> {
let data = *ctx.data_unchecked::<MyData>().0.lock();
let data = *ctx.data_unchecked::<MyData>().0.lock().await;
futures_util::stream::once(async move { data })
}
}
@ -36,23 +39,25 @@ pub async fn test_extension_ctx() {
#[async_trait::async_trait]
impl Extension for MyExtensionImpl {
fn parse_start(
&mut self,
async fn parse_query(
&self,
ctx: &ExtensionContext<'_>,
_query_source: &str,
_variables: &Variables,
) {
query: &str,
variables: &Variables,
next: NextExtension<'_>,
) -> ServerResult<ExecutableDocument> {
if let Ok(data) = ctx.data::<MyData>() {
*data.0.lock() = 100;
*data.0.lock().await = 100;
}
next.parse_query(ctx, query, variables).await
}
}
struct MyExtension;
impl ExtensionFactory for MyExtension {
fn create(&self) -> Box<dyn Extension> {
Box::new(MyExtensionImpl)
fn create(&self) -> Arc<dyn Extension> {
Arc::new(MyExtensionImpl)
}
}
@ -104,12 +109,10 @@ pub async fn test_extension_ctx() {
let mut data = Data::default();
data.insert(MyData::default());
let mut stream = schema
.execute_stream_with_session_data(
Request::new("subscription { value }"),
Arc::new(data),
)
.boxed();
let mut stream = schema.execute_stream_with_session_data(
Request::new("subscription { value }"),
Arc::new(data),
);
assert_eq!(
stream.next().await.unwrap().into_result().unwrap().data,
value! ({
@ -128,67 +131,83 @@ pub async fn test_extension_call_order() {
#[async_trait::async_trait]
#[allow(unused_variables)]
impl Extension for MyExtensionImpl {
fn name(&self) -> Option<&'static str> {
Some("test")
async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
self.calls.lock().await.push("request_start");
let res = next.request(ctx).await;
self.calls.lock().await.push("request_end");
res
}
fn start(&mut self, ctx: &ExtensionContext<'_>) {
self.calls.lock().push("start");
}
fn end(&mut self, ctx: &ExtensionContext<'_>) {
self.calls.lock().push("end");
fn subscribe<'s>(
&self,
ctx: &ExtensionContext<'_>,
mut stream: BoxStream<'s, Response>,
next: NextExtension<'_>,
) -> BoxStream<'s, Response> {
let calls = self.calls.clone();
let stream = async_stream::stream! {
calls.lock().await.push("subscribe_start");
while let Some(item) = stream.next().await {
yield item;
}
calls.lock().await.push("subscribe_end");
};
Box::pin(stream)
}
async fn prepare_request(
&mut self,
&self,
ctx: &ExtensionContext<'_>,
request: Request,
next: NextExtension<'_>,
) -> ServerResult<Request> {
self.calls.lock().push("prepare_request");
Ok(request)
self.calls.lock().await.push("prepare_request_start");
let res = next.prepare_request(ctx, request).await;
self.calls.lock().await.push("prepare_request_end");
res
}
fn parse_start(
&mut self,
async fn parse_query(
&self,
ctx: &ExtensionContext<'_>,
query_source: &str,
query: &str,
variables: &Variables,
) {
self.calls.lock().push("parse_start");
next: NextExtension<'_>,
) -> ServerResult<ExecutableDocument> {
self.calls.lock().await.push("parse_query_start");
let res = next.parse_query(ctx, query, variables).await;
self.calls.lock().await.push("parse_query_end");
res
}
fn parse_end(&mut self, ctx: &ExtensionContext<'_>, document: &ExecutableDocument) {
self.calls.lock().push("parse_end");
async fn validation(
&self,
ctx: &ExtensionContext<'_>,
next: NextExtension<'_>,
) -> Result<ValidationResult, Vec<ServerError>> {
self.calls.lock().await.push("validation_start");
let res = next.validation(ctx).await;
self.calls.lock().await.push("validation_end");
res
}
fn validation_start(&mut self, ctx: &ExtensionContext<'_>) {
self.calls.lock().push("validation_start");
async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response {
self.calls.lock().await.push("execute_start");
let res = next.execute(ctx).await;
self.calls.lock().await.push("execute_end");
res
}
fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) {
self.calls.lock().push("validation_end");
}
fn execution_start(&mut self, ctx: &ExtensionContext<'_>) {
self.calls.lock().push("execution_start");
}
fn execution_end(&mut self, ctx: &ExtensionContext<'_>) {
self.calls.lock().push("execution_end");
}
fn resolve_start(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
self.calls.lock().push("resolve_start");
}
fn resolve_end(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
self.calls.lock().push("resolve_end");
}
fn result(&mut self, ctx: &ExtensionContext<'_>) -> Option<ConstValue> {
self.calls.lock().push("result");
None
async fn resolve(
&self,
ctx: &ExtensionContext<'_>,
info: ResolveInfo<'_>,
next: NextExtension<'_>,
) -> ServerResult<Option<ConstValue>> {
self.calls.lock().await.push("resolve_start");
let res = next.resolve(ctx, info).await;
self.calls.lock().await.push("resolve_end");
res
}
}
@ -197,8 +216,8 @@ pub async fn test_extension_call_order() {
}
impl ExtensionFactory for MyExtension {
fn create(&self) -> Box<dyn Extension> {
Box::new(MyExtensionImpl {
fn create(&self) -> Arc<dyn Extension> {
Arc::new(MyExtensionImpl {
calls: self.calls.clone(),
})
}
@ -238,24 +257,24 @@ pub async fn test_extension_call_order() {
.await
.into_result()
.unwrap();
let calls = calls.lock();
let calls = calls.lock().await;
assert_eq!(
&*calls,
&vec![
"start",
"prepare_request",
"parse_start",
"parse_end",
"request_start",
"prepare_request_start",
"prepare_request_end",
"parse_query_start",
"parse_query_end",
"validation_start",
"validation_end",
"execution_start",
"execute_start",
"resolve_start",
"resolve_end",
"resolve_start",
"resolve_end",
"execution_end",
"result",
"end",
"execute_end",
"request_end",
]
);
}
@ -267,34 +286,36 @@ pub async fn test_extension_call_order() {
calls: calls.clone(),
})
.finish();
let mut stream = schema.execute_stream("subscription { value }").boxed();
let mut stream = schema.execute_stream("subscription { value }");
while let Some(_) = stream.next().await {}
let calls = calls.lock();
let calls = calls.lock().await;
assert_eq!(
&*calls,
&vec![
"start",
"prepare_request",
"parse_start",
"parse_end",
"subscribe_start",
"prepare_request_start",
"prepare_request_end",
"parse_query_start",
"parse_query_end",
"validation_start",
"validation_end",
"execution_start",
// push 1
"execute_start",
"resolve_start",
"resolve_end",
"execution_end",
"result",
"execution_start",
"execute_end",
// push 2
"execute_start",
"resolve_start",
"resolve_end",
"execution_end",
"result",
"execution_start",
"execute_end",
// push 3
"execute_start",
"resolve_start",
"resolve_end",
"execution_end",
"result",
"end",
"execute_end",
// end
"subscribe_end",
]
);
}

View File

@ -119,7 +119,7 @@ pub async fn test_field_features() {
}]
);
let mut stream = schema.execute_stream("subscription { values }").boxed();
let mut stream = schema.execute_stream("subscription { values }");
assert_eq!(
stream
.next()
@ -131,7 +131,7 @@ pub async fn test_field_features() {
})
);
let mut stream = schema.execute_stream("subscription { valuesBson }").boxed();
let mut stream = schema.execute_stream("subscription { valuesBson }");
assert_eq!(
stream.next().await.map(|resp| resp.data).unwrap(),
value!({
@ -142,7 +142,6 @@ pub async fn test_field_features() {
assert_eq!(
schema
.execute_stream("subscription { valuesAbc }")
.boxed()
.next()
.await
.unwrap()

View File

@ -284,8 +284,7 @@ pub async fn test_generic_subscription() {
{
let mut stream = schema
.execute_stream("subscription { values }")
.map(|resp| resp.into_result().unwrap().data)
.boxed();
.map(|resp| resp.into_result().unwrap().data);
for i in 1..=2 {
assert_eq!(value!({ "values": i }), stream.next().await.unwrap());
}

View File

@ -117,7 +117,6 @@ pub async fn test_guard_simple_rule() {
assert_eq!(
schema
.execute_stream(Request::new("subscription { values }").data(Role::Guest))
.boxed()
.next()
.await
.unwrap()

View File

@ -196,8 +196,7 @@ pub async fn test_merged_subscription() {
{
let mut stream = schema
.execute_stream("subscription { events1 }")
.map(|resp| resp.into_result().unwrap().data)
.boxed();
.map(|resp| resp.into_result().unwrap().data);
for i in 0i32..10 {
assert_eq!(
value!({
@ -212,8 +211,7 @@ pub async fn test_merged_subscription() {
{
let mut stream = schema
.execute_stream("subscription { events2 }")
.map(|resp| resp.into_result().unwrap().data)
.boxed();
.map(|resp| resp.into_result().unwrap().data);
for i in 10i32..20 {
assert_eq!(
value!({

View File

@ -66,8 +66,7 @@ pub async fn test_input_value_custom_error() {
let mut stream = schema
.execute_stream("subscription { type }")
.map(|resp| resp.into_result())
.map_ok(|resp| resp.data)
.boxed();
.map_ok(|resp| resp.data);
for i in 0..10 {
assert_eq!(value!({ "type": i }), stream.next().await.unwrap().unwrap());
}

View File

@ -127,7 +127,6 @@ pub async fn test_subscription() {
assert_eq!(
Schema::new(Query, EmptyMutation, Subscription)
.execute_stream("subscription { CREATE_OBJECT(objectid: 100) }")
.boxed()
.next()
.await
.unwrap()

View File

@ -36,8 +36,7 @@ pub async fn test_subscription() {
{
let mut stream = schema
.execute_stream("subscription { values(start: 10, end: 20) }")
.map(|resp| resp.into_result().unwrap().data)
.boxed();
.map(|resp| resp.into_result().unwrap().data);
for i in 10..20 {
assert_eq!(value!({ "values": i }), stream.next().await.unwrap());
}
@ -47,8 +46,7 @@ pub async fn test_subscription() {
{
let mut stream = schema
.execute_stream("subscription { events(start: 10, end: 20) { a b } }")
.map(|resp| resp.into_result().unwrap().data)
.boxed();
.map(|resp| resp.into_result().unwrap().data);
for i in 10..20 {
assert_eq!(
value!({ "events": {"a": i, "b": i * 10} }),
@ -98,8 +96,7 @@ pub async fn test_subscription_with_ctx_data() {
{
let mut stream = schema
.execute_stream(Request::new("subscription { values objects { value } }").data(100i32))
.map(|resp| resp.data)
.boxed();
.map(|resp| resp.data);
assert_eq!(value!({ "values": 100 }), stream.next().await.unwrap());
assert_eq!(
value!({ "objects": { "value": 100 } }),
@ -141,8 +138,7 @@ pub async fn test_subscription_with_token() {
.execute_stream(
Request::new("subscription { values }").data(Token("123456".to_string())),
)
.map(|resp| resp.into_result().unwrap().data)
.boxed();
.map(|resp| resp.into_result().unwrap().data);
assert_eq!(value!({ "values": 100 }), stream.next().await.unwrap());
assert!(stream.next().await.is_none());
}
@ -152,7 +148,6 @@ pub async fn test_subscription_with_token() {
.execute_stream(
Request::new("subscription { values }").data(Token("654321".to_string()))
)
.boxed()
.next()
.await
.unwrap()
@ -200,8 +195,7 @@ pub async fn test_subscription_inline_fragment() {
}
"#,
)
.map(|resp| resp.data)
.boxed();
.map(|resp| resp.data);
for i in 10..20 {
assert_eq!(
value!({ "events": {"a": i, "b": i * 10} }),
@ -250,8 +244,7 @@ pub async fn test_subscription_fragment() {
}
"#,
)
.map(|resp| resp.data)
.boxed();
.map(|resp| resp.data);
for i in 10i32..20 {
assert_eq!(
value!({ "events": {"a": i, "b": i * 10} }),
@ -301,8 +294,7 @@ pub async fn test_subscription_fragment2() {
}
"#,
)
.map(|resp| resp.data)
.boxed();
.map(|resp| resp.data);
for i in 10..20 {
assert_eq!(
value!({ "events": {"a": i, "b": i * 10} }),
@ -342,8 +334,7 @@ pub async fn test_subscription_error() {
let mut stream = schema
.execute_stream("subscription { events { value } }")
.map(|resp| resp.into_result())
.map_ok(|resp| resp.data)
.boxed();
.map_ok(|resp| resp.data);
for i in 0i32..5 {
assert_eq!(
value!({ "events": { "value": i } }),
@ -388,8 +379,7 @@ pub async fn test_subscription_fieldresult() {
let mut stream = schema
.execute_stream("subscription { values }")
.map(|resp| resp.into_result())
.map_ok(|resp| resp.data)
.boxed();
.map_ok(|resp| resp.data);
for i in 0i32..5 {
assert_eq!(
value!({ "values": i }),