Add PostGuard. #129

This commit is contained in:
Sunli 2020-06-03 14:50:06 +08:00
parent 97f268a55e
commit 746544412b
8 changed files with 416 additions and 79 deletions

View File

@ -1,4 +1,7 @@
use crate::utils::{get_rustdoc, parse_default, parse_default_with, parse_guards, parse_validator};
use crate::utils::{
get_rustdoc, parse_default, parse_default_with, parse_guards, parse_post_guards,
parse_validator,
};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{Attribute, AttributeArgs, Error, Lit, Meta, MetaList, NestedMeta, Result, Type};
@ -196,6 +199,7 @@ pub struct Field {
pub requires: Option<String>,
pub is_ref: bool,
pub guard: Option<TokenStream>,
pub post_guard: Option<TokenStream>,
}
impl Field {
@ -209,11 +213,13 @@ impl Field {
let mut requires = None;
let mut is_ref = false;
let mut guard = None;
let mut post_guard = None;
for attr in attrs {
match attr.parse_meta()? {
Meta::List(ls) if ls.path.is_ident("field") => {
guard = parse_guards(crate_name, &ls)?;
post_guard = parse_post_guards(crate_name, &ls)?;
for meta in &ls.nested {
match meta {
NestedMeta::Meta(Meta::Path(p)) if p.is_ident("skip") => {
@ -300,6 +306,7 @@ impl Field {
requires,
is_ref,
guard,
post_guard,
}))
}
}
@ -881,20 +888,17 @@ impl Scalar {
}
}
pub struct Entity {
pub guard: Option<TokenStream>,
}
pub struct Entity {}
impl Entity {
pub fn parse(crate_name: &TokenStream, attrs: &[Attribute]) -> Result<Option<Self>> {
pub fn parse(_crate_name: &TokenStream, attrs: &[Attribute]) -> Result<Option<Self>> {
for attr in attrs {
match attr.parse_meta()? {
Meta::List(ls) if ls.path.is_ident("entity") => {
let guard = parse_guards(crate_name, &ls)?;
return Ok(Some(Self { guard }));
return Ok(Some(Self {}));
}
Meta::Path(p) if p.is_ident("entity") => {
return Ok(Some(Self { guard: None }));
return Ok(Some(Self {}));
}
_ => {}
}

View File

@ -1,6 +1,6 @@
use crate::args;
use crate::output_type::OutputType;
use crate::utils::{check_reserved_name, get_crate_name, get_rustdoc};
use crate::utils::{check_reserved_name, get_crate_name, get_param_getter_ident, get_rustdoc};
use inflector::Inflector;
use proc_macro::TokenStream;
use quote::quote;
@ -44,7 +44,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
for item in &mut item_impl.items {
if let ImplItem::Method(method) = item {
if let Some(entity) = args::Entity::parse(&crate_name, &method.attrs)? {
if args::Entity::parse(&crate_name, &method.attrs)?.is_some() {
if method.sig.asyncness.is_none() {
return Err(Error::new_spanned(&method, "Must be asynchronous"));
}
@ -159,16 +159,11 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
}
let do_find = quote! { self.#field_ident(ctx, #(#use_keys),*).await.map_err(|err| err.into_error(ctx.position()))? };
let guard = entity.guard.map(
|guard| quote! { #guard.check(ctx).await.map_err(|err| err.into_error(ctx.position()))?; },
);
find_entities.push((
args.len(),
quote! {
if typename == &<#entity_type as #crate_name::Type>::type_name() {
if let (#(#key_pat),*) = (#(#key_getter),*) {
#guard
let ctx_obj = ctx.with_selection_set(&ctx.selection_set);
return #crate_name::OutputValueType::resolve(&#do_find, &ctx_obj, ctx.item).await;
}
@ -328,8 +323,10 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
Some(default) => quote! { Some(|| -> #ty { #default }) },
None => quote! { None },
};
let param_getter_name = get_param_getter_ident(&ident.ident.to_string());
get_params.push(quote! {
let #ident: #ty = ctx.param_value(#name, #default)?;
let #param_getter_name = || -> #crate_name::Result<#ty> { ctx.param_value(#name, #default) };
let #ident: #ty = #param_getter_name()?;
});
}
@ -381,6 +378,12 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
#guard.check(ctx).await
.map_err(|err| err.into_error_with_path(ctx.position(), ctx.path_node.as_ref().unwrap().to_json()))?;
});
let post_guard = field
.post_guard
.map(|guard| quote! {
#guard.check(ctx, &res).await
.map_err(|err| err.into_error_with_path(ctx.position(), ctx.path_node.as_ref().unwrap().to_json()))?;
});
resolvers.push(quote! {
if ctx.name.node == #field_name {
@ -388,7 +391,9 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
#(#get_params)*
#guard
let ctx_obj = ctx.with_selection_set(&ctx.selection_set);
return OutputValueType::resolve(&#resolve_obj, &ctx_obj, ctx.item).await;
let res = #resolve_obj;
#post_guard
return OutputValueType::resolve(&res, &ctx_obj, ctx.item).await;
}
});
@ -462,6 +467,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
}.into_error(ctx.position()))
}
#[allow(unused_variables)]
async fn find_entity(&self, ctx: &#crate_name::Context<'_>, params: &#crate_name::Value) -> #crate_name::Result<#crate_name::serde_json::Value> {
let params = match params {
#crate_name::Value::Object(params) => params,

View File

@ -95,6 +95,9 @@ pub fn generate(object_args: &args::Object, input: &mut DeriveInput) -> Result<T
let guard = field
.guard
.map(|guard| quote! { #guard.check(ctx).await.map_err(|err| err.into_error_with_path(ctx.position(), ctx.path_node.as_ref().unwrap().to_json()))?; });
let post_guard = field
.post_guard
.map(|guard| quote! { #guard.check(ctx, &res).await.map_err(|err| err.into_error_with_path(ctx.position(), ctx.path_node.as_ref().unwrap().to_json()))?; });
if field.is_ref {
getters.push(quote! {
@ -119,6 +122,7 @@ pub fn generate(object_args: &args::Object, input: &mut DeriveInput) -> Result<T
#guard
let res = self.#ident(ctx).await.map_err(|err| err.into_error_with_path(ctx.position(), ctx.path_node.as_ref().unwrap().to_json()))?;
let ctx_obj = ctx.with_selection_set(&ctx.selection_set);
#post_guard
return #crate_name::OutputValueType::resolve(&res, &ctx_obj, ctx.item).await;
}
});

View File

@ -1,6 +1,6 @@
use crate::args;
use crate::output_type::OutputType;
use crate::utils::{check_reserved_name, get_crate_name, get_rustdoc};
use crate::utils::{check_reserved_name, get_crate_name, get_param_getter_ident, get_rustdoc};
use inflector::Inflector;
use proc_macro::TokenStream;
use quote::quote;
@ -176,7 +176,9 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
Some(default) => quote! { Some(|| -> #ty { #default }) },
None => quote! { None },
};
let param_getter_name = get_param_getter_ident(&ident.ident.to_string());
get_params.push(quote! {
let #param_getter_name = || -> #crate_name::Result<#ty> { ctx.param_value(#name, #default) };
let #ident: #ty = ctx.param_value(#name, #default)?;
});
}
@ -229,6 +231,12 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
let guard = field.guard.map(|guard| quote! {
#guard.check(ctx).await.map_err(|err| err.into_error_with_path(ctx.position(), ctx.path_node.as_ref().unwrap().to_json()))?;
});
if field.post_guard.is_some() {
return Err(Error::new_spanned(
method,
"The subscription field does not support post guard",
));
}
create_stream.push(quote! {
if ctx.name.node == #field_name {

View File

@ -133,8 +133,8 @@ pub fn parse_guards(crate_name: &TokenStream, args: &MetaList) -> Result<Option<
if let Lit::Str(value) = &nv.lit {
let value_str = value.value();
if value_str.starts_with('@') {
let id = Ident::new(&value_str[1..], value.span());
params.push(quote! { #name: &#id });
let getter_name = get_param_getter_ident(&value_str[1..]);
params.push(quote! { #name: #getter_name()? });
} else {
let expr = syn::parse_str::<Expr>(&value_str)?;
params.push(quote! { #name: (#expr).into() });
@ -170,6 +170,60 @@ pub fn parse_guards(crate_name: &TokenStream, args: &MetaList) -> Result<Option<
Ok(None)
}
pub fn parse_post_guards(crate_name: &TokenStream, args: &MetaList) -> Result<Option<TokenStream>> {
for arg in &args.nested {
if let NestedMeta::Meta(Meta::List(ls)) = arg {
if ls.path.is_ident("post_guard") {
let mut guards = None;
for item in &ls.nested {
if let NestedMeta::Meta(Meta::List(ls)) = item {
let ty = &ls.path;
let mut params = Vec::new();
for attr in &ls.nested {
if let NestedMeta::Meta(Meta::NameValue(nv)) = attr {
let name = &nv.path;
if let Lit::Str(value) = &nv.lit {
let value_str = value.value();
if value_str.starts_with('@') {
let getter_name = get_param_getter_ident(&value_str[1..]);
params.push(quote! { #name: #getter_name()? });
} else {
let expr = syn::parse_str::<Expr>(&value_str)?;
params.push(quote! { #name: (#expr).into() });
}
} else {
return Err(Error::new_spanned(
&nv.lit,
"Value must be string literal",
));
}
} else {
return Err(Error::new_spanned(attr, "Invalid property for guard"));
}
}
let guard = quote! { #ty { #(#params),* } };
if guards.is_none() {
guards = Some(guard);
} else {
guards = Some(
quote! { #crate_name::guard::PostGuardExt::and(#guard, #guards) },
);
}
} else {
return Err(Error::new_spanned(item, "Invalid guard"));
}
}
return Ok(guards);
}
}
}
Ok(None)
}
pub fn get_rustdoc(attrs: &[Attribute]) -> Result<Option<String>> {
let mut full_docs = String::new();
for attr in attrs {
@ -231,3 +285,7 @@ pub fn parse_default_with(lit: &Lit) -> Result<TokenStream> {
))
}
}
pub fn get_param_getter_ident(name: &str) -> Ident {
Ident::new(&format!("__{}_getter", name), Span::call_site())
}

View File

@ -1,10 +1,11 @@
//! Field guards
use crate::{Context, FieldResult};
use serde::export::PhantomData;
/// Field guard
///
/// Guard is a precondition for a field that is resolved if `Ok(()` is returned, otherwise an error is returned.
/// Guard is a pre-condition for a field that is resolved if `Ok(()` is returned, otherwise an error is returned.
#[async_trait::async_trait]
pub trait Guard {
#[allow(missing_docs)]
@ -31,3 +32,35 @@ impl<A: Guard + Send + Sync, B: Guard + Send + Sync> Guard for And<A, B> {
self.1.check(ctx).await
}
}
/// Field post guard
///
/// Guard is a post-condition for a field that is resolved if `Ok(()` is returned, otherwise an error is returned.
#[async_trait::async_trait]
pub trait PostGuard<T: Send + Sync> {
#[allow(missing_docs)]
async fn check(&self, ctx: &Context<'_>, result: &T) -> FieldResult<()>;
}
/// An extension trait for `PostGuard<T>`
pub trait PostGuardExt<T: Send + Sync>: PostGuard<T> + Sized {
/// Merge the two guards.
fn and<R: PostGuard<T>>(self, other: R) -> PostAnd<T, Self, R> {
PostAnd(self, other, PhantomData)
}
}
impl<T: PostGuard<R>, R: Send + Sync> PostGuardExt<R> for T {}
/// PostGuard for `GuardExt<T>::and`
pub struct PostAnd<T: Send + Sync, A: PostGuard<T>, B: PostGuard<T>>(A, B, PhantomData<T>);
#[async_trait::async_trait]
impl<T: Send + Sync, A: PostGuard<T> + Send + Sync, B: PostGuard<T> + Send + Sync> PostGuard<T>
for PostAnd<T, A, B>
{
async fn check(&self, ctx: &Context<'_>, result: &T) -> FieldResult<()> {
self.0.check(ctx, result).await?;
self.1.check(ctx, result).await
}
}

View File

@ -61,11 +61,6 @@ pub async fn test_guard() {
async fn obj(&self) -> MyObj {
MyObj { value: 99 }
}
#[entity(guard(RoleGuard(role = "Role::Admin")))]
async fn find_obj(&self, value: i32) -> MyObj {
MyObj { value }
}
}
struct Subscription;
@ -190,55 +185,6 @@ pub async fn test_guard() {
},
}
);
let query = r#"{
_entities(representations: [{__typename: "MyObj", value: 1}]) {
__typename
... on MyObj {
value
}
}
}"#;
assert_eq!(
QueryBuilder::new(query)
.data(Role::Admin)
.execute(&schema)
.await
.unwrap()
.data,
serde_json::json!({
"_entities": [
{"__typename": "MyObj", "value": 1},
]
})
);
let query = r#"{
_entities(representations: [{__typename: "MyObj", value: 1}]) {
__typename
... on MyObj {
value
}
}
}"#;
assert_eq!(
QueryBuilder::new(query)
.data(Role::Guest)
.execute(&schema)
.await
.unwrap_err(),
Error::Query {
pos: Pos {
line: 2,
column: 13
},
path: None,
err: QueryError::FieldError {
err: "Forbidden".to_string(),
extended_error: None,
},
}
);
}
#[async_std::test]
@ -320,14 +266,14 @@ pub async fn test_multiple_guards() {
#[async_std::test]
pub async fn test_guard_forward_arguments() {
struct UserGuard<'a> {
id: &'a ID,
struct UserGuard {
id: ID,
}
#[async_trait::async_trait]
impl<'a> Guard for UserGuard<'a> {
impl Guard for UserGuard {
async fn check(&self, ctx: &Context<'_>) -> FieldResult<()> {
if ctx.data_opt::<ID>() != Some(self.id) {
if ctx.data_opt::<ID>() != Some(&self.id) {
Err("Forbidden".into())
} else {
Ok(())

278
tests/post_guard.rs Normal file
View File

@ -0,0 +1,278 @@
use async_graphql::guard::PostGuard;
use async_graphql::*;
#[derive(Eq, PartialEq, Copy, Clone)]
enum Role {
Admin,
Guest,
}
struct RoleGuard {
role: Role,
}
#[async_trait::async_trait]
impl PostGuard<i32> for RoleGuard {
async fn check(&self, ctx: &Context<'_>, _result: &i32) -> FieldResult<()> {
if ctx.data_opt::<Role>() == Some(&self.role) {
Ok(())
} else {
Err("Forbidden".into())
}
}
}
#[SimpleObject]
struct MyObj {
#[field(post_guard(UserGuard(username = r#""test""#, value = "88")))]
value: i32,
}
struct Username(String);
struct UserGuard {
value: i32,
username: String,
}
#[async_trait::async_trait]
impl PostGuard<i32> for UserGuard {
async fn check(&self, ctx: &Context<'_>, result: &i32) -> FieldResult<()> {
assert_eq!(*result, self.value);
if ctx.data_opt::<Username>().as_ref().map(|s| s.0.as_str()) == Some(&self.username) {
Ok(())
} else {
Err("Forbidden".into())
}
}
}
#[async_trait::async_trait]
impl PostGuard<MyObj> for UserGuard {
async fn check(&self, ctx: &Context<'_>, result: &MyObj) -> FieldResult<()> {
assert_eq!(result.value, self.value);
if ctx.data_opt::<Username>().as_ref().map(|s| s.0.as_str()) == Some(&self.username) {
Ok(())
} else {
Err("Forbidden".into())
}
}
}
#[async_std::test]
pub async fn test_post_guard() {
struct Query;
#[Object]
impl Query {
#[field(post_guard(UserGuard(username = r#""test""#, value = "99")))]
async fn value(&self) -> i32 {
99
}
async fn obj(&self) -> MyObj {
MyObj { value: 88 }
}
}
let schema = Schema::new(Query, EmptyMutation, EmptySubscription);
let query = "{ value }";
assert_eq!(
QueryBuilder::new(query)
.data(Username("test".to_string()))
.execute(&schema)
.await
.unwrap()
.data,
serde_json::json!({
"value": 99
})
);
let query = "{ value }";
assert_eq!(
QueryBuilder::new(query)
.data(Username("test1".to_string()))
.execute(&schema)
.await
.unwrap_err(),
Error::Query {
pos: Pos { line: 1, column: 3 },
path: Some(serde_json::json!(["value"])),
err: QueryError::FieldError {
err: "Forbidden".to_string(),
extended_error: None,
},
}
);
let query = "{ obj { value } }";
assert_eq!(
QueryBuilder::new(query)
.data(Username("test".to_string()))
.execute(&schema)
.await
.unwrap()
.data,
serde_json::json!({
"obj": { "value": 88 }
})
);
let query = "{ obj { value } }";
assert_eq!(
QueryBuilder::new(query)
.data(Username("test1".to_string()))
.execute(&schema)
.await
.unwrap_err(),
Error::Query {
pos: Pos { line: 1, column: 9 },
path: Some(serde_json::json!(["obj", "value"])),
err: QueryError::FieldError {
err: "Forbidden".to_string(),
extended_error: None,
},
}
);
}
#[async_std::test]
pub async fn test_multiple_post_guards() {
#[SimpleObject]
struct Query {
#[field(post_guard(
RoleGuard(role = "Role::Admin"),
UserGuard(username = r#""test""#, value = "10")
))]
value: i32,
}
let schema = Schema::new(Query { value: 10 }, EmptyMutation, EmptySubscription);
let query = "{ value }";
assert_eq!(
QueryBuilder::new(query)
.data(Role::Admin)
.data(Username("test".to_string()))
.execute(&schema)
.await
.unwrap()
.data,
serde_json::json!({"value": 10})
);
let query = "{ value }";
assert_eq!(
QueryBuilder::new(query)
.data(Role::Guest)
.data(Username("test".to_string()))
.execute(&schema)
.await
.unwrap_err(),
Error::Query {
pos: Pos { line: 1, column: 3 },
path: Some(serde_json::json!(["value"])),
err: QueryError::FieldError {
err: "Forbidden".to_string(),
extended_error: None,
},
}
);
let query = "{ value }";
assert_eq!(
QueryBuilder::new(query)
.data(Role::Admin)
.data(Username("test1".to_string()))
.execute(&schema)
.await
.unwrap_err(),
Error::Query {
pos: Pos { line: 1, column: 3 },
path: Some(serde_json::json!(["value"])),
err: QueryError::FieldError {
err: "Forbidden".to_string(),
extended_error: None,
},
}
);
let query = "{ value }";
assert_eq!(
QueryBuilder::new(query)
.data(Role::Guest)
.data(Username("test1".to_string()))
.execute(&schema)
.await
.unwrap_err(),
Error::Query {
pos: Pos { line: 1, column: 3 },
path: Some(serde_json::json!(["value"])),
err: QueryError::FieldError {
err: "Forbidden".to_string(),
extended_error: None,
},
}
);
}
#[async_std::test]
pub async fn test_post_guard_forward_arguments() {
struct UserGuard {
id: ID,
}
#[async_trait::async_trait]
impl PostGuard<ID> for UserGuard {
async fn check(&self, ctx: &Context<'_>, result: &ID) -> FieldResult<()> {
assert_eq!(result.as_str(), "haha");
if ctx.data_opt::<ID>() != Some(&self.id) {
Err("Forbidden".into())
} else {
Ok(())
}
}
}
struct QueryRoot;
#[Object]
impl QueryRoot {
#[field(post_guard(UserGuard(id = "@_id")))]
async fn user(&self, _id: ID) -> ID {
"haha".into()
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription);
let query = r#"{ user(id: "abc") }"#;
assert_eq!(
QueryBuilder::new(query)
.data(ID::from("abc"))
.execute(&schema)
.await
.unwrap()
.data,
serde_json::json!({"user": "haha"})
);
let query = r#"{ user(id: "abc") }"#;
assert_eq!(
QueryBuilder::new(query)
.data(ID::from("aaa"))
.execute(&schema)
.await
.unwrap_err(),
Error::Query {
pos: Pos { line: 1, column: 3 },
path: Some(serde_json::json!(["user"])),
err: QueryError::FieldError {
err: "Forbidden".to_string(),
extended_error: None,
},
}
);
}