diff --git a/Cargo.toml b/Cargo.toml index 929d2ca4..dcf80521 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -77,5 +77,4 @@ members = [ "integrations/rocket", "integrations/tide", "integrations/warp", - "benchmark", ] diff --git a/benchmark/Cargo.toml b/benchmark/Cargo.toml deleted file mode 100644 index f01ee33a..00000000 --- a/benchmark/Cargo.toml +++ /dev/null @@ -1,29 +0,0 @@ -[package] -name = "graphql-benchmark" -version = "2.6.5" -authors = ["sunli ", "Koxiaet"] -edition = "2018" - -[dependencies] -async-std = { version = "1.9.0", features = ["attributes"] } -futures = "0.3.13" -serde_json = "1.0.64" -async-graphql-parser = { path = "../parser" } -async-graphql = { path = ".." } -jemallocator = { version = "0.3.2", optional = true } - -[dev-dependencies] -criterion = "0.3.4" -simple = { path = "simple" } -chat = { path = "chat" } - -[features] -jemalloc = ["jemallocator"] - -[[bench]] -name = "simple" -harness = false - -[[bench]] -name = "chat" -harness = false diff --git a/benchmark/benches/chat.rs b/benchmark/benches/chat.rs deleted file mode 100644 index fca31a17..00000000 --- a/benchmark/benches/chat.rs +++ /dev/null @@ -1,13 +0,0 @@ -use chat::{Q, S}; -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use graphql_benchmark::{parse, run, serialize}; - -pub fn bench(c: &mut Criterion) { - c.bench_function("chat run", |b| b.iter(|| run(&S, black_box(Q)))); - c.bench_function("chat parse", |b| b.iter(|| parse(black_box(Q)))); - let res = run(&S, Q); - c.bench_function("chat serialize", |b| b.iter(|| serialize(black_box(&res)))); -} - -criterion_group!(chat, bench); -criterion_main!(chat); diff --git a/benchmark/benches/simple.rs b/benchmark/benches/simple.rs deleted file mode 100644 index 9a436f27..00000000 --- a/benchmark/benches/simple.rs +++ /dev/null @@ -1,15 +0,0 @@ -use criterion::{black_box, criterion_group, criterion_main, Criterion}; -use graphql_benchmark::{parse, run, serialize}; -use simple::{Q, S}; - -pub fn bench(c: &mut Criterion) { - c.bench_function("simple run", |b| b.iter(|| run(&S, black_box(Q)))); - c.bench_function("simple parse", |b| b.iter(|| parse(black_box(Q)))); - let res = run(&S, Q); - c.bench_function("simple serialize", |b| { - b.iter(|| serialize(black_box(&res))) - }); -} - -criterion_group!(simple, bench); -criterion_main!(simple); diff --git a/benchmark/chat/Cargo.toml b/benchmark/chat/Cargo.toml deleted file mode 100644 index 382bf3d8..00000000 --- a/benchmark/chat/Cargo.toml +++ /dev/null @@ -1,9 +0,0 @@ -[package] -name = "chat" -version = "2.6.5" -authors = ["Ivan Plesskih "] -edition = "2018" - -[dependencies] -async-graphql = { path = "../.." } -lazy_static = "1.4.0" diff --git a/benchmark/chat/src/lib.rs b/benchmark/chat/src/lib.rs deleted file mode 100644 index 6965fc67..00000000 --- a/benchmark/chat/src/lib.rs +++ /dev/null @@ -1,249 +0,0 @@ -use async_graphql::*; - -pub struct ChatData { - pub id: String, - pub created_at: String, - pub title: String, - pub caption: String, - pub creator_user_id: String, - pub state: String, -} - -pub struct UserData { - pub id: String, - pub is_operator: bool, - pub phone: u64, - pub join_date: String, - - pub state: String, -} - -pub struct ProfileData { - pub first_name: String, - pub last_name: String, - pub city: Option, - pub job_title: Option, - pub email: String, -} - -pub struct MessageData { - pub id: String, - pub user_id: String, - pub timestamp: String, - pub edited: bool, - pub order: i32, - pub message: String, -} - -lazy_static::lazy_static! { - pub static ref CHAT: ChatData = ChatData { - id: "1".to_string(), - created_at: "today".to_string(), - title: "chat".to_string(), - caption: "asdasd".to_string(), - creator_user_id: "123".to_string(), - state: "ACTIVE".to_string(), - }; - - pub static ref USER: UserData = UserData { - id: "123".to_string(), - is_operator: false, - phone: 79_123_273_936, - join_date: "today".to_string(), - state: "ACTIVE".to_string(), - }; - - pub static ref PROFILE: ProfileData = ProfileData { - first_name: "Ivan".to_string(), - last_name: "Plesskih".to_string(), - city: Some("Che".to_string()), - job_title: Some("progr".to_string()), - email: "asd@qwe.ru".to_string(), - }; - - pub static ref MESSAGE: MessageData = MessageData { - id: "456".to_string(), - user_id: "123".to_string(), - timestamp: "today".to_string(), - edited: false, - order: 123, - message: "Hello, world!".to_string(), - }; -} - -pub struct Chat; - -#[Object] -impl Chat { - pub async fn id(&self) -> ID { - ID::from(&CHAT.id) - } - - pub async fn messages(&self) -> Vec { - let mut res = vec![]; - for _ in 0..30 { - res.push(Message); - } - res - } - - pub async fn users(&self) -> Vec { - let mut res = vec![]; - for _ in 0..5 { - res.push(User); - } - res - } - - pub async fn creator(&self) -> User { - User - } - - #[graphql(name = "created_at")] - pub async fn created_at(&self) -> &String { - &CHAT.created_at - } - pub async fn title(&self) -> &String { - &CHAT.title - } - pub async fn caption(&self) -> &String { - &CHAT.caption - } - pub async fn state(&self) -> &String { - &CHAT.state - } -} - -pub struct Message; - -#[Object] -impl Message { - pub async fn id(&self) -> ID { - ID::from(&MESSAGE.id) - } - - pub async fn user(&self) -> User { - User - } - pub async fn timestamp(&self) -> &String { - &MESSAGE.timestamp - } - pub async fn message(&self) -> &String { - &MESSAGE.message - } - pub async fn order(&self) -> i32 { - MESSAGE.order - } - pub async fn edited(&self) -> bool { - MESSAGE.edited - } -} - -pub struct User; - -#[Object] -impl User { - pub async fn id(&self) -> ID { - ID::from(&USER.id) - } - - pub async fn profile(&self) -> Option { - Some(UserProfile) - } - - #[graphql(name = "is_operator")] - pub async fn is_operator(&self) -> bool { - USER.is_operator - } - pub async fn phone(&self) -> String { - USER.phone.to_string() - } - #[graphql(name = "join_date")] - pub async fn join_date(&self) -> &String { - &USER.join_date - } - pub async fn state(&self) -> &String { - &USER.state - } -} - -pub struct UserProfile; - -#[Object] -impl UserProfile { - pub async fn email(&self) -> &String { - &PROFILE.email - } - #[graphql(name = "first_name")] - pub async fn first_name(&self) -> &String { - &PROFILE.first_name - } - #[graphql(name = "last_name")] - pub async fn last_name(&self) -> &String { - &PROFILE.last_name - } - #[graphql(name = "job_title")] - pub async fn job_title(&self) -> &Option { - &PROFILE.job_title - } - pub async fn city(&self) -> &Option { - &PROFILE.city - } -} - -pub struct Query; - -#[Object] -impl Query { - async fn chats(&self) -> Vec { - let mut res = vec![]; - for _ in 0..30 { - res.push(Chat); - } - res - } -} - -lazy_static::lazy_static! { - pub static ref S: Schema = Schema::new(Query, EmptyMutation, EmptySubscription); -} - -pub const Q: &str = r#" -fragment User on User { - id - is_operator - phone - join_date - state - profile { - email - first_name - last_name - job_title - city - } -} - -{ - chats { - id - created_at - title - caption - state - creator { - ...User - } - messages { - id - timestamp - edited - message - order - } - users { - ...User - } - } -} -"#; diff --git a/benchmark/simple/Cargo.toml b/benchmark/simple/Cargo.toml deleted file mode 100644 index d46f634a..00000000 --- a/benchmark/simple/Cargo.toml +++ /dev/null @@ -1,9 +0,0 @@ -[package] -name = "simple" -version = "2.6.5" -authors = ["Ivan Plesskih "] -edition = "2018" - -[dependencies] -async-graphql = { path = "../.." } -lazy_static = "1.4.0" diff --git a/benchmark/simple/src/lib.rs b/benchmark/simple/src/lib.rs deleted file mode 100644 index 2b4d2e9b..00000000 --- a/benchmark/simple/src/lib.rs +++ /dev/null @@ -1,70 +0,0 @@ -use async_graphql::*; - -pub struct QueryRoot; - -#[Object] -impl QueryRoot { - async fn value_i32(&self) -> i32 { - 999 - } - - async fn obj(&self) -> MyObj { - MyObj - } -} - -pub struct MyObj; - -#[Object] -impl MyObj { - async fn value_i32(&self) -> i32 { - 999 - } - - async fn value_list(&self) -> &[i32] { - &[1, 2, 3, 4, 5, 6, 7, 8, 9] - } - - async fn obj(&self) -> MyObj { - MyObj - } -} - -lazy_static::lazy_static! { - pub static ref S: Schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription); - // static ref D: Document = parse_query(Q).unwrap(); -} - -pub const Q: &str = r#"{ - valueI32 obj { - valueI32 valueList obj { - valueI32 valueList obj { - valueI32 valueList obj { - valueI32 valueList obj { - valueI32 valueList obj { - valueI32 valueList obj { - valueI32 valueList obj { - valueI32 valueList obj { - valueI32 valueList obj { - valueI32 valueList obj { - valueI32 valueList obj { - valueI32 valueList obj { - valueI32 valueList obj { - valueI32 valueList obj { - valueI32 valueList - } - } - } - } - } - } - } - } - } - } - } - } - } - } - } -}"#; diff --git a/benchmark/src/lib.rs b/benchmark/src/lib.rs deleted file mode 100644 index f722670c..00000000 --- a/benchmark/src/lib.rs +++ /dev/null @@ -1,35 +0,0 @@ -use async_graphql::{ObjectType, Response, Schema, SubscriptionType}; -use async_graphql_parser::{parse_query, types::ExecutableDocument}; -use async_std::task; - -#[cfg(feature = "jemalloc")] -#[global_allocator] -static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc; - -pub fn run( - s: &Schema, - q: &str, -) -> Response -where - Query: ObjectType + 'static, - Mutation: ObjectType + 'static, - Subscription: SubscriptionType + 'static, -{ - task::block_on(async { s.execute(q).await.into_result().unwrap() }) -} - -pub fn parse(q: &str) -> ExecutableDocument { - parse_query(q).unwrap() -} - -// pub fn validate() { -// check_rules(&S.env.registry, &D, S.validation_mode).unwrap(); -// } -// -// pub fn resolve() { -// do_resolve(...).unwrap(); -// } - -pub fn serialize(r: &async_graphql::Response) -> String { - serde_json::to_string(r).unwrap() -} diff --git a/derive/src/subscription.rs b/derive/src/subscription.rs index b5d93964..ec97f08a 100644 --- a/derive/src/subscription.rs +++ b/derive/src/subscription.rs @@ -360,12 +360,8 @@ pub fn generate( resolve_id, &inc_resolve_id, ); - let ctx_extension = #crate_name::extensions::ExtensionContext { - schema_data: &schema_env.data, - query_data: &query_env.ctx_data, - }; - query_env.extensions.execution_start(&ctx_extension); + query_env.extensions.execution_start(); #[allow(bare_trait_objects)] let ri = #crate_name::extensions::ResolveInfo { @@ -375,12 +371,10 @@ pub fn generate( return_type: &<<#stream_ty as #crate_name::futures_util::stream::Stream>::Item as #crate_name::Type>::qualified_type_name(), }; - query_env.extensions.resolve_start(&ctx_extension, &ri); - + query_env.extensions.resolve_start(&ri); let res = #crate_name::OutputType::resolve(&msg, &ctx_selection_set, &*field).await; - - query_env.extensions.resolve_end(&ctx_extension, &ri); - query_env.extensions.execution_end(&ctx_extension); + query_env.extensions.resolve_end(&ri); + query_env.extensions.execution_end(); res } diff --git a/src/context.rs b/src/context.rs index 6cb7e80b..1de5b60a 100644 --- a/src/context.rs +++ b/src/context.rs @@ -243,6 +243,7 @@ pub struct QueryEnvInner { pub operation: Positioned, pub fragments: HashMap>, pub uploads: Vec, + pub session_data: Arc, pub ctx_data: Arc, pub http_headers: spin::Mutex>, } @@ -363,6 +364,7 @@ impl<'a, T> ContextBase<'a, T> { .ctx_data .0 .get(&TypeId::of::()) + .or_else(|| self.query_env.session_data.0.get(&TypeId::of::())) .or_else(|| self.schema_env.data.0.get(&TypeId::of::())) .and_then(|d| d.downcast_ref::()) } diff --git a/src/extensions/apollo_tracing.rs b/src/extensions/apollo_tracing.rs index 5793c55d..d54990b6 100644 --- a/src/extensions/apollo_tracing.rs +++ b/src/extensions/apollo_tracing.rs @@ -6,7 +6,7 @@ use serde::ser::SerializeMap; use serde::{Serialize, Serializer}; use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; -use crate::{value, Value, Variables}; +use crate::{value, Value}; struct PendingResolve { path: Vec, @@ -79,13 +79,10 @@ impl Extension for ApolloTracingExtension { Some("tracing") } - fn parse_start( - &mut self, - _ctx: &ExtensionContext<'_>, - _query_source: &str, - _variables: &Variables, - ) { + fn execution_start(&mut self, _ctx: &ExtensionContext<'_>) { self.start_time = Utc::now(); + self.pending_resolves.clear(); + self.resolves.clear(); } fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) { diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 30282afe..f6799d33 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -14,10 +14,13 @@ mod tracing; use std::any::{Any, TypeId}; use std::collections::BTreeMap; +use std::sync::Arc; use crate::context::{QueryPathNode, ResolveId}; use crate::parser::types::ExecutableDocument; -use crate::{Data, Request, Result, ServerError, ServerResult, ValidationResult, Variables}; +use crate::{ + Data, Request, Result, SchemaEnv, ServerError, ServerResult, ValidationResult, Variables, +}; use crate::{Error, Name, Value}; pub use self::analyzer::Analyzer; @@ -26,9 +29,9 @@ pub use self::apollo_tracing::ApolloTracing; #[cfg(feature = "log")] pub use self::logger::Logger; #[cfg(feature = "opentelemetry")] -pub use self::opentelemetry::OpenTelemetry; +pub use self::opentelemetry::{OpenTelemetry, OpenTelemetryConfig}; #[cfg(feature = "tracing")] -pub use self::tracing::Tracing; +pub use self::tracing::{Tracing, TracingConfig}; pub(crate) type BoxExtension = Box; @@ -38,7 +41,10 @@ pub struct ExtensionContext<'a> { pub schema_data: &'a Data, #[doc(hidden)] - pub query_data: &'a Data, + pub session_data: &'a Data, + + #[doc(hidden)] + pub query_data: Option<&'a Data>, } impl<'a> ExtensionContext<'a> { @@ -71,7 +77,8 @@ impl<'a> ExtensionContext<'a> { /// Gets the global data defined in the `Context` or `Schema` or `None` if the specified type data does not exist. pub fn data_opt(&self) -> Option<&'a D> { self.query_data - .get(&TypeId::of::()) + .and_then(|query_data| query_data.get(&TypeId::of::())) + .or_else(|| self.session_data.get(&TypeId::of::())) .or_else(|| self.schema_data.get(&TypeId::of::())) .and_then(|d| d.downcast_ref::()) } @@ -94,6 +101,36 @@ pub struct ResolveInfo<'a> { } /// Represents a GraphQL extension +/// +/// # Call order for query and mutation +/// +/// - start +/// - prepare_request +/// - parse_start +/// - parse_end +/// - validation_start +/// - validation_end +/// - execution_start +/// - resolve_start +/// - resolve_end +/// - result +/// - execution_end +/// - end +/// +/// # Call order for subscription +/// +/// - start +/// - prepare_request +/// - parse_start +/// - parse_end +/// - validation_start +/// - validation_end +/// - execution_start +/// - resolve_start +/// - resolve_end +/// - execution_end +/// - result +/// ``` #[async_trait::async_trait] #[allow(unused_variables)] pub trait Extension: Sync + Send + 'static { @@ -102,7 +139,13 @@ pub trait Extension: Sync + Send + 'static { None } - /// Called at the prepare request + /// Called at the beginning of query. + fn start(&mut self, ctx: &ExtensionContext<'_>) {} + + /// Called at the beginning of query. + fn end(&mut self, ctx: &ExtensionContext<'_>) {} + + /// Called at prepare request. async fn prepare_request( &mut self, ctx: &ExtensionContext<'_>, @@ -111,7 +154,7 @@ pub trait Extension: Sync + Send + 'static { Ok(request) } - /// Called at the begin of the parse. + /// Called at the beginning of parse query source. fn parse_start( &mut self, ctx: &ExtensionContext<'_>, @@ -120,54 +163,54 @@ pub trait Extension: Sync + Send + 'static { ) { } - /// Called at the end of the parse. + /// Called at the end of parse query source. fn parse_end(&mut self, ctx: &ExtensionContext<'_>, document: &ExecutableDocument) {} - /// Called at the begin of the validation. + /// Called at the beginning of the validation. fn validation_start(&mut self, ctx: &ExtensionContext<'_>) {} /// Called at the end of the validation. fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) {} - /// Called at the begin of the execution. + /// Called at the beginning of execute a query. fn execution_start(&mut self, ctx: &ExtensionContext<'_>) {} - /// Called at the end of the execution. + /// Called at the end of execute a query. fn execution_end(&mut self, ctx: &ExtensionContext<'_>) {} - /// Called at the begin of the resolve field. + /// Called at the beginning of resolve a field. fn resolve_start(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {} - /// Called at the end of the resolve field. + /// Called at the end of resolve a field. fn resolve_end(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {} /// Called when an error occurs. fn error(&mut self, ctx: &ExtensionContext<'_>, err: &ServerError) {} - /// Get the results + /// Get the results. fn result(&mut self, ctx: &ExtensionContext<'_>) -> Option { None } } pub(crate) trait ErrorLogger { - fn log_error(self, ctx: &ExtensionContext<'_>, extensions: &Extensions) -> Self; + fn log_error(self, extensions: &Extensions) -> Self; } impl ErrorLogger for ServerResult { - fn log_error(self, ctx: &ExtensionContext<'_>, extensions: &Extensions) -> Self { + fn log_error(self, extensions: &Extensions) -> Self { if let Err(err) = &self { - extensions.error(ctx, err); + extensions.error(err); } self } } impl ErrorLogger for Result> { - fn log_error(self, ctx: &ExtensionContext<'_>, extensions: &Extensions) -> Self { + fn log_error(self, extensions: &Extensions) -> Self { if let Err(errors) = &self { for error in errors { - extensions.error(ctx, error); + extensions.error(error); } } self @@ -183,113 +226,160 @@ pub trait ExtensionFactory: Send + Sync + 'static { } #[doc(hidden)] -pub struct Extensions(Option>>); +pub struct Extensions { + extensions: Option>>, + schema_env: SchemaEnv, + session_data: Arc, + query_data: Option>, +} -impl From> for Extensions { - fn from(extensions: Vec) -> Self { - Self(if extensions.is_empty() { - None - } else { - Some(spin::Mutex::new(extensions)) - }) +#[doc(hidden)] +impl Extensions { + pub fn new( + extensions: Vec, + schema_env: SchemaEnv, + session_data: Arc, + ) -> Self { + Extensions { + extensions: if extensions.is_empty() { + None + } else { + Some(spin::Mutex::new(extensions)) + }, + schema_env, + session_data, + query_data: None, + } + } + + pub fn attach_query_data(&mut self, data: Arc) { + self.query_data = Some(data); + } +} + +impl Drop for Extensions { + fn drop(&mut self) { + self.end(); } } #[doc(hidden)] impl Extensions { - pub fn is_empty(&self) -> bool { - self.0.is_none() + #[inline] + fn context(&self) -> ExtensionContext<'_> { + ExtensionContext { + schema_data: &self.schema_env.data, + session_data: &self.session_data, + query_data: self.query_data.as_deref(), + } } - pub async fn prepare_request( - &mut self, - ctx: &ExtensionContext<'_>, - request: Request, - ) -> ServerResult { + pub fn is_empty(&self) -> bool { + self.extensions.is_none() + } + + pub fn start(&self) { + if let Some(e) = &self.extensions { + e.lock().iter_mut().for_each(|e| e.start(&self.context())); + } + } + + pub fn end(&self) { + if let Some(e) = &self.extensions { + e.lock().iter_mut().for_each(|e| e.end(&self.context())); + } + } + + pub async fn prepare_request(&self, request: Request) -> ServerResult { let mut request = request; - if let Some(e) = &mut self.0 { - for e in e.get_mut().iter_mut() { - request = e.prepare_request(ctx, request).await?; + if let Some(e) = &self.extensions { + for e in e.lock().iter_mut() { + request = e.prepare_request(&self.context(), request).await?; } } Ok(request) } - pub fn parse_start( - &mut self, - ctx: &ExtensionContext<'_>, - query_source: &str, - variables: &Variables, - ) { - if let Some(e) = &mut self.0 { - e.get_mut() - .iter_mut() - .for_each(|e| e.parse_start(ctx, query_source, variables)); - } - } - - pub fn parse_end(&mut self, ctx: &ExtensionContext<'_>, document: &ExecutableDocument) { - if let Some(e) = &mut self.0 { - e.get_mut() - .iter_mut() - .for_each(|e| e.parse_end(ctx, document)); - } - } - - pub fn validation_start(&mut self, ctx: &ExtensionContext<'_>) { - if let Some(e) = &mut self.0 { - e.get_mut().iter_mut().for_each(|e| e.validation_start(ctx)); - } - } - - pub fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) { - if let Some(e) = &mut self.0 { - e.get_mut() - .iter_mut() - .for_each(|e| e.validation_end(ctx, result)); - } - } - - pub fn execution_start(&self, ctx: &ExtensionContext<'_>) { - if let Some(e) = &self.0 { - e.lock().iter_mut().for_each(|e| e.execution_start(ctx)); - } - } - - pub fn execution_end(&self, ctx: &ExtensionContext<'_>) { - if let Some(e) = &self.0 { - e.lock().iter_mut().for_each(|e| e.execution_end(ctx)); - } - } - - pub fn resolve_start(&self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { - if let Some(e) = &self.0 { - e.lock().iter_mut().for_each(|e| e.resolve_start(ctx, info)); - } - } - - pub fn resolve_end(&self, ctx: &ExtensionContext<'_>, resolve_id: &ResolveInfo<'_>) { - if let Some(e) = &self.0 { + pub fn parse_start(&self, query_source: &str, variables: &Variables) { + if let Some(e) = &self.extensions { e.lock() .iter_mut() - .for_each(|e| e.resolve_end(ctx, resolve_id)); + .for_each(|e| e.parse_start(&self.context(), query_source, variables)); } } - pub fn error(&self, ctx: &ExtensionContext<'_>, err: &ServerError) { - if let Some(e) = &self.0 { - e.lock().iter_mut().for_each(|e| e.error(ctx, err)); + pub fn parse_end(&self, document: &ExecutableDocument) { + if let Some(e) = &self.extensions { + e.lock() + .iter_mut() + .for_each(|e| e.parse_end(&self.context(), document)); } } - pub fn result(&self, ctx: &ExtensionContext<'_>) -> Option { - if let Some(e) = &self.0 { + pub fn validation_start(&self) { + if let Some(e) = &self.extensions { + e.lock() + .iter_mut() + .for_each(|e| e.validation_start(&self.context())); + } + } + + pub fn validation_end(&self, result: &ValidationResult) { + if let Some(e) = &self.extensions { + e.lock() + .iter_mut() + .for_each(|e| e.validation_end(&self.context(), result)); + } + } + + pub fn execution_start(&self) { + if let Some(e) = &self.extensions { + e.lock() + .iter_mut() + .for_each(|e| e.execution_start(&self.context())); + } + } + + pub fn execution_end(&self) { + if let Some(e) = &self.extensions { + e.lock() + .iter_mut() + .for_each(|e| e.execution_end(&self.context())); + } + } + + pub fn resolve_start(&self, info: &ResolveInfo<'_>) { + if let Some(e) = &self.extensions { + e.lock() + .iter_mut() + .for_each(|e| e.resolve_start(&self.context(), info)); + } + } + + pub fn resolve_end(&self, resolve_id: &ResolveInfo<'_>) { + if let Some(e) = &self.extensions { + e.lock() + .iter_mut() + .for_each(|e| e.resolve_end(&self.context(), resolve_id)); + } + } + + pub fn error(&self, err: &ServerError) { + if let Some(e) = &self.extensions { + e.lock() + .iter_mut() + .for_each(|e| e.error(&self.context(), err)); + } + } + + pub fn result(&self) -> Option { + if let Some(e) = &self.extensions { let value = e .lock() .iter_mut() .filter_map(|e| { if let Some(name) = e.name() { - e.result(ctx).map(|res| (Name::new(name), res)) + e.result(&self.context()).map(|res| (Name::new(name), res)) } else { None } diff --git a/src/extensions/opentelemetry.rs b/src/extensions/opentelemetry.rs index 722d2d00..8392e4bd 100644 --- a/src/extensions/opentelemetry.rs +++ b/src/extensions/opentelemetry.rs @@ -1,9 +1,10 @@ use std::collections::HashMap; +use std::sync::Arc; use async_graphql_parser::types::ExecutableDocument; use async_graphql_value::Variables; use opentelemetry::trace::{SpanKind, TraceContextExt, Tracer}; -use opentelemetry::{Context as OpenTelemetryContext, Key, KeyValue}; +use opentelemetry::{Context as OpenTelemetryContext, Key}; use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; use crate::{ServerError, ValidationResult}; @@ -27,42 +28,41 @@ const KEY_ERROR: Key = Key::from_static_str("graphql.error"); const KEY_COMPLEXITY: Key = Key::from_static_str("graphql.complexity"); const KEY_DEPTH: Key = Key::from_static_str("graphql.depth"); -/// OpenTelemetry extension +/// OpenTelemetry extension configuration for each request. #[derive(Default)] #[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))] -pub struct OpenTelemetry { - tracer: T, - #[allow(dead_code)] - uninstall: U, +pub struct OpenTelemetryConfig { + /// Use a context as the parent node of the entire query. + parent: spin::Mutex>, } -impl OpenTelemetry { +impl OpenTelemetryConfig { + /// Use a context as the parent of the entire query. + pub fn parent_context(mut self, cx: OpenTelemetryContext) -> Self { + *self.parent.get_mut() = Some(cx); + self + } +} + +/// OpenTelemetry extension +#[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))] +pub struct OpenTelemetry { + tracer: Arc, +} + +impl OpenTelemetry { /// Use `tracer` to create an OpenTelemetry extension. - pub fn new(tracer: T) -> OpenTelemetry + pub fn new(tracer: T) -> OpenTelemetry where - T: Tracer + Send + Sync + Clone, + T: Tracer + Send + Sync, { Self { - tracer, - uninstall: (), - } - } - - /// Attach an Uninstall instance to this extension. - pub fn with_uninstall(self, uninstall: U) -> OpenTelemetry - where - U: Send + Sync + 'static, - { - OpenTelemetry { - tracer: self.tracer, - uninstall, + tracer: Arc::new(tracer), } } } -impl ExtensionFactory - for OpenTelemetry -{ +impl ExtensionFactory for OpenTelemetry { fn create(&self) -> Box { Box::new(OpenTelemetryExtension { tracer: self.tracer.clone(), @@ -72,7 +72,7 @@ impl ExtensionFactory } struct OpenTelemetryExtension { - tracer: T, + tracer: Arc, contexts: HashMap, } @@ -92,36 +92,46 @@ impl OpenTelemetryExtension { } } -impl Extension for OpenTelemetryExtension { +impl Extension for OpenTelemetryExtension { + fn start(&mut self, ctx: &ExtensionContext<'_>) { + let request_cx = ctx + .data_opt::() + .and_then(|cfg| cfg.parent.lock().take()) + .unwrap_or_else(|| { + OpenTelemetryContext::current_with_span( + self.tracer + .span_builder("request") + .with_kind(SpanKind::Server) + .start(&*self.tracer), + ) + }); + self.enter_context(REQUEST_CTX, request_cx); + } + + fn end(&mut self, _ctx: &ExtensionContext<'_>) { + self.exit_context(REQUEST_CTX); + } + fn parse_start( &mut self, _ctx: &ExtensionContext<'_>, query_source: &str, variables: &Variables, ) { - let request_span = self - .tracer - .span_builder("request") - .with_kind(SpanKind::Server) - .start(&self.tracer); - let request_cx = OpenTelemetryContext::current_with_span(request_span); - self.enter_context(REQUEST_CTX, request_cx.clone()); - - let mut attributes = Vec::with_capacity(2); - attributes.push(KeyValue::new(KEY_SOURCE, query_source.to_string())); - attributes.push(KeyValue::new( - KEY_VARIABLES, - serde_json::to_string(variables).unwrap(), - )); - let parse_span = self - .tracer - .span_builder("parse") - .with_kind(SpanKind::Server) - .with_attributes(attributes) - .with_parent_context(request_cx) - .start(&self.tracer); - let parse_cx = OpenTelemetryContext::current_with_span(parse_span); - self.enter_context(PARSE_CTX, parse_cx); + if let Some(parent_cx) = self.contexts.get(&REQUEST_CTX).cloned() { + let mut attributes = Vec::with_capacity(2); + attributes.push(KEY_SOURCE.string(query_source.to_string())); + attributes.push(KEY_VARIABLES.string(serde_json::to_string(variables).unwrap())); + let parse_span = self + .tracer + .span_builder("parse") + .with_kind(SpanKind::Server) + .with_attributes(attributes) + .with_parent_context(parent_cx) + .start(&*self.tracer); + let parse_cx = OpenTelemetryContext::current_with_span(parse_span); + self.enter_context(PARSE_CTX, parse_cx); + } } fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, _document: &ExecutableDocument) { @@ -135,7 +145,7 @@ impl Extension for OpenTelemetryExtension { .span_builder("validation") .with_kind(SpanKind::Server) .with_parent_context(parent_cx) - .start(&self.tracer); + .start(&*self.tracer); let validation_cx = OpenTelemetryContext::current_with_span(span); self.enter_context(VALIDATION_CTX, validation_cx); } @@ -144,8 +154,8 @@ impl Extension for OpenTelemetryExtension { fn validation_end(&mut self, _ctx: &ExtensionContext<'_>, result: &ValidationResult) { if let Some(validation_cx) = self.exit_context(VALIDATION_CTX) { let span = validation_cx.span(); - span.set_attribute(KeyValue::new(KEY_COMPLEXITY, result.complexity as i64)); - span.set_attribute(KeyValue::new(KEY_DEPTH, result.depth as i64)); + span.set_attribute(KEY_COMPLEXITY.i64(result.complexity as i64)); + span.set_attribute(KEY_DEPTH.i64(result.depth as i64)); } } @@ -156,12 +166,12 @@ impl Extension for OpenTelemetryExtension { .span_builder("execute") .with_kind(SpanKind::Server) .with_parent_context(parent_cx) - .start(&self.tracer), + .start(&*self.tracer), None => self .tracer .span_builder("execute") .with_kind(SpanKind::Server) - .start(&self.tracer), + .start(&*self.tracer), }; let execute_cx = OpenTelemetryContext::current_with_span(span); self.enter_context(EXECUTE_CTX, execute_cx); @@ -169,7 +179,6 @@ impl Extension for OpenTelemetryExtension { fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) { self.exit_context(EXECUTE_CTX); - self.exit_context(REQUEST_CTX); } fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { @@ -181,19 +190,16 @@ impl Extension for OpenTelemetryExtension { if let Some(parent_cx) = parent_cx { let mut attributes = Vec::with_capacity(3); - attributes.push(KeyValue::new( - KEY_RESOLVE_ID, - info.resolve_id.current as i64, - )); - attributes.push(KeyValue::new(KEY_PARENT_TYPE, info.parent_type.to_string())); - attributes.push(KeyValue::new(KEY_RETURN_TYPE, info.return_type.to_string())); + attributes.push(KEY_RESOLVE_ID.i64(info.resolve_id.current as i64)); + attributes.push(KEY_PARENT_TYPE.string(info.parent_type.to_string())); + attributes.push(KEY_RETURN_TYPE.string(info.return_type.to_string())); let span = self .tracer .span_builder(&info.path_node.to_string()) .with_kind(SpanKind::Server) .with_parent_context(parent_cx) .with_attributes(attributes) - .start(&self.tracer); + .start(&*self.tracer); let resolve_cx = OpenTelemetryContext::current_with_span(span); self.enter_context(resolve_ctx_id(info.resolve_id.current), resolve_cx); } @@ -205,10 +211,9 @@ impl Extension for OpenTelemetryExtension { fn error(&mut self, _ctx: &ExtensionContext<'_>, err: &ServerError) { if let Some(parent_cx) = self.contexts.get(&EXECUTE_CTX).cloned() { - parent_cx.span().add_event( - "error".to_string(), - vec![KeyValue::new(KEY_ERROR, err.to_string())], - ); + parent_cx + .span() + .add_event("error".to_string(), vec![KEY_ERROR.string(err.to_string())]); } } } diff --git a/src/extensions/tracing.rs b/src/extensions/tracing.rs index 762fbe24..0b278dbf 100644 --- a/src/extensions/tracing.rs +++ b/src/extensions/tracing.rs @@ -6,6 +6,22 @@ use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveIn use crate::parser::types::ExecutableDocument; use crate::{ServerError, ValidationResult, Variables}; +/// Tracing extension configuration for each request. +#[derive(Default)] +#[cfg_attr(docsrs, doc(cfg(feature = "tracing")))] +pub struct TracingConfig { + /// Use a span as the parent node of the entire query. + parent: spin::Mutex>, +} + +impl TracingConfig { + /// Use a span as the parent of the entire query. + pub fn parent_span(mut self, span: Span) -> Self { + *self.parent.get_mut() = Some(span); + self + } +} + const REQUEST_CTX: usize = 0; const PARSE_CTX: usize = 1; const VALIDATION_CTX: usize = 2; @@ -26,7 +42,7 @@ fn resolve_span_id(resolver_id: usize) -> usize { /// /// ```no_run /// use async_graphql::*; -/// use async_graphql::extensions::Tracing; +/// use async_graphql::extensions::{Tracing, TracingConfig}; /// use tracing::{span, Level, Instrument}; /// /// #[derive(SimpleObject)] @@ -42,7 +58,7 @@ fn resolve_span_id(resolver_id: usize) -> usize { /// schema.execute(Request::new("{ value }")).await; /// }); /// -/// // tracing in custom parent span +/// // tracing in parent span /// tokio::runtime::Runtime::new().unwrap().block_on(async { /// let root_span = span!( /// parent: None, @@ -51,6 +67,16 @@ fn resolve_span_id(resolver_id: usize) -> usize { /// ); /// schema.execute(Request::new("{ value }")).instrument(root_span).await; /// }); +/// +/// // replace root span +/// tokio::runtime::Runtime::new().unwrap().block_on(async { +/// let root_span = span!( +/// parent: None, +/// Level::INFO, +/// "span root" +/// ); +/// schema.execute(Request::new("{ value }").data(TracingConfig::default().parent_span(root_span))).await; +/// }); /// ``` #[derive(Default)] #[cfg_attr(docsrs, doc(cfg(feature = "tracing")))] @@ -82,30 +108,42 @@ impl TracingExtension { } impl Extension for TracingExtension { + fn start(&mut self, ctx: &ExtensionContext<'_>) { + let request_span = ctx + .data_opt::() + .and_then(|cfg| cfg.parent.lock().take()) + .unwrap_or_else(|| { + span!( + target: "async_graphql::graphql", + Level::INFO, + "request", + ) + }); + self.enter_span(REQUEST_CTX, request_span); + } + + fn end(&mut self, _ctx: &ExtensionContext<'_>) { + self.exit_span(REQUEST_CTX); + } + fn parse_start( &mut self, _ctx: &ExtensionContext<'_>, query_source: &str, variables: &Variables, ) { - let request_span = span!( - target: "async_graphql::graphql", - Level::INFO, - "request", - ); - - let variables = serde_json::to_string(&variables).unwrap(); - let parse_span = span!( - target: "async_graphql::graphql", - parent: &request_span, - Level::INFO, - "parse", - source = query_source, - variables = %variables, - ); - - self.enter_span(REQUEST_CTX, request_span); - self.enter_span(PARSE_CTX, parse_span); + if let Some(parent) = self.spans.get(&REQUEST_CTX) { + let variables = serde_json::to_string(&variables).unwrap(); + let parse_span = span!( + target: "async_graphql::graphql", + parent: parent, + Level::INFO, + "parse", + source = query_source, + variables = %variables, + ); + self.enter_span(PARSE_CTX, parse_span); + } } fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, _document: &ExecutableDocument) { @@ -129,27 +167,19 @@ impl Extension for TracingExtension { } fn execution_start(&mut self, _ctx: &ExtensionContext<'_>) { - let span = match self.spans.get(&REQUEST_CTX) { - Some(parent) => span!( + if let Some(parent) = self.spans.get(&REQUEST_CTX) { + let span = span!( target: "async_graphql::graphql", parent: parent, Level::INFO, "execute" - ), - None => span!( - target: "async_graphql::graphql", - parent: None, - Level::INFO, - "execute" - ), + ); + self.enter_span(EXECUTE_CTX, span); }; - - self.enter_span(EXECUTE_CTX, span); } fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) { self.exit_span(EXECUTE_CTX); - self.exit_span(REQUEST_CTX); } fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { diff --git a/src/http/websocket.rs b/src/http/websocket.rs index feb2fafe..35ac64e9 100644 --- a/src/http/websocket.rs +++ b/src/http/websocket.rs @@ -177,7 +177,9 @@ where if let Some(data) = this.data.clone() { this.streams.insert( id, - Box::pin(this.schema.execute_stream_with_ctx_data(request, data)), + Box::pin( + this.schema.execute_stream_with_session_data(request, data), + ), ); } else { return Poll::Ready(Some(WsMessage::Close( diff --git a/src/resolver_utils/container.rs b/src/resolver_utils/container.rs index 49c2bf59..e670f6b9 100644 --- a/src/resolver_utils/container.rs +++ b/src/resolver_utils/container.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use std::future::Future; use std::pin::Pin; -use crate::extensions::{ErrorLogger, ExtensionContext, ResolveInfo}; +use crate::extensions::{ErrorLogger, ResolveInfo}; use crate::parser::types::Selection; use crate::registry::MetaType; use crate::{ @@ -183,11 +183,6 @@ impl<'a> Fields<'a> { } }? } else { - let ctx_extension = ExtensionContext { - schema_data: &ctx.schema_env.data, - query_data: &ctx.query_env.ctx_data, - }; - let type_name = T::type_name(); let resolve_info = ResolveInfo { resolve_id: ctx_field.resolve_id, @@ -215,10 +210,7 @@ impl<'a> Fields<'a> { }, }; - ctx_field - .query_env - .extensions - .resolve_start(&ctx_extension, &resolve_info); + ctx_field.query_env.extensions.resolve_start(&resolve_info); let res = match root.resolve_field(&ctx_field).await { Ok(value) => Ok((field_name, value.unwrap_or_default())), @@ -226,13 +218,8 @@ impl<'a> Fields<'a> { Err(e.path(PathSegment::Field(field_name.to_string()))) } } - .log_error(&ctx_extension, &ctx_field.query_env.extensions)?; - - ctx_field - .query_env - .extensions - .resolve_end(&ctx_extension, &resolve_info); - + .log_error(&ctx_field.query_env.extensions)?; + ctx_field.query_env.extensions.resolve_end(&resolve_info); res }; diff --git a/src/resolver_utils/list.rs b/src/resolver_utils/list.rs index d75a6572..ced30ea9 100644 --- a/src/resolver_utils/list.rs +++ b/src/resolver_utils/list.rs @@ -1,4 +1,4 @@ -use crate::extensions::{ErrorLogger, ExtensionContext, ResolveInfo}; +use crate::extensions::{ErrorLogger, ResolveInfo}; use crate::parser::types::Field; use crate::{ContextSelectionSet, OutputType, PathSegment, Positioned, ServerResult, Type, Value}; @@ -14,16 +14,11 @@ pub async fn resolve_list<'a, T: OutputType + 'a>( for (idx, item) in iter.into_iter().enumerate() { let ctx_idx = ctx.with_index(idx); futures.push(async move { - let ctx_extension = ExtensionContext { - schema_data: &ctx.schema_env.data, - query_data: &ctx.query_env.ctx_data, - }; - if ctx_idx.query_env.extensions.is_empty() { OutputType::resolve(&item, &ctx_idx, field) .await .map_err(|e| e.path(PathSegment::Index(idx))) - .log_error(&ctx_extension, &ctx_idx.query_env.extensions) + .log_error(&ctx_idx.query_env.extensions) } else { let resolve_info = ResolveInfo { resolve_id: ctx_idx.resolve_id, @@ -32,21 +27,14 @@ pub async fn resolve_list<'a, T: OutputType + 'a>( return_type: &T::qualified_type_name(), }; - ctx_idx - .query_env - .extensions - .resolve_start(&ctx_extension, &resolve_info); + ctx_idx.query_env.extensions.resolve_start(&resolve_info); let res = OutputType::resolve(&item, &ctx_idx, field) .await .map_err(|e| e.path(PathSegment::Index(idx))) - .log_error(&ctx_extension, &ctx_idx.query_env.extensions)?; - - ctx_idx - .query_env - .extensions - .resolve_end(&ctx_extension, &resolve_info); + .log_error(&ctx_idx.query_env.extensions)?; + ctx_idx.query_env.extensions.resolve_end(&resolve_info); Ok(res) } }); diff --git a/src/schema.rs b/src/schema.rs index 549fc1a1..b0b1b205 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -8,7 +8,7 @@ use futures_util::stream::{self, Stream, StreamExt}; use indexmap::map::IndexMap; use crate::context::{Data, QueryEnvInner, ResolveId}; -use crate::extensions::{ErrorLogger, ExtensionContext, ExtensionFactory, Extensions}; +use crate::extensions::{ErrorLogger, ExtensionFactory, Extensions}; use crate::model::__DirectiveLocation; use crate::parser::parse_query; use crate::parser::types::{DocumentOperations, OperationType}; @@ -359,57 +359,59 @@ where self.0.env.registry.names() } + fn create_extensions(&self, session_data: Arc) -> Extensions { + let extensions = Extensions::new( + self.0 + .extensions + .iter() + .map(|factory| factory.create()) + .collect::>(), + self.env.clone(), + session_data, + ); + extensions.start(); + extensions + } + async fn prepare_request( &self, + mut extensions: Extensions, request: Request, - ) -> Result<(QueryEnvInner, CacheControl), Vec> { - // create extension instances - let mut extensions: Extensions = self - .0 - .extensions - .iter() - .map(|factory| factory.create()) - .collect::>() - .into(); - + session_data: Arc, + ) -> Result<(QueryEnv, CacheControl), Vec> { let mut request = request; - let data = std::mem::take(&mut request.data); - let ctx_extension = ExtensionContext { - schema_data: &self.env.data, - query_data: &data, - }; + let query_data = Arc::new(std::mem::take(&mut request.data)); + extensions.attach_query_data(query_data.clone()); + let request = extensions.prepare_request(request).await?; - let request = extensions.prepare_request(&ctx_extension, request).await?; - - extensions.parse_start(&ctx_extension, &request.query, &request.variables); + extensions.parse_start(&request.query, &request.variables); let document = parse_query(&request.query) .map_err(Into::::into) - .log_error(&ctx_extension, &extensions)?; - extensions.parse_end(&ctx_extension, &document); + .log_error(&extensions)?; + extensions.parse_end(&document); // check rules - extensions.validation_start(&ctx_extension); + extensions.validation_start(); let validation_result = check_rules( &self.env.registry, &document, Some(&request.variables), self.validation_mode, ) - .log_error(&ctx_extension, &extensions)?; - extensions.validation_end(&ctx_extension, &validation_result); + .log_error(&extensions)?; + extensions.validation_end(&validation_result); // check limit if let Some(limit_complexity) = self.complexity { if validation_result.complexity > limit_complexity { - return Err(vec![ServerError::new("Query is too complex.")]) - .log_error(&ctx_extension, &extensions); + return Err(vec![ServerError::new("Query is too complex.")]).log_error(&extensions); } } if let Some(limit_depth) = self.depth { if validation_result.depth > limit_depth { return Err(vec![ServerError::new("Query is nested too deep.")]) - .log_error(&ctx_extension, &extensions); + .log_error(&extensions); } } @@ -437,7 +439,7 @@ where let operation = match operation { Ok(operation) => operation, Err(e) => { - extensions.error(&ctx_extension, &e); + extensions.error(&e); return Err(vec![e]); } }; @@ -448,10 +450,11 @@ where operation, fragments: document.fragments, uploads: request.uploads, - ctx_data: Arc::new(data), + session_data, + ctx_data: query_data, http_headers: Default::default(), }; - Ok((env, validation_result.cache_control)) + Ok((QueryEnv::new(env), validation_result.cache_control)) } async fn execute_once(&self, env: QueryEnv) -> Response { @@ -465,12 +468,7 @@ where schema_env: &self.env, query_env: &env, }; - let ctx_extension = ExtensionContext { - schema_data: &self.env.data, - query_data: &env.ctx_data, - }; - - env.extensions.execution_start(&ctx_extension); + env.extensions.execution_start(); let data = match &env.operation.node.ty { OperationType::Query => resolve_container(&ctx, &self.query).await, @@ -482,8 +480,8 @@ where } }; - env.extensions.execution_end(&ctx_extension); - let extensions = env.extensions.result(&ctx_extension); + env.extensions.execution_end(); + let extensions = env.extensions.result(); match data { Ok(data) => Response::new(data), @@ -496,11 +494,12 @@ where /// Execute a GraphQL query. pub async fn execute(&self, request: impl Into) -> Response { let request = request.into(); - match self.prepare_request(request).await { - Ok((env, cache_control)) => self - .execute_once(QueryEnv::new(env)) - .await - .cache_control(cache_control), + let extensions = self.create_extensions(Default::default()); + match self + .prepare_request(extensions, request, Default::default()) + .await + { + Ok((env, cache_control)) => self.execute_once(env).await.cache_control(cache_control), Err(errors) => Response::from_errors(errors), } } @@ -518,24 +517,25 @@ where } } - pub(crate) fn execute_stream_with_ctx_data( + /// Execute a GraphQL subscription with session data. + #[doc(hidden)] + pub fn execute_stream_with_session_data( &self, request: impl Into + Send, - ctx_data: Arc, + session_data: Arc, ) -> impl Stream + Send { let schema = self.clone(); + let request = request.into(); + let extensions = self.create_extensions(session_data.clone()); async_stream::stream! { - let request = request.into(); - let (mut env, cache_control) = match schema.prepare_request(request).await { + let (env, cache_control) = match schema.prepare_request(extensions, request, session_data).await { Ok(res) => res, Err(errors) => { yield Response::from_errors(errors); return; } }; - env.ctx_data = ctx_data; - let env = QueryEnv::new(env); if env.operation.node.ty != OperationType::Subscription { yield schema @@ -553,26 +553,18 @@ where ResolveId::root(), &resolve_id, ); - let ctx_extension = ExtensionContext { - schema_data: &schema.env.data, - query_data: &env.ctx_data, - }; - - env.extensions.execution_start(&ctx_extension); let mut streams = Vec::new(); if let Err(e) = collect_subscription_streams(&ctx, &schema.subscription, &mut streams) { - env.extensions.execution_end(&ctx_extension); + env.extensions.execution_end(); yield Response::from_errors(vec![e]); return; } - env.extensions.execution_end(&ctx_extension); - let mut stream = stream::select_all(streams); while let Some(data) = stream.next().await { let is_err = data.is_err(); - let extensions = env.extensions.result(&ctx_extension); + let extensions = env.extensions.result(); yield match data { Ok((name, value)) => { let mut map = BTreeMap::new(); @@ -593,8 +585,6 @@ where &self, request: impl Into, ) -> impl Stream + Send { - let mut request = request.into(); - let ctx_data = std::mem::take(&mut request.data); - self.execute_stream_with_ctx_data(request, Arc::new(ctx_data)) + self.execute_stream_with_session_data(request.into(), Default::default()) } } diff --git a/src/subscription.rs b/src/subscription.rs index 772e2810..4cafd4da 100644 --- a/src/subscription.rs +++ b/src/subscription.rs @@ -38,13 +38,7 @@ pub(crate) fn collect_subscription_streams<'a, T: SubscriptionType + 'static>( let ctx = ctx.clone(); async_stream::stream! { let ctx = ctx.with_field(field); - let field_name = ctx - .item - .node - .response_key() - .node - .clone(); - + let field_name = ctx.item.node.response_key().node.clone(); let stream = root.create_field_stream(&ctx); if let Some(mut stream) = stream { while let Some(item) = stream.next().await { diff --git a/tests/extension.rs b/tests/extension.rs index e2104810..6f9c263a 100644 --- a/tests/extension.rs +++ b/tests/extension.rs @@ -1,8 +1,13 @@ -use async_graphql::extensions::{Extension, ExtensionContext, ExtensionFactory}; -use async_graphql::*; -use spin::Mutex; use std::sync::Arc; +use async_graphql::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; +use async_graphql::parser::types::ExecutableDocument; +use async_graphql::*; +use async_graphql_value::ConstValue; +use futures_util::stream::Stream; +use futures_util::StreamExt; +use spin::Mutex; + #[tokio::test] pub async fn test_extension_ctx() { #[derive(Default, Clone)] @@ -12,8 +17,18 @@ pub async fn test_extension_ctx() { #[Object] impl Query { - async fn value(&self) -> bool { - true + async fn value(&self, ctx: &Context<'_>) -> i32 { + *ctx.data_unchecked::().0.lock() + } + } + + struct Subscription; + + #[Subscription] + impl Subscription { + async fn value(&self, ctx: &Context<'_>) -> impl Stream { + let data = *ctx.data_unchecked::().0.lock(); + futures_util::stream::once(async move { data }) } } @@ -27,7 +42,9 @@ pub async fn test_extension_ctx() { _query_source: &str, _variables: &Variables, ) { - *ctx.data_unchecked::().0.lock() = 100; + if let Ok(data) = ctx.data::() { + *data.0.lock() = 100; + } } } @@ -46,9 +63,17 @@ pub async fn test_extension_ctx() { .data(data.clone()) .extension(MyExtension) .finish(); - - schema.execute("{ value }").await.into_result().unwrap(); - assert_eq!(*data.0.lock(), 100); + assert_eq!( + schema + .execute("{ value }") + .await + .into_result() + .unwrap() + .data, + value! ({ + "value": 100 + }) + ); } // data in request @@ -58,11 +83,219 @@ pub async fn test_extension_ctx() { .extension(MyExtension) .finish(); - schema - .execute(Request::new("{ value }").data(data.clone())) + assert_eq!( + schema + .execute(Request::new("{ value }").data(data.clone())) + .await + .into_result() + .unwrap() + .data, + value! ({ + "value": 100 + }) + ); + } + + // data in session + { + let schema = Schema::build(Query, EmptyMutation, Subscription) + .extension(MyExtension) + .finish(); + + let mut data = Data::default(); + data.insert(MyData::default()); + let mut stream = schema + .execute_stream_with_session_data( + Request::new("subscription { value }"), + Arc::new(data), + ) + .boxed(); + assert_eq!( + stream.next().await.unwrap().into_result().unwrap().data, + value! ({ + "value": 100 + }) + ); + } +} + +#[tokio::test] +pub async fn test_extension_call_order() { + struct MyExtensionImpl { + calls: Arc>>, + } + + #[async_trait::async_trait] + #[allow(unused_variables)] + impl Extension for MyExtensionImpl { + fn name(&self) -> Option<&'static str> { + Some("test") + } + + fn start(&mut self, ctx: &ExtensionContext<'_>) { + self.calls.lock().push("start"); + } + + fn end(&mut self, ctx: &ExtensionContext<'_>) { + self.calls.lock().push("end"); + } + + async fn prepare_request( + &mut self, + ctx: &ExtensionContext<'_>, + request: Request, + ) -> ServerResult { + self.calls.lock().push("prepare_request"); + Ok(request) + } + + fn parse_start( + &mut self, + ctx: &ExtensionContext<'_>, + query_source: &str, + variables: &Variables, + ) { + self.calls.lock().push("parse_start"); + } + + fn parse_end(&mut self, ctx: &ExtensionContext<'_>, document: &ExecutableDocument) { + self.calls.lock().push("parse_end"); + } + + fn validation_start(&mut self, ctx: &ExtensionContext<'_>) { + self.calls.lock().push("validation_start"); + } + + fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) { + self.calls.lock().push("validation_end"); + } + + fn execution_start(&mut self, ctx: &ExtensionContext<'_>) { + self.calls.lock().push("execution_start"); + } + + fn execution_end(&mut self, ctx: &ExtensionContext<'_>) { + self.calls.lock().push("execution_end"); + } + + fn resolve_start(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { + self.calls.lock().push("resolve_start"); + } + + fn resolve_end(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { + self.calls.lock().push("resolve_end"); + } + + fn result(&mut self, ctx: &ExtensionContext<'_>) -> Option { + self.calls.lock().push("result"); + None + } + } + + struct MyExtension { + calls: Arc>>, + } + + impl ExtensionFactory for MyExtension { + fn create(&self) -> Box { + Box::new(MyExtensionImpl { + calls: self.calls.clone(), + }) + } + } + + struct Query; + + #[Object] + impl Query { + async fn value1(&self) -> i32 { + 10 + } + + async fn value2(&self) -> i32 { + 10 + } + } + + struct Subscription; + + #[Subscription] + impl Subscription { + async fn value(&self) -> impl Stream { + futures_util::stream::iter(vec![1, 2, 3]) + } + } + + { + let calls: Arc>> = Default::default(); + let schema = Schema::build(Query, EmptyMutation, Subscription) + .extension(MyExtension { + calls: calls.clone(), + }) + .finish(); + let _ = schema + .execute("{ value1 value2 }") .await .into_result() .unwrap(); - assert_eq!(*data.0.lock(), 100); + let calls = calls.lock(); + assert_eq!( + &*calls, + &vec![ + "start", + "prepare_request", + "parse_start", + "parse_end", + "validation_start", + "validation_end", + "execution_start", + "resolve_start", + "resolve_end", + "resolve_start", + "resolve_end", + "execution_end", + "result", + "end", + ] + ); + } + + { + let calls: Arc>> = Default::default(); + let schema = Schema::build(Query, EmptyMutation, Subscription) + .extension(MyExtension { + calls: calls.clone(), + }) + .finish(); + let mut stream = schema.execute_stream("subscription { value }").boxed(); + while let Some(_) = stream.next().await {} + let calls = calls.lock(); + assert_eq!( + &*calls, + &vec![ + "start", + "prepare_request", + "parse_start", + "parse_end", + "validation_start", + "validation_end", + "execution_start", + "resolve_start", + "resolve_end", + "execution_end", + "result", + "execution_start", + "resolve_start", + "resolve_end", + "execution_end", + "result", + "execution_start", + "resolve_start", + "resolve_end", + "execution_end", + "result", + "end", + ] + ); } }