use std::sync::Arc; use async_graphql::{ extensions::{ Extension, ExtensionContext, ExtensionFactory, NextExecute, NextParseQuery, NextPrepareRequest, NextRequest, NextResolve, NextSubscribe, NextValidation, ResolveInfo, }, futures_util::stream::BoxStream, parser::types::ExecutableDocument, *, }; use async_graphql_value::ConstValue; use futures_util::{lock::Mutex, stream::Stream, StreamExt}; #[tokio::test] pub async fn test_extension_ctx() { #[derive(Default, Clone)] struct MyData(Arc>); struct Query; #[Object] impl Query { async fn value(&self, ctx: &Context<'_>) -> i32 { *ctx.data_unchecked::().0.lock().await } } struct Subscription; #[Subscription] impl Subscription { async fn value(&self, ctx: &Context<'_>) -> impl Stream { let data = *ctx.data_unchecked::().0.lock().await; futures_util::stream::once(async move { data }) } } struct MyExtensionImpl; #[async_trait::async_trait] impl Extension for MyExtensionImpl { async fn parse_query( &self, ctx: &ExtensionContext<'_>, query: &str, variables: &Variables, next: NextParseQuery<'_>, ) -> ServerResult { if let Ok(data) = ctx.data::() { *data.0.lock().await = 100; } next.run(ctx, query, variables).await } } struct MyExtension; impl ExtensionFactory for MyExtension { fn create(&self) -> Arc { Arc::new(MyExtensionImpl) } } // data in schema { let data = MyData::default(); let schema = Schema::build(Query, EmptyMutation, EmptySubscription) .data(data.clone()) .extension(MyExtension) .finish(); assert_eq!( schema .execute("{ value }") .await .into_result() .unwrap() .data, value! ({ "value": 100 }) ); } // data in request { let data = MyData::default(); let schema = Schema::build(Query, EmptyMutation, EmptySubscription) .extension(MyExtension) .finish(); assert_eq!( schema .execute(Request::new("{ value }").data(data.clone())) .await .into_result() .unwrap() .data, value! ({ "value": 100 }) ); } // data in session { let schema = Schema::build(Query, EmptyMutation, Subscription) .extension(MyExtension) .finish(); let mut data = Data::default(); data.insert(MyData::default()); let mut stream = schema.execute_stream_with_session_data( Request::new("subscription { value }"), Arc::new(data), ); assert_eq!( stream.next().await.unwrap().into_result().unwrap().data, value! ({ "value": 100 }) ); } } #[tokio::test] pub async fn test_extension_call_order() { struct MyExtensionImpl { calls: Arc>>, } #[async_trait::async_trait] #[allow(unused_variables)] impl Extension for MyExtensionImpl { async fn request(&self, ctx: &ExtensionContext<'_>, next: NextRequest<'_>) -> Response { self.calls.lock().await.push("request_start"); let res = next.run(ctx).await; self.calls.lock().await.push("request_end"); res } fn subscribe<'s>( &self, ctx: &ExtensionContext<'_>, mut stream: BoxStream<'s, Response>, next: NextSubscribe<'_>, ) -> BoxStream<'s, Response> { let calls = self.calls.clone(); next.run( ctx, Box::pin(async_stream::stream! { calls.lock().await.push("subscribe_start"); while let Some(item) = stream.next().await { yield item; } calls.lock().await.push("subscribe_end"); }), ) } async fn prepare_request( &self, ctx: &ExtensionContext<'_>, request: Request, next: NextPrepareRequest<'_>, ) -> ServerResult { self.calls.lock().await.push("prepare_request_start"); let res = next.run(ctx, request).await; self.calls.lock().await.push("prepare_request_end"); res } async fn parse_query( &self, ctx: &ExtensionContext<'_>, query: &str, variables: &Variables, next: NextParseQuery<'_>, ) -> ServerResult { self.calls.lock().await.push("parse_query_start"); let res = next.run(ctx, query, variables).await; self.calls.lock().await.push("parse_query_end"); res } async fn validation( &self, ctx: &ExtensionContext<'_>, next: NextValidation<'_>, ) -> Result> { self.calls.lock().await.push("validation_start"); let res = next.run(ctx).await; self.calls.lock().await.push("validation_end"); res } async fn execute( &self, ctx: &ExtensionContext<'_>, operation_name: Option<&str>, next: NextExecute<'_>, ) -> Response { assert_eq!(operation_name, Some("Abc")); self.calls.lock().await.push("execute_start"); let res = next.run(ctx, operation_name).await; self.calls.lock().await.push("execute_end"); res } async fn resolve( &self, ctx: &ExtensionContext<'_>, info: ResolveInfo<'_>, next: NextResolve<'_>, ) -> ServerResult> { self.calls.lock().await.push("resolve_start"); let res = next.run(ctx, info).await; self.calls.lock().await.push("resolve_end"); res } } struct MyExtension { calls: Arc>>, } impl ExtensionFactory for MyExtension { fn create(&self) -> Arc { Arc::new(MyExtensionImpl { calls: self.calls.clone(), }) } } struct Query; #[Object] impl Query { async fn value1(&self) -> i32 { 10 } async fn value2(&self) -> i32 { 10 } } struct Subscription; #[Subscription] impl Subscription { async fn value(&self) -> impl Stream { futures_util::stream::iter(vec![1, 2, 3]) } } { let calls: Arc>> = Default::default(); let schema = Schema::build(Query, EmptyMutation, Subscription) .extension(MyExtension { calls: calls.clone(), }) .finish(); let _ = schema .execute("query Abc { value1 value2 }") .await .into_result() .unwrap(); let calls = calls.lock().await; assert_eq!( &*calls, &vec![ "request_start", "prepare_request_start", "prepare_request_end", "parse_query_start", "parse_query_end", "validation_start", "validation_end", "execute_start", "resolve_start", "resolve_end", "resolve_start", "resolve_end", "execute_end", "request_end", ] ); } { let calls: Arc>> = Default::default(); let schema = Schema::build(Query, EmptyMutation, Subscription) .extension(MyExtension { calls: calls.clone(), }) .finish(); let mut stream = schema.execute_stream("subscription Abc { value }"); while stream.next().await.is_some() {} let calls = calls.lock().await; assert_eq!( &*calls, &vec![ "subscribe_start", "prepare_request_start", "prepare_request_end", "parse_query_start", "parse_query_end", "validation_start", "validation_end", // push 1 "execute_start", "resolve_start", "resolve_end", "execute_end", // push 2 "execute_start", "resolve_start", "resolve_end", "execute_end", // push 3 "execute_start", "resolve_start", "resolve_end", "execute_end", // end "subscribe_end", ] ); } }