diff --git a/CHANGELOG.md b/CHANGELOG.md index fa8fab37..e59d82e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed -- Rework `Extension`, now fully supports asynchronous, better to use than before, and can achieve more features. - Because it contains a lot of changes _(if you don't have a custom extension, it will not cause the existing code to fail to compile)_, the major version will be updated to `3.0.0`. +- Rework `Extension`, now fully supports asynchronous, better to use than before, and can achieve more features, it contains a lot of changes. _(if you don't have a custom extension, it will not cause the existing code to fail to compile)_ ## [2.7.4] 2021-04-02 diff --git a/src/extensions/analyzer.rs b/src/extensions/analyzer.rs index 409a5e47..dc05e458 100644 --- a/src/extensions/analyzer.rs +++ b/src/extensions/analyzer.rs @@ -2,7 +2,9 @@ use std::sync::Arc; use futures_util::lock::Mutex; -use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, NextExtension}; +use crate::extensions::{ + Extension, ExtensionContext, ExtensionFactory, NextRequest, NextValidation, +}; use crate::{value, Response, ServerError, ValidationResult}; /// Analyzer extension @@ -23,8 +25,8 @@ struct AnalyzerExtension { #[async_trait::async_trait] impl Extension for AnalyzerExtension { - async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { - let mut resp = next.request(ctx).await; + async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response { + let mut resp = next.run(ctx).await; let validation_result = self.validation_result.lock().await.take(); if let Some(validation_result) = validation_result { resp = resp.extension( @@ -41,9 +43,9 @@ impl Extension for AnalyzerExtension { async fn validation( &self, ctx: &ExtensionContext<'_>, - next: NextExtension<'_>, + next: NextValidation<'_>, ) -> Result> { - let res = next.validation(ctx).await?; + let res = next.run(ctx).await?; *self.validation_result.lock().await = Some(res); Ok(res) } diff --git a/src/extensions/apollo_persisted_queries.rs b/src/extensions/apollo_persisted_queries.rs index 98d60c94..d892e70f 100644 --- a/src/extensions/apollo_persisted_queries.rs +++ b/src/extensions/apollo_persisted_queries.rs @@ -6,7 +6,7 @@ use futures_util::lock::Mutex; use serde::Deserialize; use sha2::{Digest, Sha256}; -use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, NextExtension}; +use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, NextPrepareRequest}; use crate::{from_value, Request, ServerError, ServerResult}; #[derive(Deserialize)] @@ -81,7 +81,7 @@ impl Extension for ApolloPersistedQueriesExtension { &self, ctx: &ExtensionContext<'_>, mut request: Request, - next: NextExtension<'_>, + next: NextPrepareRequest<'_>, ) -> ServerResult { let res = if let Some(value) = request.extensions.remove("persistedQuery") { let persisted_query: PersistedQuery = from_value(value).map_err(|_| { @@ -112,7 +112,7 @@ impl Extension for ApolloPersistedQueriesExtension { } else { Ok(request) }; - next.prepare_request(ctx, res?).await + next.run(ctx, res?).await } } diff --git a/src/extensions/apollo_tracing.rs b/src/extensions/apollo_tracing.rs index cb60fd97..0a075313 100644 --- a/src/extensions/apollo_tracing.rs +++ b/src/extensions/apollo_tracing.rs @@ -6,7 +6,7 @@ use serde::ser::SerializeMap; use serde::{Serialize, Serializer}; use crate::extensions::{ - Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo, + Extension, ExtensionContext, ExtensionFactory, NextExecute, NextResolve, ResolveInfo, }; use crate::{value, Response, ServerResult, Value}; @@ -70,9 +70,9 @@ struct ApolloTracingExtension { #[async_trait::async_trait] impl Extension for ApolloTracingExtension { - async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExecute<'_>) -> Response { self.inner.lock().await.start_time = Utc::now(); - let resp = next.execute(ctx).await; + let resp = next.run(ctx).await; let mut inner = self.inner.lock().await; inner.end_time = Utc::now(); @@ -97,7 +97,7 @@ impl Extension for ApolloTracingExtension { &self, ctx: &ExtensionContext<'_>, info: ResolveInfo<'_>, - next: NextExtension<'_>, + next: NextResolve<'_>, ) -> ServerResult> { let path = info.path_node.to_string_vec(); let field_name = info.path_node.field_name().to_string(); @@ -108,7 +108,7 @@ impl Extension for ApolloTracingExtension { .num_nanoseconds() .unwrap(); - let res = next.resolve(ctx, info).await; + let res = next.run(ctx, info).await; let end_time = Utc::now(); self.inner.lock().await.resolves.push(ResolveState { diff --git a/src/extensions/logger.rs b/src/extensions/logger.rs index 53c12a0d..b10d7209 100644 --- a/src/extensions/logger.rs +++ b/src/extensions/logger.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use futures_util::lock::Mutex; use crate::extensions::{ - Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo, + Extension, ExtensionContext, ExtensionFactory, NextParseQuery, NextResolve, ResolveInfo, }; use crate::parser::types::{ExecutableDocument, OperationType, Selection}; use crate::{PathSegment, ServerError, ServerResult, Value, Variables}; @@ -42,12 +42,12 @@ impl Extension for LoggerExtension { ctx: &ExtensionContext<'_>, query: &str, variables: &Variables, - next: NextExtension<'_>, + next: NextParseQuery<'_>, ) -> ServerResult { let mut inner = self.inner.lock().await; inner.query = query.replace(char::is_whitespace, ""); inner.variables = variables.clone(); - let document = next.parse_query(ctx, query, variables).await?; + let document = next.run(ctx, query, variables).await?; let is_schema = document .operations .iter() @@ -61,13 +61,13 @@ impl Extension for LoggerExtension { &self, ctx: &ExtensionContext<'_>, info: ResolveInfo<'_>, - next: NextExtension<'_>, + next: NextResolve<'_>, ) -> ServerResult> { let enabled = self.inner.lock().await.enabled; if enabled { let path = info.path_node.to_string(); log::trace!(target: "async-graphql", "[ResolveStart] path: \"{}\"", path); - let res = next.resolve(ctx, info).await; + let res = next.run(ctx, info).await; if let Err(err) = &res { let inner = self.inner.lock().await; log::error!( @@ -79,7 +79,7 @@ impl Extension for LoggerExtension { log::trace!(target: "async-graphql", "[ResolveEnd] path: \"{}\"", path); res } else { - next.resolve(ctx, info).await + next.run(ctx, info).await } } } diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index e9ac0cf4..a792b7cc 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -27,7 +27,6 @@ use std::future::Future; use std::sync::Arc; use futures_util::stream::BoxStream; -use futures_util::stream::StreamExt; use crate::parser::types::ExecutableDocument; use crate::{ @@ -107,173 +106,180 @@ type ExecuteFut<'a> = &'a mut (dyn Future + Send + Unpin); type ResolveFut<'a> = &'a mut (dyn Future>> + Send + Unpin); -/// The remainder of a extension chain. -pub struct NextExtension<'a> { +/// The remainder of a extension chain for request. +pub struct NextRequest<'a> { chain: &'a [Arc], - request_fut: Option>, - parse_query_fut: Option>, - validation_fut: Option>, - execute_fut: Option>, - resolve_fut: Option>, + request_fut: RequestFut<'a>, } -impl<'a> NextExtension<'a> { - #[inline] - pub(crate) fn new(chain: &'a [Arc]) -> Self { - Self { - chain, - request_fut: None, - parse_query_fut: None, - validation_fut: None, - execute_fut: None, - resolve_fut: None, - } - } - - #[inline] - pub(crate) fn with_chain(self, chain: &'a [Arc]) -> Self { - Self { chain, ..self } - } - - #[inline] - pub(crate) fn with_request(self, fut: RequestFut<'a>) -> Self { - Self { - request_fut: Some(fut), - ..self - } - } - - #[inline] - pub(crate) fn with_parse_query(self, fut: ParseFut<'a>) -> Self { - Self { - parse_query_fut: Some(fut), - ..self - } - } - - #[inline] - pub(crate) fn with_validation(self, fut: ValidationFut<'a>) -> Self { - Self { - validation_fut: Some(fut), - ..self - } - } - - #[inline] - pub(crate) fn with_execute(self, fut: ExecuteFut<'a>) -> Self { - Self { - execute_fut: Some(fut), - ..self - } - } - - #[inline] - pub(crate) fn with_resolve(self, fut: ResolveFut<'a>) -> Self { - Self { - resolve_fut: Some(fut), - ..self - } - } - +impl<'a> NextRequest<'a> { /// Call the [Extension::request] function of next extension. - pub async fn request(mut self, ctx: &ExtensionContext<'_>) -> Response { + pub async fn run(self, ctx: &ExtensionContext<'_>) -> Response { if let Some((first, next)) = self.chain.split_first() { - first.request(ctx, self.with_chain(next)).await - } else { - self.request_fut - .take() - .expect("You definitely called the wrong function.") + first + .request( + ctx, + NextRequest { + chain: next, + request_fut: self.request_fut, + }, + ) .await + } else { + self.request_fut.await } } +} +/// The remainder of a extension chain for subscribe. +pub struct NextSubscribe<'a> { + chain: &'a [Arc], +} + +impl<'a> NextSubscribe<'a> { /// Call the [Extension::subscribe] function of next extension. - pub fn subscribe<'s>( + pub fn run<'s>( self, ctx: &ExtensionContext<'_>, stream: BoxStream<'s, Response>, ) -> BoxStream<'s, Response> { if let Some((first, next)) = self.chain.split_first() { - first.subscribe(ctx, stream, self.with_chain(next)) + first.subscribe(ctx, stream, NextSubscribe { chain: next }) } else { stream } } +} +/// The remainder of a extension chain for subscribe. +pub struct NextPrepareRequest<'a> { + chain: &'a [Arc], +} + +impl<'a> NextPrepareRequest<'a> { /// Call the [Extension::prepare_request] function of next extension. - pub async fn prepare_request( - self, - ctx: &ExtensionContext<'_>, - request: Request, - ) -> ServerResult { + pub async fn run(self, ctx: &ExtensionContext<'_>, request: Request) -> ServerResult { if let Some((first, next)) = self.chain.split_first() { first - .prepare_request(ctx, request, self.with_chain(next)) + .prepare_request(ctx, request, NextPrepareRequest { chain: next }) .await } else { Ok(request) } } +} +/// The remainder of a extension chain for parse query. +pub struct NextParseQuery<'a> { + chain: &'a [Arc], + parse_query_fut: ParseFut<'a>, +} + +impl<'a> NextParseQuery<'a> { /// Call the [Extension::parse_query] function of next extension. - pub async fn parse_query( - mut self, + pub async fn run( + self, ctx: &ExtensionContext<'_>, query: &str, variables: &Variables, ) -> ServerResult { if let Some((first, next)) = self.chain.split_first() { first - .parse_query(ctx, query, variables, self.with_chain(next)) + .parse_query( + ctx, + query, + variables, + NextParseQuery { + chain: next, + parse_query_fut: self.parse_query_fut, + }, + ) .await } else { - self.parse_query_fut - .take() - .expect("You definitely called the wrong function.") - .await + self.parse_query_fut.await } } +} +/// The remainder of a extension chain for validation. +pub struct NextValidation<'a> { + chain: &'a [Arc], + validation_fut: ValidationFut<'a>, +} + +impl<'a> NextValidation<'a> { /// Call the [Extension::validation] function of next extension. - pub async fn validation( - mut self, + pub async fn run( + self, ctx: &ExtensionContext<'_>, ) -> Result> { if let Some((first, next)) = self.chain.split_first() { - first.validation(ctx, self.with_chain(next)).await - } else { - self.validation_fut - .take() - .expect("You definitely called the wrong function.") + first + .validation( + ctx, + NextValidation { + chain: next, + validation_fut: self.validation_fut, + }, + ) .await + } else { + self.validation_fut.await } } +} +/// The remainder of a extension chain for execute. +pub struct NextExecute<'a> { + chain: &'a [Arc], + execute_fut: ExecuteFut<'a>, +} + +impl<'a> NextExecute<'a> { /// Call the [Extension::execute] function of next extension. - pub async fn execute(mut self, ctx: &ExtensionContext<'_>) -> Response { + pub async fn run(self, ctx: &ExtensionContext<'_>) -> Response { if let Some((first, next)) = self.chain.split_first() { - first.execute(ctx, self.with_chain(next)).await - } else { - self.execute_fut - .take() - .expect("You definitely called the wrong function.") + first + .execute( + ctx, + NextExecute { + chain: next, + execute_fut: self.execute_fut, + }, + ) .await + } else { + self.execute_fut.await } } +} +/// The remainder of a extension chain for resolve. +pub struct NextResolve<'a> { + chain: &'a [Arc], + resolve_fut: ResolveFut<'a>, +} + +impl<'a> NextResolve<'a> { /// Call the [Extension::resolve] function of next extension. - pub async fn resolve( - mut self, + pub async fn run( + self, ctx: &ExtensionContext<'_>, info: ResolveInfo<'_>, ) -> ServerResult> { if let Some((first, next)) = self.chain.split_first() { - first.resolve(ctx, info, self.with_chain(next)).await - } else { - self.resolve_fut - .take() - .expect("You definitely called the wrong function.") + first + .resolve( + ctx, + info, + NextResolve { + chain: next, + resolve_fut: self.resolve_fut, + }, + ) .await + } else { + self.resolve_fut.await } } } @@ -283,8 +289,8 @@ impl<'a> NextExtension<'a> { #[allow(unused_variables)] pub trait Extension: Sync + Send + 'static { /// Called at start query/mutation request. - async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { - next.request(ctx).await + async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response { + next.run(ctx).await } /// Called at subscribe request. @@ -292,9 +298,9 @@ pub trait Extension: Sync + Send + 'static { &self, ctx: &ExtensionContext<'_>, stream: BoxStream<'s, Response>, - next: NextExtension<'_>, + next: NextSubscribe<'_>, ) -> BoxStream<'s, Response> { - next.subscribe(ctx, stream) + next.run(ctx, stream) } /// Called at prepare request. @@ -302,9 +308,9 @@ pub trait Extension: Sync + Send + 'static { &self, ctx: &ExtensionContext<'_>, request: Request, - next: NextExtension<'_>, + next: NextPrepareRequest<'_>, ) -> ServerResult { - next.prepare_request(ctx, request).await + next.run(ctx, request).await } /// Called at parse query. @@ -313,23 +319,23 @@ pub trait Extension: Sync + Send + 'static { ctx: &ExtensionContext<'_>, query: &str, variables: &Variables, - next: NextExtension<'_>, + next: NextParseQuery<'_>, ) -> ServerResult { - next.parse_query(ctx, query, variables).await + next.run(ctx, query, variables).await } /// Called at validation query. async fn validation( &self, ctx: &ExtensionContext<'_>, - next: NextExtension<'_>, + next: NextValidation<'_>, ) -> Result> { - next.validation(ctx).await + next.run(ctx).await } /// Called at execute query. - async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { - next.execute(ctx).await + async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExecute<'_>) -> Response { + next.run(ctx).await } /// Called at resolve field. @@ -337,9 +343,9 @@ pub trait Extension: Sync + Send + 'static { &self, ctx: &ExtensionContext<'_>, info: ResolveInfo<'_>, - next: NextExtension<'_>, + next: NextResolve<'_>, ) -> ServerResult> { - next.resolve(ctx, info).await + next.run(ctx, info).await } } @@ -375,6 +381,7 @@ impl Extensions { } } + #[inline] pub fn attach_query_data(&mut self, data: Arc) { self.query_data = Some(data); } @@ -394,30 +401,25 @@ impl Extensions { } pub async fn request(&self, request_fut: RequestFut<'_>) -> Response { - if !self.extensions.is_empty() { - let next = NextExtension::new(&self.extensions).with_request(request_fut); - next.request(&self.create_context()).await - } else { - request_fut.await - } + let next = NextRequest { + chain: &self.extensions, + request_fut, + }; + next.run(&self.create_context()).await } pub fn subscribe<'s>(&self, stream: BoxStream<'s, Response>) -> BoxStream<'s, Response> { - if !self.extensions.is_empty() { - let next = NextExtension::new(&self.extensions); - next.subscribe(&self.create_context(), stream) - } else { - stream.boxed() - } + let next = NextSubscribe { + chain: &self.extensions, + }; + next.run(&self.create_context(), stream) } pub async fn prepare_request(&self, request: Request) -> ServerResult { - if !self.extensions.is_empty() { - let next = NextExtension::new(&self.extensions); - next.prepare_request(&self.create_context(), request).await - } else { - Ok(request) - } + let next = NextPrepareRequest { + chain: &self.extensions, + }; + next.run(&self.create_context(), request).await } pub async fn parse_query( @@ -426,34 +428,30 @@ impl Extensions { variables: &Variables, parse_query_fut: ParseFut<'_>, ) -> ServerResult { - if !self.extensions.is_empty() { - let next = NextExtension::new(&self.extensions).with_parse_query(parse_query_fut); - next.parse_query(&self.create_context(), query, variables) - .await - } else { - parse_query_fut.await - } + let next = NextParseQuery { + chain: &self.extensions, + parse_query_fut, + }; + next.run(&self.create_context(), query, variables).await } pub async fn validation( &self, validation_fut: ValidationFut<'_>, ) -> Result> { - if !self.extensions.is_empty() { - let next = NextExtension::new(&self.extensions).with_validation(validation_fut); - next.validation(&self.create_context()).await - } else { - validation_fut.await - } + let next = NextValidation { + chain: &self.extensions, + validation_fut, + }; + next.run(&self.create_context()).await } pub async fn execute(&self, execute_fut: ExecuteFut<'_>) -> Response { - if !self.extensions.is_empty() { - let next = NextExtension::new(&self.extensions).with_execute(execute_fut); - next.execute(&self.create_context()).await - } else { - execute_fut.await - } + let next = NextExecute { + chain: &self.extensions, + execute_fut, + }; + next.run(&self.create_context()).await } pub async fn resolve( @@ -461,11 +459,10 @@ impl Extensions { info: ResolveInfo<'_>, resolve_fut: ResolveFut<'_>, ) -> ServerResult> { - if !self.extensions.is_empty() { - let next = NextExtension::new(&self.extensions).with_resolve(resolve_fut); - next.resolve(&self.create_context(), info).await - } else { - resolve_fut.await - } + let next = NextResolve { + chain: &self.extensions, + resolve_fut, + }; + next.run(&self.create_context(), info).await } } diff --git a/src/extensions/opentelemetry.rs b/src/extensions/opentelemetry.rs index 57be9917..61e8bcf2 100644 --- a/src/extensions/opentelemetry.rs +++ b/src/extensions/opentelemetry.rs @@ -8,7 +8,8 @@ use opentelemetry::trace::{FutureExt, SpanKind, TraceContextExt, Tracer}; use opentelemetry::{Context as OpenTelemetryContext, Key}; use crate::extensions::{ - Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo, + Extension, ExtensionContext, ExtensionFactory, NextExecute, NextParseQuery, NextRequest, + NextResolve, NextSubscribe, NextValidation, ResolveInfo, }; use crate::{Response, ServerError, ServerResult, ValidationResult, Value}; @@ -52,8 +53,8 @@ struct OpenTelemetryExtension { #[async_trait::async_trait] impl Extension for OpenTelemetryExtension { - async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { - next.request(ctx) + async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response { + next.run(ctx) .with_context(OpenTelemetryContext::current_with_span( self.tracer .span_builder("request") @@ -67,10 +68,10 @@ impl Extension for OpenTelemetryExtension { &self, ctx: &ExtensionContext<'_>, stream: BoxStream<'s, Response>, - next: NextExtension<'_>, + next: NextSubscribe<'_>, ) -> BoxStream<'s, Response> { Box::pin( - next.subscribe(ctx, stream) + next.run(ctx, stream) .with_context(OpenTelemetryContext::current_with_span( self.tracer .span_builder("subscribe") @@ -85,7 +86,7 @@ impl Extension for OpenTelemetryExtension { ctx: &ExtensionContext<'_>, query: &str, variables: &Variables, - next: NextExtension<'_>, + next: NextParseQuery<'_>, ) -> ServerResult { let attributes = vec![ KEY_SOURCE.string(query.to_string()), @@ -97,7 +98,7 @@ impl Extension for OpenTelemetryExtension { .with_kind(SpanKind::Server) .with_attributes(attributes) .start(&*self.tracer); - next.parse_query(ctx, query, variables) + next.run(ctx, query, variables) .with_context(OpenTelemetryContext::current_with_span(span)) .await } @@ -105,14 +106,14 @@ impl Extension for OpenTelemetryExtension { async fn validation( &self, ctx: &ExtensionContext<'_>, - next: NextExtension<'_>, + next: NextValidation<'_>, ) -> Result> { let span = self .tracer .span_builder("validation") .with_kind(SpanKind::Server) .start(&*self.tracer); - next.validation(ctx) + next.run(ctx) .with_context(OpenTelemetryContext::current_with_span(span)) .map_ok(|res| { let current_cx = OpenTelemetryContext::current(); @@ -124,13 +125,13 @@ impl Extension for OpenTelemetryExtension { .await } - async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExecute<'_>) -> Response { let span = self .tracer .span_builder("execute") .with_kind(SpanKind::Server) .start(&*self.tracer); - next.execute(ctx) + next.run(ctx) .with_context(OpenTelemetryContext::current_with_span(span)) .await } @@ -139,7 +140,7 @@ impl Extension for OpenTelemetryExtension { &self, ctx: &ExtensionContext<'_>, info: ResolveInfo<'_>, - next: NextExtension<'_>, + next: NextResolve<'_>, ) -> ServerResult> { let attributes = vec![ KEY_PARENT_TYPE.string(info.parent_type.to_string()), @@ -151,7 +152,7 @@ impl Extension for OpenTelemetryExtension { .with_kind(SpanKind::Server) .with_attributes(attributes) .start(&*self.tracer); - next.resolve(ctx, info) + next.run(ctx, info) .with_context(OpenTelemetryContext::current_with_span(span)) .map_err(|err| { let current_cx = OpenTelemetryContext::current(); diff --git a/src/extensions/tracing.rs b/src/extensions/tracing.rs index 8eb35b38..666e58b9 100644 --- a/src/extensions/tracing.rs +++ b/src/extensions/tracing.rs @@ -6,7 +6,8 @@ use tracing_futures::Instrument; use tracinglib::{span, Level}; use crate::extensions::{ - Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo, + Extension, ExtensionContext, ExtensionFactory, NextExecute, NextParseQuery, NextRequest, + NextResolve, NextSubscribe, NextValidation, ResolveInfo, }; use crate::parser::types::ExecutableDocument; use crate::{Response, ServerError, ServerResult, ValidationResult, Value, Variables}; @@ -50,8 +51,8 @@ struct TracingExtension; #[async_trait::async_trait] impl Extension for TracingExtension { - async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { - next.request(ctx) + async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response { + next.run(ctx) .instrument(span!( target: "async_graphql::graphql", Level::INFO, @@ -64,9 +65,9 @@ impl Extension for TracingExtension { &self, ctx: &ExtensionContext<'_>, stream: BoxStream<'s, Response>, - next: NextExtension<'_>, + next: NextSubscribe<'_>, ) -> BoxStream<'s, Response> { - Box::pin(next.subscribe(ctx, stream).instrument(span!( + Box::pin(next.run(ctx, stream).instrument(span!( target: "async_graphql::graphql", Level::INFO, "subscribe", @@ -78,7 +79,7 @@ impl Extension for TracingExtension { ctx: &ExtensionContext<'_>, query: &str, variables: &Variables, - next: NextExtension<'_>, + next: NextParseQuery<'_>, ) -> ServerResult { let span = span!( target: "async_graphql::graphql", @@ -87,38 +88,36 @@ impl Extension for TracingExtension { source = query, variables = %serde_json::to_string(&variables).unwrap(), ); - next.parse_query(ctx, query, variables) - .instrument(span) - .await + next.run(ctx, query, variables).instrument(span).await } async fn validation( &self, ctx: &ExtensionContext<'_>, - next: NextExtension<'_>, + next: NextValidation<'_>, ) -> Result> { let span = span!( target: "async_graphql::graphql", Level::INFO, "validation" ); - next.validation(ctx).instrument(span).await + next.run(ctx).instrument(span).await } - async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExecute<'_>) -> Response { let span = span!( target: "async_graphql::graphql", Level::INFO, "execute" ); - next.execute(ctx).instrument(span).await + next.run(ctx).instrument(span).await } async fn resolve( &self, ctx: &ExtensionContext<'_>, info: ResolveInfo<'_>, - next: NextExtension<'_>, + next: NextResolve<'_>, ) -> ServerResult> { let span = span!( target: "async_graphql::graphql", @@ -128,7 +127,7 @@ impl Extension for TracingExtension { parent_type = %info.parent_type, return_type = %info.return_type, ); - next.resolve(ctx, info) + next.run(ctx, info) .instrument(span) .map_err(|err| { tracinglib::error!(target: "async_graphql::graphql", error = %err.message); diff --git a/tests/extension.rs b/tests/extension.rs index 93007045..449331c1 100644 --- a/tests/extension.rs +++ b/tests/extension.rs @@ -1,7 +1,8 @@ use std::sync::Arc; use async_graphql::extensions::{ - Extension, ExtensionContext, ExtensionFactory, NextExtension, ResolveInfo, + Extension, ExtensionContext, ExtensionFactory, NextExecute, NextParseQuery, NextPrepareRequest, + NextRequest, NextResolve, NextSubscribe, NextValidation, ResolveInfo, }; use async_graphql::futures_util::stream::BoxStream; use async_graphql::parser::types::ExecutableDocument; @@ -44,12 +45,12 @@ pub async fn test_extension_ctx() { ctx: &ExtensionContext<'_>, query: &str, variables: &Variables, - next: NextExtension<'_>, + next: NextParseQuery<'_>, ) -> ServerResult { if let Ok(data) = ctx.data::() { *data.0.lock().await = 100; } - next.parse_query(ctx, query, variables).await + next.run(ctx, query, variables).await } } @@ -131,9 +132,9 @@ pub async fn test_extension_call_order() { #[async_trait::async_trait] #[allow(unused_variables)] impl Extension for MyExtensionImpl { - async fn request(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response { self.calls.lock().await.push("request_start"); - let res = next.request(ctx).await; + let res = next.run(ctx).await; self.calls.lock().await.push("request_end"); res } @@ -142,27 +143,29 @@ pub async fn test_extension_call_order() { &self, ctx: &ExtensionContext<'_>, mut stream: BoxStream<'s, Response>, - next: NextExtension<'_>, + next: NextSubscribe<'_>, ) -> BoxStream<'s, Response> { let calls = self.calls.clone(); - let stream = async_stream::stream! { - calls.lock().await.push("subscribe_start"); - while let Some(item) = stream.next().await { - yield item; - } - calls.lock().await.push("subscribe_end"); - }; - Box::pin(stream) + next.run( + ctx, + Box::pin(async_stream::stream! { + calls.lock().await.push("subscribe_start"); + while let Some(item) = stream.next().await { + yield item; + } + calls.lock().await.push("subscribe_end"); + }), + ) } async fn prepare_request( &self, ctx: &ExtensionContext<'_>, request: Request, - next: NextExtension<'_>, + next: NextPrepareRequest<'_>, ) -> ServerResult { self.calls.lock().await.push("prepare_request_start"); - let res = next.prepare_request(ctx, request).await; + let res = next.run(ctx, request).await; self.calls.lock().await.push("prepare_request_end"); res } @@ -172,10 +175,10 @@ pub async fn test_extension_call_order() { ctx: &ExtensionContext<'_>, query: &str, variables: &Variables, - next: NextExtension<'_>, + next: NextParseQuery<'_>, ) -> ServerResult { self.calls.lock().await.push("parse_query_start"); - let res = next.parse_query(ctx, query, variables).await; + let res = next.run(ctx, query, variables).await; self.calls.lock().await.push("parse_query_end"); res } @@ -183,17 +186,17 @@ pub async fn test_extension_call_order() { async fn validation( &self, ctx: &ExtensionContext<'_>, - next: NextExtension<'_>, + next: NextValidation<'_>, ) -> Result> { self.calls.lock().await.push("validation_start"); - let res = next.validation(ctx).await; + let res = next.run(ctx).await; self.calls.lock().await.push("validation_end"); res } - async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExtension<'_>) -> Response { + async fn execute(&self, ctx: &ExtensionContext<'_>, next: NextExecute<'_>) -> Response { self.calls.lock().await.push("execute_start"); - let res = next.execute(ctx).await; + let res = next.run(ctx).await; self.calls.lock().await.push("execute_end"); res } @@ -202,10 +205,10 @@ pub async fn test_extension_call_order() { &self, ctx: &ExtensionContext<'_>, info: ResolveInfo<'_>, - next: NextExtension<'_>, + next: NextResolve<'_>, ) -> ServerResult> { self.calls.lock().await.push("resolve_start"); - let res = next.resolve(ctx, info).await; + let res = next.run(ctx, info).await; self.calls.lock().await.push("resolve_end"); res }