diff --git a/tests/guard.rs b/tests/guard.rs index 3fe1cab4..bd2b88b6 100644 --- a/tests/guard.rs +++ b/tests/guard.rs @@ -489,3 +489,58 @@ pub async fn test_guard_race_operator() { }] ); } + +#[tokio::test] +pub async fn test_guard_use_params() { + struct EqGuard { + expect: i32, + actual: i32, + } + + #[async_trait::async_trait] + impl Guard for EqGuard { + async fn check(&self, _ctx: &Context<'_>) -> Result<()> { + if self.expect != self.actual { + Err("Forbidden".into()) + } else { + Ok(()) + } + } + } + + struct Query; + + #[Object] + impl Query { + #[graphql(guard(EqGuard(expect = "100", actual = "@value")))] + async fn get(&self, value: i32) -> i32 { + value + } + } + + let schema = Schema::new(Query, EmptyMutation, EmptySubscription); + + assert_eq!( + schema + .execute(Request::new("{ get(value: 100) }")) + .await + .into_result() + .unwrap() + .data, + value!({"get": 100}) + ); + + assert_eq!( + schema + .execute(Request::new("{ get(value: 99) }")) + .await + .into_result() + .unwrap_err(), + vec![ServerError { + message: "Forbidden".to_string(), + locations: vec![Pos { line: 1, column: 3 }], + path: vec![PathSegment::Field("get".to_owned())], + extensions: None, + }] + ); +}