diff --git a/src/schema.rs b/src/schema.rs index 5c454141..9d993fba 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -341,21 +341,15 @@ where .collect::>() .into(); - let request = extensions - .prepare_request( - &ExtensionContext { - schema_data: &self.env.data, - query_data: &Default::default(), - }, - request, - ) - .await?; - + let mut request = request; + let data = std::mem::take(&mut request.data); let ctx_extension = ExtensionContext { schema_data: &self.env.data, - query_data: &request.data, + query_data: &data, }; + let request = extensions.prepare_request(&ctx_extension, request).await?; + extensions.parse_start(&ctx_extension, &request.query, &request.variables); let document = parse_query(&request.query) .map_err(Into::::into) @@ -427,7 +421,7 @@ where operation, fragments: document.fragments, uploads: request.uploads, - ctx_data: Arc::new(request.data), + ctx_data: Arc::new(data), }; Ok((env, cache_control)) } diff --git a/tests/extension.rs b/tests/extension.rs new file mode 100644 index 00000000..f0e05186 --- /dev/null +++ b/tests/extension.rs @@ -0,0 +1,68 @@ +use async_graphql::extensions::{Extension, ExtensionContext, ExtensionFactory}; +use async_graphql::*; +use spin::Mutex; +use std::sync::Arc; + +#[async_std::test] +pub async fn test_extension_ctx() { + #[derive(Default, Clone)] + struct MyData(Arc>); + + struct Query; + + #[Object] + impl Query { + async fn value(&self) -> bool { + true + } + } + + struct MyExtensionImpl; + + #[async_trait::async_trait] + impl Extension for MyExtensionImpl { + fn parse_start( + &mut self, + ctx: &ExtensionContext<'_>, + _query_source: &str, + _variables: &Variables, + ) { + *ctx.data_unchecked::().0.lock() = 100; + } + } + + struct MyExtension; + + impl ExtensionFactory for MyExtension { + fn create(&self) -> Box { + Box::new(MyExtensionImpl) + } + } + + // data in schema + { + let data = MyData::default(); + let schema = Schema::build(Query, EmptyMutation, EmptySubscription) + .data(data.clone()) + .extension(MyExtension) + .finish(); + + schema.execute("{ value }").await.into_result().unwrap(); + assert_eq!(*data.0.lock(), 100); + } + + // data in request + { + let data = MyData::default(); + let schema = Schema::build(Query, EmptyMutation, EmptySubscription) + .extension(MyExtension) + .finish(); + + schema + .execute(Request::new("{ value }").data(data.clone())) + .await + .into_result() + .unwrap(); + assert_eq!(*data.0.lock(), 100); + } +}