Rework Extension & TracingExtension & OpenTelemetryExtension

This commit is contained in:
Sunli 2021-03-25 16:33:11 +08:00
parent 845f6a514a
commit cc937b70b5
21 changed files with 643 additions and 761 deletions

View File

@ -77,5 +77,4 @@ members = [
"integrations/rocket", "integrations/rocket",
"integrations/tide", "integrations/tide",
"integrations/warp", "integrations/warp",
"benchmark",
] ]

View File

@ -1,29 +0,0 @@
[package]
name = "graphql-benchmark"
version = "2.6.5"
authors = ["sunli <scott_s829@163.com>", "Koxiaet"]
edition = "2018"
[dependencies]
async-std = { version = "1.9.0", features = ["attributes"] }
futures = "0.3.13"
serde_json = "1.0.64"
async-graphql-parser = { path = "../parser" }
async-graphql = { path = ".." }
jemallocator = { version = "0.3.2", optional = true }
[dev-dependencies]
criterion = "0.3.4"
simple = { path = "simple" }
chat = { path = "chat" }
[features]
jemalloc = ["jemallocator"]
[[bench]]
name = "simple"
harness = false
[[bench]]
name = "chat"
harness = false

View File

@ -1,13 +0,0 @@
use chat::{Q, S};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use graphql_benchmark::{parse, run, serialize};
pub fn bench(c: &mut Criterion) {
c.bench_function("chat run", |b| b.iter(|| run(&S, black_box(Q))));
c.bench_function("chat parse", |b| b.iter(|| parse(black_box(Q))));
let res = run(&S, Q);
c.bench_function("chat serialize", |b| b.iter(|| serialize(black_box(&res))));
}
criterion_group!(chat, bench);
criterion_main!(chat);

View File

@ -1,15 +0,0 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use graphql_benchmark::{parse, run, serialize};
use simple::{Q, S};
pub fn bench(c: &mut Criterion) {
c.bench_function("simple run", |b| b.iter(|| run(&S, black_box(Q))));
c.bench_function("simple parse", |b| b.iter(|| parse(black_box(Q))));
let res = run(&S, Q);
c.bench_function("simple serialize", |b| {
b.iter(|| serialize(black_box(&res)))
});
}
criterion_group!(simple, bench);
criterion_main!(simple);

View File

@ -1,9 +0,0 @@
[package]
name = "chat"
version = "2.6.5"
authors = ["Ivan Plesskih <terma95@gmail.com>"]
edition = "2018"
[dependencies]
async-graphql = { path = "../.." }
lazy_static = "1.4.0"

View File

@ -1,249 +0,0 @@
use async_graphql::*;
pub struct ChatData {
pub id: String,
pub created_at: String,
pub title: String,
pub caption: String,
pub creator_user_id: String,
pub state: String,
}
pub struct UserData {
pub id: String,
pub is_operator: bool,
pub phone: u64,
pub join_date: String,
pub state: String,
}
pub struct ProfileData {
pub first_name: String,
pub last_name: String,
pub city: Option<String>,
pub job_title: Option<String>,
pub email: String,
}
pub struct MessageData {
pub id: String,
pub user_id: String,
pub timestamp: String,
pub edited: bool,
pub order: i32,
pub message: String,
}
lazy_static::lazy_static! {
pub static ref CHAT: ChatData = ChatData {
id: "1".to_string(),
created_at: "today".to_string(),
title: "chat".to_string(),
caption: "asdasd".to_string(),
creator_user_id: "123".to_string(),
state: "ACTIVE".to_string(),
};
pub static ref USER: UserData = UserData {
id: "123".to_string(),
is_operator: false,
phone: 79_123_273_936,
join_date: "today".to_string(),
state: "ACTIVE".to_string(),
};
pub static ref PROFILE: ProfileData = ProfileData {
first_name: "Ivan".to_string(),
last_name: "Plesskih".to_string(),
city: Some("Che".to_string()),
job_title: Some("progr".to_string()),
email: "asd@qwe.ru".to_string(),
};
pub static ref MESSAGE: MessageData = MessageData {
id: "456".to_string(),
user_id: "123".to_string(),
timestamp: "today".to_string(),
edited: false,
order: 123,
message: "Hello, world!".to_string(),
};
}
pub struct Chat;
#[Object]
impl Chat {
pub async fn id(&self) -> ID {
ID::from(&CHAT.id)
}
pub async fn messages(&self) -> Vec<Message> {
let mut res = vec![];
for _ in 0..30 {
res.push(Message);
}
res
}
pub async fn users(&self) -> Vec<User> {
let mut res = vec![];
for _ in 0..5 {
res.push(User);
}
res
}
pub async fn creator(&self) -> User {
User
}
#[graphql(name = "created_at")]
pub async fn created_at(&self) -> &String {
&CHAT.created_at
}
pub async fn title(&self) -> &String {
&CHAT.title
}
pub async fn caption(&self) -> &String {
&CHAT.caption
}
pub async fn state(&self) -> &String {
&CHAT.state
}
}
pub struct Message;
#[Object]
impl Message {
pub async fn id(&self) -> ID {
ID::from(&MESSAGE.id)
}
pub async fn user(&self) -> User {
User
}
pub async fn timestamp(&self) -> &String {
&MESSAGE.timestamp
}
pub async fn message(&self) -> &String {
&MESSAGE.message
}
pub async fn order(&self) -> i32 {
MESSAGE.order
}
pub async fn edited(&self) -> bool {
MESSAGE.edited
}
}
pub struct User;
#[Object]
impl User {
pub async fn id(&self) -> ID {
ID::from(&USER.id)
}
pub async fn profile(&self) -> Option<UserProfile> {
Some(UserProfile)
}
#[graphql(name = "is_operator")]
pub async fn is_operator(&self) -> bool {
USER.is_operator
}
pub async fn phone(&self) -> String {
USER.phone.to_string()
}
#[graphql(name = "join_date")]
pub async fn join_date(&self) -> &String {
&USER.join_date
}
pub async fn state(&self) -> &String {
&USER.state
}
}
pub struct UserProfile;
#[Object]
impl UserProfile {
pub async fn email(&self) -> &String {
&PROFILE.email
}
#[graphql(name = "first_name")]
pub async fn first_name(&self) -> &String {
&PROFILE.first_name
}
#[graphql(name = "last_name")]
pub async fn last_name(&self) -> &String {
&PROFILE.last_name
}
#[graphql(name = "job_title")]
pub async fn job_title(&self) -> &Option<String> {
&PROFILE.job_title
}
pub async fn city(&self) -> &Option<String> {
&PROFILE.city
}
}
pub struct Query;
#[Object]
impl Query {
async fn chats(&self) -> Vec<Chat> {
let mut res = vec![];
for _ in 0..30 {
res.push(Chat);
}
res
}
}
lazy_static::lazy_static! {
pub static ref S: Schema<Query, EmptyMutation, EmptySubscription> = Schema::new(Query, EmptyMutation, EmptySubscription);
}
pub const Q: &str = r#"
fragment User on User {
id
is_operator
phone
join_date
state
profile {
email
first_name
last_name
job_title
city
}
}
{
chats {
id
created_at
title
caption
state
creator {
...User
}
messages {
id
timestamp
edited
message
order
}
users {
...User
}
}
}
"#;

View File

@ -1,9 +0,0 @@
[package]
name = "simple"
version = "2.6.5"
authors = ["Ivan Plesskih <terma95@gmail.com>"]
edition = "2018"
[dependencies]
async-graphql = { path = "../.." }
lazy_static = "1.4.0"

View File

