From f24c6fceffdf97383cc15f6ed28959d243079b00 Mon Sep 17 00:00:00 2001 From: sunli Date: Sat, 9 May 2020 18:34:57 +0800 Subject: [PATCH] Add ability to forward field arguments to guard #59 --- async-graphql-derive/src/object.rs | 2 +- async-graphql-derive/src/subscription.rs | 5 +- async-graphql-derive/src/utils.rs | 10 +++- tests/guard.rs | 58 ++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 6 deletions(-) diff --git a/async-graphql-derive/src/object.rs b/async-graphql-derive/src/object.rs index 44e9e2bd..10996926 100644 --- a/async-graphql-derive/src/object.rs +++ b/async-graphql-derive/src/object.rs @@ -387,8 +387,8 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< resolvers.push(quote! { if ctx.name.as_str() == #field_name { use #crate_name::OutputValueType; - #guard #(#get_params)* + #guard let ctx_obj = ctx.with_selection_set(&ctx.selection_set); return OutputValueType::resolve(&#resolve_obj, &ctx_obj, ctx.position()).await; } diff --git a/async-graphql-derive/src/subscription.rs b/async-graphql-derive/src/subscription.rs index d4ab6550..825ebfcf 100644 --- a/async-graphql-derive/src/subscription.rs +++ b/async-graphql-derive/src/subscription.rs @@ -238,10 +238,9 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< if ctx.name.as_str() == #field_name { use #crate_name::futures::stream::{StreamExt, TryStreamExt}; - #guard - - let field_name = std::sync::Arc::new(ctx.result_name().to_string()); #(#get_params)* + #guard + let field_name = std::sync::Arc::new(ctx.result_name().to_string()); let field_selection_set = std::sync::Arc::new(ctx.selection_set.clone()); let schema = schema.clone(); let pos = ctx.position(); diff --git a/async-graphql-derive/src/utils.rs b/async-graphql-derive/src/utils.rs index d1860fed..22fbd00e 100644 --- a/async-graphql-derive/src/utils.rs +++ b/async-graphql-derive/src/utils.rs @@ -196,8 +196,14 @@ pub fn parse_guards(crate_name: &TokenStream, args: &MetaList) -> Result(&value.value())?; - params.push(quote! { #name: #expr.into() }); + let value_str = value.value(); + if value_str.starts_with('@') { + let id = Ident::new(&value_str[1..], value.span()); + params.push(quote! { #name: &#id }); + } else { + let expr = syn::parse_str::(&value_str)?; + params.push(quote! { #name: #expr.into() }); + } } else { return Err(Error::new_spanned( &nv.lit, diff --git a/tests/guard.rs b/tests/guard.rs index fb6efbe9..666ff882 100644 --- a/tests/guard.rs +++ b/tests/guard.rs @@ -317,3 +317,61 @@ pub async fn test_multiple_guards() { } ); } + +#[async_std::test] +pub async fn test_guard_forward_arguments() { + struct UserGuard<'a> { + id: &'a ID, + } + + #[async_trait::async_trait] + impl<'a> Guard for UserGuard<'a> { + async fn check(&self, ctx: &Context<'_>) -> FieldResult<()> { + if ctx.data_opt::() != Some(self.id) { + Err("Forbidden".into()) + } else { + Ok(()) + } + } + } + + struct QueryRoot; + + #[Object] + impl QueryRoot { + #[field(guard(UserGuard(id = "@id")))] + async fn user(&self, id: ID) -> ID { + id + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription); + + let query = r#"{ user(id: "abc") }"#; + assert_eq!( + QueryBuilder::new(query) + .data(ID::from("abc")) + .execute(&schema) + .await + .unwrap() + .data, + serde_json::json!({"user": "abc"}) + ); + + let query = r#"{ user(id: "abc") }"#; + assert_eq!( + QueryBuilder::new(query) + .data(ID::from("aaa")) + .execute(&schema) + .await + .unwrap_err(), + Error::Query { + pos: Pos { line: 1, column: 3 }, + path: Some(serde_json::json!(["user"])), + err: QueryError::FieldError { + err: "Forbidden".to_string(), + extended_error: None, + }, + } + ); +}