From ec5a466bcb27a917629cc2f4d22c58a3004db010 Mon Sep 17 00:00:00 2001 From: Aurelien Foucault Date: Tue, 6 Oct 2020 00:10:15 +0200 Subject: [PATCH] Add chain and race operator --- derive/src/utils.rs | 46 +++++++- src/guard.rs | 4 +- tests/guard.rs | 253 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 298 insertions(+), 5 deletions(-) diff --git a/derive/src/utils.rs b/derive/src/utils.rs index 796a9307..2a6a1035 100644 --- a/derive/src/utils.rs +++ b/derive/src/utils.rs @@ -121,10 +121,8 @@ pub fn generate_guards( crate_name: &TokenStream, args: &Meta, ) -> GeneratorResult> { - println!("{:#?}\n", args); match args { Meta::List(args) => { - println!("args = {:#?}\n", args.path); match args.path.get_ident() { Some(ident) => match ident.to_string().as_str() { "guard" => { @@ -189,6 +187,50 @@ pub fn generate_guards( quote! { #crate_name::guard::GuardExt::or(#first_rule, #second_rule) }, )) } + "chain" => { + if args.nested.len() < 2 { + return Err(Error::new_spanned( + args, + "chain operator need at least 1 operand.", + ) + .into()); + } + let mut guards: Option = None; + for arg in &args.nested { + if let NestedMeta::Meta(rule) = &arg { + let guard = generate_guards(crate_name, rule)?; + if guards.is_none() { + guards = guard; + } else { + guards = + Some(quote! { #crate_name::guard::GuardExt::and(#guard, #guards) }); + } + } + } + Ok(guards) + } + "race" => { + if args.nested.len() < 2 { + return Err(Error::new_spanned( + args, + "race operator need at least 1 operand.", + ) + .into()); + } + let mut guards: Option = None; + for arg in &args.nested { + if let NestedMeta::Meta(rule) = &arg { + let guard = generate_guards(crate_name, rule)?; + if guards.is_none() { + guards = guard; + } else { + guards = + Some(quote! { #crate_name::guard::GuardExt::or(#guard, #guards) }); + } + } + } + Ok(guards) + } _ => { let ty = &args.path; let mut params = Vec::new(); diff --git a/src/guard.rs b/src/guard.rs index bcc25e1e..33f423a6 100644 --- a/src/guard.rs +++ b/src/guard.rs @@ -35,8 +35,8 @@ pub struct And(A, B); #[async_trait::async_trait] impl Guard for And { async fn check(&self, ctx: &Context<'_>) -> Result<()> { - let second_result = self.1.check(ctx).await; - self.0.check(ctx).await.and(second_result) + self.0.check(ctx).await?; + self.1.check(ctx).await } } diff --git a/tests/guard.rs b/tests/guard.rs index 0a33e5b7..3b958498 100644 --- a/tests/guard.rs +++ b/tests/guard.rs @@ -40,6 +40,23 @@ impl Guard for UserGuard { } } +struct Age(i32); + +struct AgeGuard { + age: i32, +} + +#[async_trait::async_trait] +impl Guard for AgeGuard { + async fn check(&self, ctx: &Context<'_>) -> Result<()> { + if ctx.data_opt::().map(|name| &name.0) == Some(&self.age) { + Ok(()) + } else { + Err("Forbidden".into()) + } + } +} + #[async_std::test] pub async fn test_guard_simple_rule() { #[derive(SimpleObject)] @@ -48,7 +65,17 @@ pub async fn test_guard_simple_rule() { value: i32, } - let schema = Schema::new(Query { value: 10 }, EmptyMutation, EmptySubscription); + struct Subscription; + + #[Subscription] + impl Subscription { + #[graphql(guard(RoleGuard(role = "Role::Admin")))] + async fn values(&self) -> impl Stream { + futures::stream::iter(vec![1, 2, 3]) + } + } + + let schema = Schema::new(Query { value: 10 }, EmptyMutation, Subscription); let query = "{ value }"; assert_eq!( @@ -73,6 +100,38 @@ pub async fn test_guard_simple_rule() { extensions: None, }] ); + + assert_eq!( + schema + .execute_stream(Request::new("subscription { values }").data(Role::Admin)) + .map(|item| item.data) + .collect::>() + .await, + vec![ + serde_json::json! ({"values": 1}), + serde_json::json! ({"values": 2}), + serde_json::json! ({"values": 3}) + ] + ); + + assert_eq!( + schema + .execute_stream(Request::new("subscription { values }").data(Role::Guest)) + .boxed() + .next() + .await + .unwrap() + .errors, + vec![ServerError { + message: "Forbidden".to_string(), + locations: vec![Pos { + line: 1, + column: 16 + }], + path: vec![PathSegment::Field("values".to_owned())], + extensions: None, + }] + ); } #[async_std::test] @@ -227,3 +286,195 @@ pub async fn test_guard_or_operator() { }] ); } + +#[async_std::test] +pub async fn test_guard_chain_operator() { + #[derive(SimpleObject)] + struct Query { + #[graphql(guard(chain(RoleGuard(role = "Role::Admin"), UserGuard(username = r#""test""#), AgeGuard(age = r#"21"#))))] + value: i32, + } + + let schema = Schema::new(Query { value: 10 }, EmptyMutation, EmptySubscription); + + let query = "{ value }"; + assert_eq!( + schema + .execute( + Request::new(query) + .data(Role::Admin) + .data(Username("test".to_string())) + .data(Age(21)) + ) + .await + .data, + serde_json::json!({"value": 10}) + ); + + let query = "{ value }"; + assert_eq!( + schema + .execute( + Request::new(query) + .data(Role::Guest) + .data(Username("test".to_string())) + .data(Age(21)) + ) + .await + .into_result() + .unwrap_err(), + vec![ServerError { + message: "Forbidden".to_string(), + locations: vec![Pos { line: 1, column: 3 }], + path: vec![PathSegment::Field("value".to_owned())], + extensions: None, + }] + ); + + let query = "{ value }"; + assert_eq!( + schema + .execute( + Request::new(query) + .data(Role::Admin) + .data(Username("test1".to_string())) + .data(Age(21)) + ) + .await + .into_result() + .unwrap_err(), + vec![ServerError { + message: "Forbidden".to_string(), + locations: vec![Pos { line: 1, column: 3 }], + path: vec![PathSegment::Field("value".to_owned())], + extensions: None, + }] + ); + + let query = "{ value }"; + assert_eq!( + schema + .execute( + Request::new(query) + .data(Role::Admin) + .data(Username("test".to_string())) + .data(Age(22)) + ) + .await + .into_result() + .unwrap_err(), + vec![ServerError { + message: "Forbidden".to_string(), + locations: vec![Pos { line: 1, column: 3 }], + path: vec![PathSegment::Field("value".to_owned())], + extensions: None, + }] + ); + + let query = "{ value }"; + assert_eq!( + schema + .execute( + Request::new(query) + .data(Role::Guest) + .data(Username("test1".to_string())) + .data(Age(22)) + ) + .await + .into_result() + .unwrap_err(), + vec![ServerError { + message: "Forbidden".to_string(), + locations: vec![Pos { line: 1, column: 3 }], + path: vec![PathSegment::Field("value".to_owned())], + extensions: None, + }] + ); +} + +#[async_std::test] +pub async fn test_guard_race_operator() { + #[derive(SimpleObject)] + struct Query { + #[graphql(guard(race(RoleGuard(role = "Role::Admin"), UserGuard(username = r#""test""#), AgeGuard(age = r#"21"#))))] + value: i32, + } + + let schema = Schema::new(Query { value: 10 }, EmptyMutation, EmptySubscription); + + let query = "{ value }"; + assert_eq!( + schema + .execute( + Request::new(query) + .data(Role::Admin) + .data(Username("test".to_string())) + .data(Age(21)) + ) + .await + .data, + serde_json::json!({"value": 10}) + ); + + let query = "{ value }"; + assert_eq!( + schema + .execute( + Request::new(query) + .data(Role::Guest) + .data(Username("test".to_string())) + .data(Age(22)) + ) + .await + .data, + serde_json::json!({"value": 10}) + ); + + let query = "{ value }"; + assert_eq!( + schema + .execute( + Request::new(query) + .data(Role::Admin) + .data(Username("test1".to_string())) + .data(Age(22)) + ) + .await + .data, + serde_json::json!({"value": 10}) + ); + + let query = "{ value }"; + assert_eq!( + schema + .execute( + Request::new(query) + .data(Role::Guest) + .data(Username("test1".to_string())) + .data(Age(21)) + ) + .await + .data, + serde_json::json!({"value": 10}) + ); + + let query = "{ value }"; + assert_eq!( + schema + .execute( + Request::new(query) + .data(Role::Guest) + .data(Username("test1".to_string())) + .data(Age(22)) + ) + .await + .into_result() + .unwrap_err(), + vec![ServerError { + message: "Forbidden".to_string(), + locations: vec![Pos { line: 1, column: 3 }], + path: vec![PathSegment::Field("value".to_owned())], + extensions: None, + }] + ); +} \ No newline at end of file