@ -1,70 +0,0 @@
use async_graphql::*;
pub struct QueryRoot;
#[Object]
impl QueryRoot {
async fn value_i32(&self) -> i32 {
999
}
async fn obj(&self) -> MyObj {
MyObj
}
}
pub struct MyObj;
#[Object]
impl MyObj {
async fn value_i32(&self) -> i32 {
999
}
async fn value_list(&self) -> &[i32] {
&[1, 2, 3, 4, 5, 6, 7, 8, 9]
}
async fn obj(&self) -> MyObj {
MyObj
}
}
lazy_static::lazy_static! {
pub static ref S: Schema<QueryRoot, EmptyMutation, EmptySubscription> = Schema::new(QueryRoot, EmptyMutation, EmptySubscription);
// static ref D: Document = parse_query(Q).unwrap();
}
pub const Q: &str = r#"{
valueI32 obj {
valueI32 valueList obj {
valueI32 valueList obj {
valueI32 valueList obj {
valueI32 valueList obj {
valueI32 valueList obj {
valueI32 valueList obj {
valueI32 valueList obj {
valueI32 valueList obj {
valueI32 valueList obj {
valueI32 valueList obj {
valueI32 valueList obj {
valueI32 valueList obj {
valueI32 valueList obj {
valueI32 valueList obj {
valueI32 valueList
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}
}"#;

View File

@ -1,35 +0,0 @@
use async_graphql::{ObjectType, Response, Schema, SubscriptionType};
use async_graphql_parser::{parse_query, types::ExecutableDocument};
use async_std::task;
#[cfg(feature = "jemalloc")]
#[global_allocator]
static ALLOC: jemallocator::Jemalloc = jemallocator::Jemalloc;
pub fn run<Query, Mutation, Subscription>(
s: &Schema<Query, Mutation, Subscription>,
q: &str,
) -> Response
where
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
{
task::block_on(async { s.execute(q).await.into_result().unwrap() })
}
pub fn parse(q: &str) -> ExecutableDocument {
parse_query(q).unwrap()
}
// pub fn validate() {
// check_rules(&S.env.registry, &D, S.validation_mode).unwrap();
// }
//
// pub fn resolve() {
// do_resolve(...).unwrap();
// }
pub fn serialize(r: &async_graphql::Response) -> String {
serde_json::to_string(r).unwrap()
}

View File

@ -360,12 +360,8 @@ pub fn generate(
resolve_id, resolve_id,
&inc_resolve_id, &inc_resolve_id,
); );
let ctx_extension = #crate_name::extensions::ExtensionContext {
schema_data: &schema_env.data,
query_data: &query_env.ctx_data,
};
query_env.extensions.execution_start(&ctx_extension); query_env.extensions.execution_start();
#[allow(bare_trait_objects)] #[allow(bare_trait_objects)]
let ri = #crate_name::extensions::ResolveInfo { let ri = #crate_name::extensions::ResolveInfo {
@ -375,12 +371,10 @@ pub fn generate(
return_type: &<<#stream_ty as #crate_name::futures_util::stream::Stream>::Item as #crate_name::Type>::qualified_type_name(), return_type: &<<#stream_ty as #crate_name::futures_util::stream::Stream>::Item as #crate_name::Type>::qualified_type_name(),
}; };
query_env.extensions.resolve_start(&ctx_extension, &ri); query_env.extensions.resolve_start(&ri);
let res = #crate_name::OutputType::resolve(&msg, &ctx_selection_set, &*field).await; let res = #crate_name::OutputType::resolve(&msg, &ctx_selection_set, &*field).await;
query_env.extensions.resolve_end(&ri);
query_env.extensions.resolve_end(&ctx_extension, &ri); query_env.extensions.execution_end();
query_env.extensions.execution_end(&ctx_extension);
res res
} }

View File

@ -243,6 +243,7 @@ pub struct QueryEnvInner {
pub operation: Positioned<OperationDefinition>, pub operation: Positioned<OperationDefinition>,
pub fragments: HashMap<Name, Positioned<FragmentDefinition>>, pub fragments: HashMap<Name, Positioned<FragmentDefinition>>,
pub uploads: Vec<UploadValue>, pub uploads: Vec<UploadValue>,
pub session_data: Arc<Data>,
pub ctx_data: Arc<Data>, pub ctx_data: Arc<Data>,
pub http_headers: spin::Mutex<HeaderMap<String>>, pub http_headers: spin::Mutex<HeaderMap<String>>,
} }
@ -363,6 +364,7 @@ impl<'a, T> ContextBase<'a, T> {
.ctx_data .ctx_data
.0 .0
.get(&TypeId::of::<D>()) .get(&TypeId::of::<D>())
.or_else(|| self.query_env.session_data.0.get(&TypeId::of::<D>()))
.or_else(|| self.schema_env.data.0.get(&TypeId::of::<D>())) .or_else(|| self.schema_env.data.0.get(&TypeId::of::<D>()))
.and_then(|d| d.downcast_ref::<D>()) .and_then(|d| d.downcast_ref::<D>())
} }

View File

@ -6,7 +6,7 @@ use serde::ser::SerializeMap;
use serde::{Serialize, Serializer}; use serde::{Serialize, Serializer};
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo};
use crate::{value, Value, Variables}; use crate::{value, Value};
struct PendingResolve { struct PendingResolve {
path: Vec<String>, path: Vec<String>,
@ -79,13 +79,10 @@ impl Extension for ApolloTracingExtension {
Some("tracing") Some("tracing")
} }
fn parse_start( fn execution_start(&mut self, _ctx: &ExtensionContext<'_>) {
&mut self,
_ctx: &ExtensionContext<'_>,
_query_source: &str,
_variables: &Variables,
) {
self.start_time = Utc::now(); self.start_time = Utc::now();
self.pending_resolves.clear();
self.resolves.clear();
} }
fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) { fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) {

View File

@ -14,10 +14,13 @@ mod tracing;
use std::any::{Any, TypeId}; use std::any::{Any, TypeId};
use std::collections::BTreeMap; use std::collections::BTreeMap;
use std::sync::Arc;
use crate::context::{QueryPathNode, ResolveId}; use crate::context::{QueryPathNode, ResolveId};
use crate::parser::types::ExecutableDocument; use crate::parser::types::ExecutableDocument;
use crate::{Data, Request, Result, ServerError, ServerResult, ValidationResult, Variables}; use crate::{
Data, Request, Result, SchemaEnv, ServerError, ServerResult, ValidationResult, Variables,
};
use crate::{Error, Name, Value}; use crate::{Error, Name, Value};
pub use self::analyzer::Analyzer; pub use self::analyzer::Analyzer;
@ -26,9 +29,9 @@ pub use self::apollo_tracing::ApolloTracing;
#[cfg(feature = "log")] #[cfg(feature = "log")]
pub use self::logger::Logger; pub use self::logger::Logger;
#[cfg(feature = "opentelemetry")] #[cfg(feature = "opentelemetry")]
pub use self::opentelemetry::OpenTelemetry; pub use self::opentelemetry::{OpenTelemetry, OpenTelemetryConfig};
#[cfg(feature = "tracing")] #[cfg(feature = "tracing")]
pub use self::tracing::Tracing; pub use self::tracing::{Tracing, TracingConfig};
pub(crate) type BoxExtension = Box<dyn Extension>; pub(crate) type BoxExtension = Box<dyn Extension>;
@ -38,7 +41,10 @@ pub struct ExtensionContext<'a> {
pub schema_data: &'a Data, pub schema_data: &'a Data,
#[doc(hidden)] #[doc(hidden)]
pub query_data: &'a Data, pub session_data: &'a Data,
#[doc(hidden)]
pub query_data: Option<&'a Data>,
} }
impl<'a> ExtensionContext<'a> { impl<'a> ExtensionContext<'a> {
@ -71,7 +77,8 @@ impl<'a> ExtensionContext<'a> {
/// Gets the global data defined in the `Context` or `Schema` or `None` if the specified type data does not exist. /// Gets the global data defined in the `Context` or `Schema` or `None` if the specified type data does not exist.
pub fn data_opt<D: Any + Send + Sync>(&self) -> Option<&'a D> { pub fn data_opt<D: Any + Send + Sync>(&self) -> Option<&'a D> {
self.query_data self.query_data
.get(&TypeId::of::<D>()) .and_then(|query_data| query_data.get(&TypeId::of::<D>()))
.or_else(|| self.session_data.get(&TypeId::of::<D>()))
.or_else(|| self.schema_data.get(&TypeId::of::<D>())) .or_else(|| self.schema_data.get(&TypeId::of::<D>()))
.and_then(|d| d.downcast_ref::<D>()) .and_then(|d| d.downcast_ref::<D>())
} }
@ -94,6 +101,36 @@ pub struct ResolveInfo<'a> {
} }
/// Represents a GraphQL extension /// Represents a GraphQL extension
///
/// # Call order for query and mutation
///
/// - start
/// - prepare_request
/// - parse_start
/// - parse_end
/// - validation_start
/// - validation_end
/// - execution_start
/// - resolve_start
/// - resolve_end
/// - result
/// - execution_end
/// - end
///
/// # Call order for subscription
///
/// - start
/// - prepare_request
/// - parse_start
/// - parse_end
/// - validation_start
/// - validation_end
/// - execution_start
/// - resolve_start
/// - resolve_end
/// - execution_end
/// - result
/// ```
#[async_trait::async_trait] #[async_trait::async_trait]
#[allow(unused_variables)] #[allow(unused_variables)]
pub trait Extension: Sync + Send + 'static { pub trait Extension: Sync + Send + 'static {
@ -102,7 +139,13 @@ pub trait Extension: Sync + Send + 'static {
None None
} }
/// Called at the prepare request /// Called at the beginning of query.
fn start(&mut self, ctx: &ExtensionContext<'_>) {}
/// Called at the beginning of query.
fn end(&mut self, ctx: &ExtensionContext<'_>) {}
/// Called at prepare request.
async fn prepare_request( async fn prepare_request(
&mut self, &mut self,
ctx: &ExtensionContext<'_>, ctx: &ExtensionContext<'_>,
@ -111,7 +154,7 @@ pub trait Extension: Sync + Send + 'static {
Ok(request) Ok(request)
} }
/// Called at the begin of the parse. /// Called at the beginning of parse query source.
fn parse_start( fn parse_start(
&mut self, &mut self,
ctx: &ExtensionContext<'_>, ctx: &ExtensionContext<'_>,
@ -120,54 +163,54 @@ pub trait Extension: Sync + Send + 'static {
) { ) {
} }
/// Called at the end of the parse. /// Called at the end of parse query source.
fn parse_end(&mut self, ctx: &ExtensionContext<'_>, document: &ExecutableDocument) {} fn parse_end(&mut self, ctx: &ExtensionContext<'_>, document: &ExecutableDocument) {}
/// Called at the begin of the validation. /// Called at the beginning of the validation.
fn validation_start(&mut self, ctx: &ExtensionContext<'_>) {} fn validation_start(&mut self, ctx: &ExtensionContext<'_>) {}
/// Called at the end of the validation. /// Called at the end of the validation.
fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) {} fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) {}
/// Called at the begin of the execution. /// Called at the beginning of execute a query.
fn execution_start(&mut self, ctx: &ExtensionContext<'_>) {} fn execution_start(&mut self, ctx: &ExtensionContext<'_>) {}
/// Called at the end of the execution. /// Called at the end of execute a query.
fn execution_end(&mut self, ctx: &ExtensionContext<'_>) {} fn execution_end(&mut self, ctx: &ExtensionContext<'_>) {}
/// Called at the begin of the resolve field. /// Called at the beginning of resolve a field.
fn resolve_start(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {} fn resolve_start(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {}
/// Called at the end of the resolve field. /// Called at the end of resolve a field.
fn resolve_end(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {} fn resolve_end(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {}
/// Called when an error occurs. /// Called when an error occurs.
fn error(&mut self, ctx: &ExtensionContext<'_>, err: &ServerError) {} fn error(&mut self, ctx: &ExtensionContext<'_>, err: &ServerError) {}
/// Get the results /// Get the results.
fn result(&mut self, ctx: &ExtensionContext<'_>) -> Option<Value> { fn result(&mut self, ctx: &ExtensionContext<'_>) -> Option<Value> {
None None
} }
} }
pub(crate) trait ErrorLogger { pub(crate) trait ErrorLogger {
fn log_error(self, ctx: &ExtensionContext<'_>, extensions: &Extensions) -> Self; fn log_error(self, extensions: &Extensions) -> Self;
} }
impl<T> ErrorLogger for ServerResult<T> { impl<T> ErrorLogger for ServerResult<T> {
fn log_error(self, ctx: &ExtensionContext<'_>, extensions: &Extensions) -> Self { fn log_error(self, extensions: &Extensions) -> Self {
if let Err(err) = &self { if let Err(err) = &self {
extensions.error(ctx, err); extensions.error(err);
} }
self self
} }
} }
impl<T> ErrorLogger for Result<T, Vec<ServerError>> { impl<T> ErrorLogger for Result<T, Vec<ServerError>> {
fn log_error(self, ctx: &ExtensionContext<'_>, extensions: &Extensions) -> Self { fn log_error(self, extensions: &Extensions) -> Self {
if let Err(errors) = &self { if let Err(errors) = &self {
for error in errors { for error in errors {
extensions.error(ctx, error); extensions.error(error);
} }
} }
self self
@ -183,113 +226,160 @@ pub trait ExtensionFactory: Send + Sync + 'static {
} }
#[doc(hidden)] #[doc(hidden)]
pub struct Extensions(Option<spin::Mutex<Vec<BoxExtension>>>); pub struct Extensions {
extensions: Option<spin::Mutex<Vec<BoxExtension>>>,
schema_env: SchemaEnv,
session_data: Arc<Data>,
query_data: Option<Arc<Data>>,
}
impl From<Vec<BoxExtension>> for Extensions { #[doc(hidden)]
fn from(extensions: Vec<BoxExtension>) -> Self { impl Extensions {
Self(if extensions.is_empty() { pub fn new(
None extensions: Vec<BoxExtension>,
} else { schema_env: SchemaEnv,
Some(spin::Mutex::new(extensions)) session_data: Arc<Data>,
}) ) -> Self {
Extensions {
extensions: if extensions.is_empty() {
None
} else {
Some(spin::Mutex::new(extensions))
},
schema_env,
session_data,
query_data: None,
}
}
pub fn attach_query_data(&mut self, data: Arc<Data>) {
self.query_data = Some(data);
}
}
impl Drop for Extensions {
fn drop(&mut self) {
self.end();
} }
} }
#[doc(hidden)] #[doc(hidden)]
impl Extensions { impl Extensions {
pub fn is_empty(&self) -> bool { #[inline]
self.0.is_none() fn context(&self) -> ExtensionContext<'_> {
ExtensionContext {
schema_data: &self.schema_env.data,
session_data: &self.session_data,
query_data: self.query_data.as_deref(),
}
} }
pub async fn prepare_request( pub fn is_empty(&self) -> bool {
&mut self, self.extensions.is_none()
ctx: &ExtensionContext<'_>, }
request: Request,
) -> ServerResult<Request> { pub fn start(&self) {
if let Some(e) = &self.extensions {
e.lock().iter_mut().for_each(|e| e.start(&self.context()));
}
}
pub fn end(&self) {
if let Some(e) = &self.extensions {
e.lock().iter_mut().for_each(|e| e.end(&self.context()));
}
}
pub async fn prepare_request(&self, request: Request) -> ServerResult<Request> {
let mut request = request; let mut request = request;
if let Some(e) = &mut self.0 { if let Some(e) = &self.extensions {
for e in e.get_mut().iter_mut() { for e in e.lock().iter_mut() {
request = e.prepare_request(ctx, request).await?; request = e.prepare_request(&self.context(), request).await?;
} }
} }
Ok(request) Ok(request)
} }
pub fn parse_start( pub fn parse_start(&self, query_source: &str, variables: &Variables) {
&mut self, if let Some(e) = &self.extensions {
ctx: &ExtensionContext<'_>,
query_source: &str,
variables: &Variables,
) {
if let Some(e) = &mut self.0 {
e.get_mut()
.iter_mut()
.for_each(|e| e.parse_start(ctx, query_source, variables));
}
}
pub fn parse_end(&mut self, ctx: &ExtensionContext<'_>, document: &ExecutableDocument) {
if let Some(e) = &mut self.0 {
e.get_mut()
.iter_mut()
.for_each(|e| e.parse_end(ctx, document));
}
}
pub fn validation_start(&mut self, ctx: &ExtensionContext<'_>) {
if let Some(e) = &mut self.0 {
e.get_mut().iter_mut().for_each(|e| e.validation_start(ctx));
}
}
pub fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) {
if let Some(e) = &mut self.0 {
e.get_mut()
.iter_mut()
.for_each(|e| e.validation_end(ctx, result));
}
}
pub fn execution_start(&self, ctx: &ExtensionContext<'_>) {
if let Some(e) = &self.0 {
e.lock().iter_mut().for_each(|e| e.execution_start(ctx));
}
}
pub fn execution_end(&self, ctx: &ExtensionContext<'_>) {
if let Some(e) = &self.0 {
e.lock().iter_mut().for_each(|e| e.execution_end(ctx));
}
}
pub fn resolve_start(&self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
if let Some(e) = &self.0 {
e.lock().iter_mut().for_each(|e| e.resolve_start(ctx, info));
}
}
pub fn resolve_end(&self, ctx: &ExtensionContext<'_>, resolve_id: &ResolveInfo<'_>) {
if let Some(e) = &self.0 {
e.lock() e.lock()
.iter_mut() .iter_mut()
.for_each(|e| e.resolve_end(ctx, resolve_id)); .for_each(|e| e.parse_start(&self.context(), query_source, variables));
} }
} }
pub fn error(&self, ctx: &ExtensionContext<'_>, err: &ServerError) { pub fn parse_end(&self, document: &ExecutableDocument) {
if let Some(e) = &self.0 { if let Some(e) = &self.extensions {
e.lock().iter_mut().for_each(|e| e.error(ctx, err)); e.lock()
.iter_mut()
.for_each(|e| e.parse_end(&self.context(), document));
} }
} }
pub fn result(&self, ctx: &ExtensionContext<'_>) -> Option<Value> { pub fn validation_start(&self) {
if let Some(e) = &self.0 { if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.validation_start(&self.context()));
}
}
pub fn validation_end(&self, result: &ValidationResult) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.validation_end(&self.context(), result));
}
}
pub fn execution_start(&self) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.execution_start(&self.context()));
}
}
pub fn execution_end(&self) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.execution_end(&self.context()));
}
}
pub fn resolve_start(&self, info: &ResolveInfo<'_>) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.resolve_start(&self.context(), info));
}
}
pub fn resolve_end(&self, resolve_id: &ResolveInfo<'_>) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.resolve_end(&self.context(), resolve_id));
}
}
pub fn error(&self, err: &ServerError) {
if let Some(e) = &self.extensions {
e.lock()
.iter_mut()
.for_each(|e| e.error(&self.context(), err));
}
}
pub fn result(&self) -> Option<Value> {
if let Some(e) = &self.extensions {
let value = e let value = e
.lock() .lock()
.iter_mut() .iter_mut()
.filter_map(|e| { .filter_map(|e| {
if let Some(name) = e.name() { if let Some(name) = e.name() {
e.result(ctx).map(|res| (Name::new(name), res)) e.result(&self.context()).map(|res| (Name::new(name), res))
} else { } else {
None None
} }

View File

@ -1,9 +1,10 @@
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::Arc;
use async_graphql_parser::types::ExecutableDocument; use async_graphql_parser::types::ExecutableDocument;
use async_graphql_value::Variables; use async_graphql_value::Variables;
use opentelemetry::trace::{SpanKind, TraceContextExt, Tracer}; use opentelemetry::trace::{SpanKind, TraceContextExt, Tracer};
use opentelemetry::{Context as OpenTelemetryContext, Key, KeyValue}; use opentelemetry::{Context as OpenTelemetryContext, Key};
use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo};
use crate::{ServerError, ValidationResult}; use crate::{ServerError, ValidationResult};
@ -27,42 +28,41 @@ const KEY_ERROR: Key = Key::from_static_str("graphql.error");
const KEY_COMPLEXITY: Key = Key::from_static_str("graphql.complexity"); const KEY_COMPLEXITY: Key = Key::from_static_str("graphql.complexity");
const KEY_DEPTH: Key = Key::from_static_str("graphql.depth"); const KEY_DEPTH: Key = Key::from_static_str("graphql.depth");
/// OpenTelemetry extension /// OpenTelemetry extension configuration for each request.
#[derive(Default)] #[derive(Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))] #[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))]
pub struct OpenTelemetry<T, U> { pub struct OpenTelemetryConfig {
tracer: T, /// Use a context as the parent node of the entire query.
#[allow(dead_code)] parent: spin::Mutex<Option<OpenTelemetryContext>>,
uninstall: U,
} }
impl<T> OpenTelemetry<T, ()> { impl OpenTelemetryConfig {
/// Use a context as the parent of the entire query.
pub fn parent_context(mut self, cx: OpenTelemetryContext) -> Self {
*self.parent.get_mut() = Some(cx);
self
}
}
/// OpenTelemetry extension
#[cfg_attr(docsrs, doc(cfg(feature = "opentelemetry")))]
pub struct OpenTelemetry<T> {
tracer: Arc<T>,
}
impl<T> OpenTelemetry<T> {
/// Use `tracer` to create an OpenTelemetry extension. /// Use `tracer` to create an OpenTelemetry extension.
pub fn new(tracer: T) -> OpenTelemetry<T, ()> pub fn new(tracer: T) -> OpenTelemetry<T>
where where
T: Tracer + Send + Sync + Clone, T: Tracer + Send + Sync,
{ {
Self { Self {
tracer, tracer: Arc::new(tracer),
uninstall: (),
}
}
/// Attach an Uninstall instance to this extension.
pub fn with_uninstall<U>(self, uninstall: U) -> OpenTelemetry<T, U>
where
U: Send + Sync + 'static,
{
OpenTelemetry {
tracer: self.tracer,
uninstall,
} }
} }
} }
impl<T: Tracer + Send + Sync + Clone, U: Send + Sync + 'static> ExtensionFactory impl<T: Tracer + Send + Sync> ExtensionFactory for OpenTelemetry<T> {
for OpenTelemetry<T, U>
{
fn create(&self) -> Box<dyn Extension> { fn create(&self) -> Box<dyn Extension> {
Box::new(OpenTelemetryExtension { Box::new(OpenTelemetryExtension {
tracer: self.tracer.clone(), tracer: self.tracer.clone(),
@ -72,7 +72,7 @@ impl<T: Tracer + Send + Sync + Clone, U: Send + Sync + 'static> ExtensionFactory
} }
struct OpenTelemetryExtension<T> { struct OpenTelemetryExtension<T> {
tracer: T, tracer: Arc<T>,
contexts: HashMap<usize, OpenTelemetryContext>, contexts: HashMap<usize, OpenTelemetryContext>,
} }
@ -92,36 +92,46 @@ impl<T> OpenTelemetryExtension<T> {
} }
} }
impl<T: Tracer + Send + Sync + Clone> Extension for OpenTelemetryExtension<T> { impl<T: Tracer + Send + Sync> Extension for OpenTelemetryExtension<T> {
fn start(&mut self, ctx: &ExtensionContext<'_>) {
let request_cx = ctx
.data_opt::<OpenTelemetryConfig>()
.and_then(|cfg| cfg.parent.lock().take())
.unwrap_or_else(|| {
OpenTelemetryContext::current_with_span(
self.tracer
.span_builder("request")
.with_kind(SpanKind::Server)
.start(&*self.tracer),
)
});
self.enter_context(REQUEST_CTX, request_cx);
}
fn end(&mut self, _ctx: &ExtensionContext<'_>) {
self.exit_context(REQUEST_CTX);
}
fn parse_start( fn parse_start(
&mut self, &mut self,
_ctx: &ExtensionContext<'_>, _ctx: &ExtensionContext<'_>,
query_source: &str, query_source: &str,
variables: &Variables, variables: &Variables,
) { ) {
let request_span = self if let Some(parent_cx) = self.contexts.get(&REQUEST_CTX).cloned() {
.tracer let mut attributes = Vec::with_capacity(2);
.span_builder("request") attributes.push(KEY_SOURCE.string(query_source.to_string()));
.with_kind(SpanKind::Server) attributes.push(KEY_VARIABLES.string(serde_json::to_string(variables).unwrap()));
.start(&self.tracer); let parse_span = self
let request_cx = OpenTelemetryContext::current_with_span(request_span); .tracer
self.enter_context(REQUEST_CTX, request_cx.clone()); .span_builder("parse")
.with_kind(SpanKind::Server)
let mut attributes = Vec::with_capacity(2); .with_attributes(attributes)
attributes.push(KeyValue::new(KEY_SOURCE, query_source.to_string())); .with_parent_context(parent_cx)
attributes.push(KeyValue::new( .start(&*self.tracer);
KEY_VARIABLES, let parse_cx = OpenTelemetryContext::current_with_span(parse_span);
serde_json::to_string(variables).unwrap(), self.enter_context(PARSE_CTX, parse_cx);
)); }
let parse_span = self
.tracer
.span_builder("parse")
.with_kind(SpanKind::Server)
.with_attributes(attributes)
.with_parent_context(request_cx)
.start(&self.tracer);
let parse_cx = OpenTelemetryContext::current_with_span(parse_span);
self.enter_context(PARSE_CTX, parse_cx);
} }
fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, _document: &ExecutableDocument) { fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, _document: &ExecutableDocument) {
@ -135,7 +145,7 @@ impl<T: Tracer + Send + Sync + Clone> Extension for OpenTelemetryExtension<T> {
.span_builder("validation") .span_builder("validation")
.with_kind(SpanKind::Server) .with_kind(SpanKind::Server)
.with_parent_context(parent_cx) .with_parent_context(parent_cx)
.start(&self.tracer); .start(&*self.tracer);
let validation_cx = OpenTelemetryContext::current_with_span(span); let validation_cx = OpenTelemetryContext::current_with_span(span);
self.enter_context(VALIDATION_CTX, validation_cx); self.enter_context(VALIDATION_CTX, validation_cx);
} }
@ -144,8 +154,8 @@ impl<T: Tracer + Send + Sync + Clone> Extension for OpenTelemetryExtension<T> {
fn validation_end(&mut self, _ctx: &ExtensionContext<'_>, result: &ValidationResult) { fn validation_end(&mut self, _ctx: &ExtensionContext<'_>, result: &ValidationResult) {
if let Some(validation_cx) = self.exit_context(VALIDATION_CTX) { if let Some(validation_cx) = self.exit_context(VALIDATION_CTX) {
let span = validation_cx.span(); let span = validation_cx.span();
span.set_attribute(KeyValue::new(KEY_COMPLEXITY, result.complexity as i64)); span.set_attribute(KEY_COMPLEXITY.i64(result.complexity as i64));
span.set_attribute(KeyValue::new(KEY_DEPTH, result.depth as i64)); span.set_attribute(KEY_DEPTH.i64(result.depth as i64));
} }
} }
@ -156,12 +166,12 @@ impl<T: Tracer + Send + Sync + Clone> Extension for OpenTelemetryExtension<T> {
.span_builder("execute") .span_builder("execute")
.with_kind(SpanKind::Server) .with_kind(SpanKind::Server)
.with_parent_context(parent_cx) .with_parent_context(parent_cx)
.start(&self.tracer), .start(&*self.tracer),
None => self None => self
.tracer .tracer
.span_builder("execute") .span_builder("execute")
.with_kind(SpanKind::Server) .with_kind(SpanKind::Server)
.start(&self.tracer), .start(&*self.tracer),
}; };
let execute_cx = OpenTelemetryContext::current_with_span(span); let execute_cx = OpenTelemetryContext::current_with_span(span);
self.enter_context(EXECUTE_CTX, execute_cx); self.enter_context(EXECUTE_CTX, execute_cx);
@ -169,7 +179,6 @@ impl<T: Tracer + Send + Sync + Clone> Extension for OpenTelemetryExtension<T> {
fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) { fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) {
self.exit_context(EXECUTE_CTX); self.exit_context(EXECUTE_CTX);
self.exit_context(REQUEST_CTX);
} }
fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
@ -181,19 +190,16 @@ impl<T: Tracer + Send + Sync + Clone> Extension for OpenTelemetryExtension<T> {
if let Some(parent_cx) = parent_cx { if let Some(parent_cx) = parent_cx {
let mut attributes = Vec::with_capacity(3); let mut attributes = Vec::with_capacity(3);
attributes.push(KeyValue::new( attributes.push(KEY_RESOLVE_ID.i64(info.resolve_id.current as i64));
KEY_RESOLVE_ID, attributes.push(KEY_PARENT_TYPE.string(info.parent_type.to_string()));
info.resolve_id.current as i64, attributes.push(KEY_RETURN_TYPE.string(info.return_type.to_string()));
));
attributes.push(KeyValue::new(KEY_PARENT_TYPE, info.parent_type.to_string()));
attributes.push(KeyValue::new(KEY_RETURN_TYPE, info.return_type.to_string()));
let span = self let span = self
.tracer .tracer
.span_builder(&info.path_node.to_string()) .span_builder(&info.path_node.to_string())
.with_kind(SpanKind::Server) .with_kind(SpanKind::Server)
.with_parent_context(parent_cx) .with_parent_context(parent_cx)
.with_attributes(attributes) .with_attributes(attributes)
.start(&self.tracer); .start(&*self.tracer);
let resolve_cx = OpenTelemetryContext::current_with_span(span); let resolve_cx = OpenTelemetryContext::current_with_span(span);
self.enter_context(resolve_ctx_id(info.resolve_id.current), resolve_cx); self.enter_context(resolve_ctx_id(info.resolve_id.current), resolve_cx);
} }
@ -205,10 +211,9 @@ impl<T: Tracer + Send + Sync + Clone> Extension for OpenTelemetryExtension<T> {
fn error(&mut self, _ctx: &ExtensionContext<'_>, err: &ServerError) { fn error(&mut self, _ctx: &ExtensionContext<'_>, err: &ServerError) {
if let Some(parent_cx) = self.contexts.get(&EXECUTE_CTX).cloned() { if let Some(parent_cx) = self.contexts.get(&EXECUTE_CTX).cloned() {
parent_cx.span().add_event( parent_cx
"error".to_string(), .span()
vec![KeyValue::new(KEY_ERROR, err.to_string())], .add_event("error".to_string(), vec![KEY_ERROR.string(err.to_string())]);
);
} }
} }
} }

View File

@ -6,6 +6,22 @@ use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveIn
use crate::parser::types::ExecutableDocument; use crate::parser::types::ExecutableDocument;
use crate::{ServerError, ValidationResult, Variables}; use crate::{ServerError, ValidationResult, Variables};
/// Tracing extension configuration for each request.
#[derive(Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "tracing")))]
pub struct TracingConfig {
/// Use a span as the parent node of the entire query.
parent: spin::Mutex<Option<Span>>,
}
impl TracingConfig {
/// Use a span as the parent of the entire query.
pub fn parent_span(mut self, span: Span) -> Self {
*self.parent.get_mut() = Some(span);
self
}
}
const REQUEST_CTX: usize = 0; const REQUEST_CTX: usize = 0;
const PARSE_CTX: usize = 1; const PARSE_CTX: usize = 1;
const VALIDATION_CTX: usize = 2; const VALIDATION_CTX: usize = 2;
@ -26,7 +42,7 @@ fn resolve_span_id(resolver_id: usize) -> usize {
/// ///
/// ```no_run /// ```no_run
/// use async_graphql::*; /// use async_graphql::*;
/// use async_graphql::extensions::Tracing; /// use async_graphql::extensions::{Tracing, TracingConfig};
/// use tracing::{span, Level, Instrument}; /// use tracing::{span, Level, Instrument};
/// ///
/// #[derive(SimpleObject)] /// #[derive(SimpleObject)]
@ -42,7 +58,7 @@ fn resolve_span_id(resolver_id: usize) -> usize {
/// schema.execute(Request::new("{ value }")).await; /// schema.execute(Request::new("{ value }")).await;
/// }); /// });
/// ///
/// // tracing in custom parent span /// // tracing in parent span
/// tokio::runtime::Runtime::new().unwrap().block_on(async { /// tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let root_span = span!( /// let root_span = span!(
/// parent: None, /// parent: None,
@ -51,6 +67,16 @@ fn resolve_span_id(resolver_id: usize) -> usize {
/// ); /// );
/// schema.execute(Request::new("{ value }")).instrument(root_span).await; /// schema.execute(Request::new("{ value }")).instrument(root_span).await;
/// }); /// });
///
/// // replace root span
/// tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let root_span = span!(
/// parent: None,
/// Level::INFO,
/// "span root"
/// );
/// schema.execute(Request::new("{ value }").data(TracingConfig::default().parent_span(root_span))).await;
/// });
/// ``` /// ```
#[derive(Default)] #[derive(Default)]
#[cfg_attr(docsrs, doc(cfg(feature = "tracing")))] #[cfg_attr(docsrs, doc(cfg(feature = "tracing")))]
@ -82,30 +108,42 @@ impl TracingExtension {
} }
impl Extension for TracingExtension { impl Extension for TracingExtension {
fn start(&mut self, ctx: &ExtensionContext<'_>) {
let request_span = ctx
.data_opt::<TracingConfig>()
.and_then(|cfg| cfg.parent.lock().take())
.unwrap_or_else(|| {
span!(
target: "async_graphql::graphql",
Level::INFO,
"request",
)
});
self.enter_span(REQUEST_CTX, request_span);
}
fn end(&mut self, _ctx: &ExtensionContext<'_>) {
self.exit_span(REQUEST_CTX);
}
fn parse_start( fn parse_start(
&mut self, &mut self,
_ctx: &ExtensionContext<'_>, _ctx: &ExtensionContext<'_>,
query_source: &str, query_source: &str,
variables: &Variables, variables: &Variables,
) { ) {
let request_span = span!( if let Some(parent) = self.spans.get(&REQUEST_CTX) {
target: "async_graphql::graphql", let variables = serde_json::to_string(&variables).unwrap();
Level::INFO, let parse_span = span!(
"request", target: "async_graphql::graphql",
); parent: parent,
Level::INFO,
let variables = serde_json::to_string(&variables).unwrap(); "parse",
let parse_span = span!( source = query_source,
target: "async_graphql::graphql", variables = %variables,
parent: &request_span, );
Level::INFO, self.enter_span(PARSE_CTX, parse_span);
"parse", }
source = query_source,
variables = %variables,
);
self.enter_span(REQUEST_CTX, request_span);
self.enter_span(PARSE_CTX, parse_span);
} }
fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, _document: &ExecutableDocument) { fn parse_end(&mut self, _ctx: &ExtensionContext<'_>, _document: &ExecutableDocument) {
@ -129,27 +167,19 @@ impl Extension for TracingExtension {
} }
fn execution_start(&mut self, _ctx: &ExtensionContext<'_>) { fn execution_start(&mut self, _ctx: &ExtensionContext<'_>) {
let span = match self.spans.get(&REQUEST_CTX) { if let Some(parent) = self.spans.get(&REQUEST_CTX) {
Some(parent) => span!( let span = span!(
target: "async_graphql::graphql", target: "async_graphql::graphql",
parent: parent, parent: parent,
Level::INFO, Level::INFO,
"execute" "execute"
), );
None => span!( self.enter_span(EXECUTE_CTX, span);
target: "async_graphql::graphql",
parent: None,
Level::INFO,
"execute"
),
}; };
self.enter_span(EXECUTE_CTX, span);
} }
fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) { fn execution_end(&mut self, _ctx: &ExtensionContext<'_>) {
self.exit_span(EXECUTE_CTX); self.exit_span(EXECUTE_CTX);
self.exit_span(REQUEST_CTX);
} }
fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) { fn resolve_start(&mut self, _ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {

View File

@ -177,7 +177,9 @@ where
if let Some(data) = this.data.clone() { if let Some(data) = this.data.clone() {
this.streams.insert( this.streams.insert(
id, id,
Box::pin(this.schema.execute_stream_with_ctx_data(request, data)), Box::pin(
this.schema.execute_stream_with_session_data(request, data),
),
); );
} else { } else {
return Poll::Ready(Some(WsMessage::Close( return Poll::Ready(Some(WsMessage::Close(

View File

@ -2,7 +2,7 @@ use std::collections::BTreeMap;
use std::future::Future; use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
use crate::extensions::{ErrorLogger, ExtensionContext, ResolveInfo}; use crate::extensions::{ErrorLogger, ResolveInfo};
use crate::parser::types::Selection; use crate::parser::types::Selection;
use crate::registry::MetaType; use crate::registry::MetaType;
use crate::{ use crate::{
@ -183,11 +183,6 @@ impl<'a> Fields<'a> {
} }
}? }?
} else { } else {
let ctx_extension = ExtensionContext {
schema_data: &ctx.schema_env.data,
query_data: &ctx.query_env.ctx_data,
};
let type_name = T::type_name(); let type_name = T::type_name();
let resolve_info = ResolveInfo { let resolve_info = ResolveInfo {
resolve_id: ctx_field.resolve_id, resolve_id: ctx_field.resolve_id,
@ -215,10 +210,7 @@ impl<'a> Fields<'a> {
}, },
}; };
ctx_field ctx_field.query_env.extensions.resolve_start(&resolve_info);
.query_env
.extensions
.resolve_start(&ctx_extension, &resolve_info);
let res = match root.resolve_field(&ctx_field).await { let res = match root.resolve_field(&ctx_field).await {
Ok(value) => Ok((field_name, value.unwrap_or_default())), Ok(value) => Ok((field_name, value.unwrap_or_default())),
@ -226,13 +218,8 @@ impl<'a> Fields<'a> {
Err(e.path(PathSegment::Field(field_name.to_string()))) Err(e.path(PathSegment::Field(field_name.to_string())))
} }
} }
.log_error(&ctx_extension, &ctx_field.query_env.extensions)?; .log_error(&ctx_field.query_env.extensions)?;
ctx_field.query_env.extensions.resolve_end(&resolve_info);
ctx_field
.query_env
.extensions
.resolve_end(&ctx_extension, &resolve_info);
res res
}; };

View File

@ -1,4 +1,4 @@
use crate::extensions::{ErrorLogger, ExtensionContext, ResolveInfo}; use crate::extensions::{ErrorLogger, ResolveInfo};
use crate::parser::types::Field; use crate::parser::types::Field;
use crate::{ContextSelectionSet, OutputType, PathSegment, Positioned, ServerResult, Type, Value}; use crate::{ContextSelectionSet, OutputType, PathSegment, Positioned, ServerResult, Type, Value};
@ -14,16 +14,11 @@ pub async fn resolve_list<'a, T: OutputType + 'a>(
for (idx, item) in iter.into_iter().enumerate() { for (idx, item) in iter.into_iter().enumerate() {
let ctx_idx = ctx.with_index(idx); let ctx_idx = ctx.with_index(idx);
futures.push(async move { futures.push(async move {
let ctx_extension = ExtensionContext {
schema_data: &ctx.schema_env.data,
query_data: &ctx.query_env.ctx_data,
};
if ctx_idx.query_env.extensions.is_empty() { if ctx_idx.query_env.extensions.is_empty() {
OutputType::resolve(&item, &ctx_idx, field) OutputType::resolve(&item, &ctx_idx, field)
.await .await
.map_err(|e| e.path(PathSegment::Index(idx))) .map_err(|e| e.path(PathSegment::Index(idx)))
.log_error(&ctx_extension, &ctx_idx.query_env.extensions) .log_error(&ctx_idx.query_env.extensions)
} else { } else {
let resolve_info = ResolveInfo { let resolve_info = ResolveInfo {
resolve_id: ctx_idx.resolve_id, resolve_id: ctx_idx.resolve_id,
@ -32,21 +27,14 @@ pub async fn resolve_list<'a, T: OutputType + 'a>(
return_type: &T::qualified_type_name(), return_type: &T::qualified_type_name(),
}; };
ctx_idx ctx_idx.query_env.extensions.resolve_start(&resolve_info);
.query_env
.extensions
.resolve_start(&ctx_extension, &resolve_info);
let res = OutputType::resolve(&item, &ctx_idx, field) let res = OutputType::resolve(&item, &ctx_idx, field)
.await .await
.map_err(|e| e.path(PathSegment::Index(idx))) .map_err(|e| e.path(PathSegment::Index(idx)))
.log_error(&ctx_extension, &ctx_idx.query_env.extensions)?; .log_error(&ctx_idx.query_env.extensions)?;
ctx_idx
.query_env
.extensions
.resolve_end(&ctx_extension, &resolve_info);
ctx_idx.query_env.extensions.resolve_end(&resolve_info);
Ok(res) Ok(res)
} }
}); });

View File

@ -8,7 +8,7 @@ use futures_util::stream::{self, Stream, StreamExt};
use indexmap::map::IndexMap; use indexmap::map::IndexMap;
use crate::context::{Data, QueryEnvInner, ResolveId}; use crate::context::{Data, QueryEnvInner, ResolveId};
use crate::extensions::{ErrorLogger, ExtensionContext, ExtensionFactory, Extensions}; use crate::extensions::{ErrorLogger, ExtensionFactory, Extensions};
use crate::model::__DirectiveLocation; use crate::model::__DirectiveLocation;
use crate::parser::parse_query; use crate::parser::parse_query;
use crate::parser::types::{DocumentOperations, OperationType}; use crate::parser::types::{DocumentOperations, OperationType};
@ -359,57 +359,59 @@ where
self.0.env.registry.names() self.0.env.registry.names()
} }
fn create_extensions(&self, session_data: Arc<Data>) -> Extensions {
let extensions = Extensions::new(
self.0
.extensions
.iter()
.map(|factory| factory.create())
.collect::<Vec<_>>(),
self.env.clone(),
session_data,
);
extensions.start();
extensions
}
async fn prepare_request( async fn prepare_request(
&self, &self,
mut extensions: Extensions,
request: Request, request: Request,
) -> Result<(QueryEnvInner, CacheControl), Vec<ServerError>> { session_data: Arc<Data>,
// create extension instances ) -> Result<(QueryEnv, CacheControl), Vec<ServerError>> {
let mut extensions: Extensions = self
.0
.extensions
.iter()
.map(|factory| factory.create())
.collect::<Vec<_>>()
.into();
let mut request = request; let mut request = request;
let data = std::mem::take(&mut request.data); let query_data = Arc::new(std::mem::take(&mut request.data));
let ctx_extension = ExtensionContext { extensions.attach_query_data(query_data.clone());
schema_data: &self.env.data, let request = extensions.prepare_request(request).await?;
query_data: &data,
};
let request = extensions.prepare_request(&ctx_extension, request).await?; extensions.parse_start(&request.query, &request.variables);
extensions.parse_start(&ctx_extension, &request.query, &request.variables);
let document = parse_query(&request.query) let document = parse_query(&request.query)
.map_err(Into::<ServerError>::into) .map_err(Into::<ServerError>::into)
.log_error(&ctx_extension, &extensions)?; .log_error(&extensions)?;
extensions.parse_end(&ctx_extension, &document); extensions.parse_end(&document);
// check rules // check rules
extensions.validation_start(&ctx_extension); extensions.validation_start();
let validation_result = check_rules( let validation_result = check_rules(
&self.env.registry, &self.env.registry,
&document, &document,
Some(&request.variables), Some(&request.variables),
self.validation_mode, self.validation_mode,
) )
.log_error(&ctx_extension, &extensions)?; .log_error(&extensions)?;
extensions.validation_end(&ctx_extension, &validation_result); extensions.validation_end(&validation_result);
// check limit // check limit
if let Some(limit_complexity) = self.complexity { if let Some(limit_complexity) = self.complexity {
if validation_result.complexity > limit_complexity { if validation_result.complexity > limit_complexity {
return Err(vec![ServerError::new("Query is too complex.")]) return Err(vec![ServerError::new("Query is too complex.")]).log_error(&extensions);
.log_error(&ctx_extension, &extensions);
} }
} }
if let Some(limit_depth) = self.depth { if let Some(limit_depth) = self.depth {
if validation_result.depth > limit_depth { if validation_result.depth > limit_depth {
return Err(vec![ServerError::new("Query is nested too deep.")]) return Err(vec![ServerError::new("Query is nested too deep.")])
.log_error(&ctx_extension, &extensions); .log_error(&extensions);
} }
} }
@ -437,7 +439,7 @@ where
let operation = match operation { let operation = match operation {
Ok(operation) => operation, Ok(operation) => operation,
Err(e) => { Err(e) => {
extensions.error(&ctx_extension, &e); extensions.error(&e);
return Err(vec![e]); return Err(vec![e]);
} }
}; };
@ -448,10 +450,11 @@ where
operation, operation,
fragments: document.fragments, fragments: document.fragments,
uploads: request.uploads, uploads: request.uploads,
ctx_data: Arc::new(data), session_data,
ctx_data: query_data,
http_headers: Default::default(), http_headers: Default::default(),
}; };
Ok((env, validation_result.cache_control)) Ok((QueryEnv::new(env), validation_result.cache_control))
} }
async fn execute_once(&self, env: QueryEnv) -> Response { async fn execute_once(&self, env: QueryEnv) -> Response {
@ -465,12 +468,7 @@ where
schema_env: &self.env, schema_env: &self.env,
query_env: &env, query_env: &env,
}; };
let ctx_extension = ExtensionContext { env.extensions.execution_start();
schema_data: &self.env.data,
query_data: &env.ctx_data,
};
env.extensions.execution_start(&ctx_extension);
let data = match &env.operation.node.ty { let data = match &env.operation.node.ty {
OperationType::Query => resolve_container(&ctx, &self.query).await, OperationType::Query => resolve_container(&ctx, &self.query).await,
@ -482,8 +480,8 @@ where
} }
}; };
env.extensions.execution_end(&ctx_extension); env.extensions.execution_end();
let extensions = env.extensions.result(&ctx_extension); let extensions = env.extensions.result();
match data { match data {
Ok(data) => Response::new(data), Ok(data) => Response::new(data),
@ -496,11 +494,12 @@ where
/// Execute a GraphQL query. /// Execute a GraphQL query.
pub async fn execute(&self, request: impl Into<Request>) -> Response { pub async fn execute(&self, request: impl Into<Request>) -> Response {
let request = request.into(); let request = request.into();
match self.prepare_request(request).await { let extensions = self.create_extensions(Default::default());
Ok((env, cache_control)) => self match self
.execute_once(QueryEnv::new(env)) .prepare_request(extensions, request, Default::default())
.await .await
.cache_control(cache_control), {
Ok((env, cache_control)) => self.execute_once(env).await.cache_control(cache_control),
Err(errors) => Response::from_errors(errors), Err(errors) => Response::from_errors(errors),
} }
} }
@ -518,24 +517,25 @@ where
} }
} }
pub(crate) fn execute_stream_with_ctx_data( /// Execute a GraphQL subscription with session data.
#[doc(hidden)]
pub fn execute_stream_with_session_data(
&self, &self,
request: impl Into<Request> + Send, request: impl Into<Request> + Send,
ctx_data: Arc<Data>, session_data: Arc<Data>,
) -> impl Stream<Item = Response> + Send { ) -> impl Stream<Item = Response> + Send {
let schema = self.clone(); let schema = self.clone();
let request = request.into();
let extensions = self.create_extensions(session_data.clone());
async_stream::stream! { async_stream::stream! {
let request = request.into(); let (env, cache_control) = match schema.prepare_request(extensions, request, session_data).await {
let (mut env, cache_control) = match schema.prepare_request(request).await {
Ok(res) => res, Ok(res) => res,
Err(errors) => { Err(errors) => {
yield Response::from_errors(errors); yield Response::from_errors(errors);
return; return;
} }
}; };
env.ctx_data = ctx_data;
let env = QueryEnv::new(env);
if env.operation.node.ty != OperationType::Subscription { if env.operation.node.ty != OperationType::Subscription {
yield schema yield schema
@ -553,26 +553,18 @@ where
ResolveId::root(), ResolveId::root(),
&resolve_id, &resolve_id,
); );
let ctx_extension = ExtensionContext {
schema_data: &schema.env.data,
query_data: &env.ctx_data,
};
env.extensions.execution_start(&ctx_extension);
let mut streams = Vec::new(); let mut streams = Vec::new();
if let Err(e) = collect_subscription_streams(&ctx, &schema.subscription, &mut streams) { if let Err(e) = collect_subscription_streams(&ctx, &schema.subscription, &mut streams) {
env.extensions.execution_end(&ctx_extension); env.extensions.execution_end();
yield Response::from_errors(vec![e]); yield Response::from_errors(vec![e]);
return; return;
} }
env.extensions.execution_end(&ctx_extension);
let mut stream = stream::select_all(streams); let mut stream = stream::select_all(streams);
while let Some(data) = stream.next().await { while let Some(data) = stream.next().await {
let is_err = data.is_err(); let is_err = data.is_err();
let extensions = env.extensions.result(&ctx_extension); let extensions = env.extensions.result();
yield match data { yield match data {
Ok((name, value)) => { Ok((name, value)) => {
let mut map = BTreeMap::new(); let mut map = BTreeMap::new();
@ -593,8 +585,6 @@ where
&self, &self,
request: impl Into<Request>, request: impl Into<Request>,
) -> impl Stream<Item = Response> + Send { ) -> impl Stream<Item = Response> + Send {
let mut request = request.into(); self.execute_stream_with_session_data(request.into(), Default::default())
let ctx_data = std::mem::take(&mut request.data);
self.execute_stream_with_ctx_data(request, Arc::new(ctx_data))
} }
} }

View File

@ -38,13 +38,7 @@ pub(crate) fn collect_subscription_streams<'a, T: SubscriptionType + 'static>(
let ctx = ctx.clone(); let ctx = ctx.clone();
async_stream::stream! { async_stream::stream! {
let ctx = ctx.with_field(field); let ctx = ctx.with_field(field);
let field_name = ctx let field_name = ctx.item.node.response_key().node.clone();
.item
.node
.response_key()
.node
.clone();
let stream = root.create_field_stream(&ctx); let stream = root.create_field_stream(&ctx);
if let Some(mut stream) = stream { if let Some(mut stream) = stream {
while let Some(item) = stream.next().await { while let Some(item) = stream.next().await {

View File

@ -1,8 +1,13 @@
use async_graphql::extensions::{Extension, ExtensionContext, ExtensionFactory};
use async_graphql::*;
use spin::Mutex;
use std::sync::Arc; use std::sync::Arc;
use async_graphql::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo};
use async_graphql::parser::types::ExecutableDocument;
use async_graphql::*;
use async_graphql_value::ConstValue;
use futures_util::stream::Stream;
use futures_util::StreamExt;
use spin::Mutex;
#[tokio::test] #[tokio::test]
pub async fn test_extension_ctx() { pub async fn test_extension_ctx() {
#[derive(Default, Clone)] #[derive(Default, Clone)]
@ -12,8 +17,18 @@ pub async fn test_extension_ctx() {
#[Object] #[Object]
impl Query { impl Query {
async fn value(&self) -> bool { async fn value(&self, ctx: &Context<'_>) -> i32 {
true *ctx.data_unchecked::<MyData>().0.lock()
}
}
struct Subscription;
#[Subscription]
impl Subscription {
async fn value(&self, ctx: &Context<'_>) -> impl Stream<Item = i32> {
let data = *ctx.data_unchecked::<MyData>().0.lock();
futures_util::stream::once(async move { data })
} }
} }
@ -27,7 +42,9 @@ pub async fn test_extension_ctx() {
_query_source: &str, _query_source: &str,
_variables: &Variables, _variables: &Variables,
) { ) {
*ctx.data_unchecked::<MyData>().0.lock() = 100; if let Ok(data) = ctx.data::<MyData>() {
*data.0.lock() = 100;
}
} }
} }
@ -46,9 +63,17 @@ pub async fn test_extension_ctx() {
.data(data.clone()) .data(data.clone())
.extension(MyExtension) .extension(MyExtension)
.finish(); .finish();
assert_eq!(
schema.execute("{ value }").await.into_result().unwrap(); schema
assert_eq!(*data.0.lock(), 100); .execute("{ value }")
.await
.into_result()
.unwrap()
.data,
value! ({
"value": 100
})
);
} }
// data in request // data in request
@ -58,11 +83,219 @@ pub async fn test_extension_ctx() {
.extension(MyExtension) .extension(MyExtension)
.finish(); .finish();
schema assert_eq!(
.execute(Request::new("{ value }").data(data.clone())) schema
.execute(Request::new("{ value }").data(data.clone()))
.await
.into_result()
.unwrap()
.data,
value! ({
"value": 100
})
);
}
// data in session
{
let schema = Schema::build(Query, EmptyMutation, Subscription)
.extension(MyExtension)
.finish();
let mut data = Data::default();
data.insert(MyData::default());
let mut stream = schema
.execute_stream_with_session_data(
Request::new("subscription { value }"),
Arc::new(data),
)
.boxed();
assert_eq!(
stream.next().await.unwrap().into_result().unwrap().data,
value! ({
"value": 100
})
);
}
}
#[tokio::test]
pub async fn test_extension_call_order() {
struct MyExtensionImpl {
calls: Arc<Mutex<Vec<&'static str>>>,
}
#[async_trait::async_trait]
#[allow(unused_variables)]
impl Extension for MyExtensionImpl {
fn name(&self) -> Option<&'static str> {
Some("test")
}
fn start(&mut self, ctx: &ExtensionContext<'_>) {
self.calls.lock().push("start");
}
fn end(&mut self, ctx: &ExtensionContext<'_>) {
self.calls.lock().push("end");
}
async fn prepare_request(
&mut self,
ctx: &ExtensionContext<'_>,
request: Request,
) -> ServerResult<Request> {
self.calls.lock().push("prepare_request");
Ok(request)
}
fn parse_start(
&mut self,
ctx: &ExtensionContext<'_>,
query_source: &str,
variables: &Variables,
) {
self.calls.lock().push("parse_start");
}
fn parse_end(&mut self, ctx: &ExtensionContext<'_>, document: &ExecutableDocument) {
self.calls.lock().push("parse_end");
}
fn validation_start(&mut self, ctx: &ExtensionContext<'_>) {
self.calls.lock().push("validation_start");
}
fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) {
self.calls.lock().push("validation_end");
}
fn execution_start(&mut self, ctx: &ExtensionContext<'_>) {
self.calls.lock().push("execution_start");
}
fn execution_end(&mut self, ctx: &ExtensionContext<'_>) {
self.calls.lock().push("execution_end");
}
fn resolve_start(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
self.calls.lock().push("resolve_start");
}
fn resolve_end(&mut self, ctx: &ExtensionContext<'_>, info: &ResolveInfo<'_>) {
self.calls.lock().push("resolve_end");
}
fn result(&mut self, ctx: &ExtensionContext<'_>) -> Option<ConstValue> {
self.calls.lock().push("result");
None
}
}
struct MyExtension {
calls: Arc<Mutex<Vec<&'static str>>>,
}
impl ExtensionFactory for MyExtension {
fn create(&self) -> Box<dyn Extension> {
Box::new(MyExtensionImpl {
calls: self.calls.clone(),
})
}
}
struct Query;
#[Object]
impl Query {
async fn value1(&self) -> i32 {
10
}
async fn value2(&self) -> i32 {
10
}
}
struct Subscription;
#[Subscription]
impl Subscription {
async fn value(&self) -> impl Stream<Item = i32> {
futures_util::stream::iter(vec![1, 2, 3])
}
}
{
let calls: Arc<Mutex<Vec<&'static str>>> = Default::default();
let schema = Schema::build(Query, EmptyMutation, Subscription)
.extension(MyExtension {
calls: calls.clone(),
})
.finish();
let _ = schema
.execute("{ value1 value2 }")
.await .await
.into_result() .into_result()
.unwrap(); .unwrap();
assert_eq!(*data.0.lock(), 100); let calls = calls.lock();
assert_eq!(
&*calls,
&vec![
"start",
"prepare_request",
"parse_start",
"parse_end",
"validation_start",
"validation_end",
"execution_start",
"resolve_start",
"resolve_end",
"resolve_start",
"resolve_end",
"execution_end",
"result",
"end",
]
);
}
{
let calls: Arc<Mutex<Vec<&'static str>>> = Default::default();
let schema = Schema::build(Query, EmptyMutation, Subscription)
.extension(MyExtension {
calls: calls.clone(),
})
.finish();
let mut stream = schema.execute_stream("subscription { value }").boxed();
while let Some(_) = stream.next().await {}
let calls = calls.lock();
assert_eq!(
&*calls,
&vec![
"start",
"prepare_request",
"parse_start",
"parse_end",
"validation_start",
"validation_end",
"execution_start",
"resolve_start",
"resolve_end",
"execution_end",
"result",
"execution_start",
"resolve_start",
"resolve_end",
"execution_end",
"result",
"execution_start",
"resolve_start",
"resolve_end",
"execution_end",
"result",
"end",
]
);
} }
} }