From 331e65a72b768b5b2d73b107a47aa9170db7aa73 Mon Sep 17 00:00:00 2001 From: sunli Date: Sat, 2 May 2020 07:57:34 +0800 Subject: [PATCH] Add field guard --- async-graphql-derive/src/args.rs | 38 ++- async-graphql-derive/src/interface.rs | 52 ++-- async-graphql-derive/src/object.rs | 140 ++++++--- async-graphql-derive/src/simple_object.rs | 20 +- async-graphql-derive/src/subscription.rs | 73 +++-- async-graphql-derive/src/utils.rs | 60 +++- docs/zh-CN/src/input_value_validators.md | 6 +- src/error.rs | 6 +- src/guard.rs | 33 ++ src/lib.rs | 9 +- src/model/enum_value.rs | 10 +- src/query.rs | 1 + src/validation/test_harness.rs | 1 + src/validators/mod.rs | 6 +- tests/federation.rs | 16 +- tests/guard.rs | 362 ++++++++++++++++++++++ 16 files changed, 688 insertions(+), 145 deletions(-) create mode 100644 src/guard.rs create mode 100644 tests/guard.rs diff --git a/async-graphql-derive/src/args.rs b/async-graphql-derive/src/args.rs index 91b0a3a7..4ddb26c1 100644 --- a/async-graphql-derive/src/args.rs +++ b/async-graphql-derive/src/args.rs @@ -1,4 +1,4 @@ -use crate::utils::{parse_validator, parse_value}; +use crate::utils::{parse_guards, parse_validator, parse_value}; use graphql_parser::query::Value; use proc_macro2::TokenStream; use quote::quote; @@ -215,10 +215,11 @@ pub struct Field { pub provides: Option, pub requires: Option, pub is_ref: bool, + pub guard: Option, } impl Field { - pub fn parse(attrs: &[Attribute]) -> Result> { + pub fn parse(crate_name: &TokenStream, attrs: &[Attribute]) -> Result> { let mut name = None; let mut desc = None; let mut deprecation = None; @@ -227,10 +228,12 @@ impl Field { let mut provides = None; let mut requires = None; let mut is_ref = false; + let mut guard = None; for attr in attrs { match attr.parse_meta()? { Meta::List(ls) if ls.path.is_ident("field") => { + guard = parse_guards(crate_name, &ls)?; for meta in &ls.nested { match meta { NestedMeta::Meta(Meta::Path(p)) if p.is_ident("skip") => { @@ -312,6 +315,7 @@ impl Field { provides, requires, is_ref, + guard, })) } } @@ -665,7 +669,6 @@ pub struct InterfaceField { pub ty: Type, pub args: Vec, pub deprecation: Option, - pub context: bool, pub external: bool, pub provides: Option, pub requires: Option, @@ -678,16 +681,12 @@ impl InterfaceField { let mut ty = None; let mut args = Vec::new(); let mut deprecation = None; - let mut context = false; let mut external = false; let mut provides = None; let mut requires = None; for meta in &ls.nested { match meta { - NestedMeta::Meta(Meta::Path(p)) if p.is_ident("context") => { - context = true; - } NestedMeta::Meta(Meta::Path(p)) if p.is_ident("external") => { external = true; } @@ -773,7 +772,6 @@ impl InterfaceField { ty: ty.unwrap(), args, deprecation, - context, external, requires, provides, @@ -887,3 +885,27 @@ impl Scalar { Ok(Self { internal }) } } + +#[derive(Debug)] +pub struct Entity { + pub guard: Option, +} + +impl Entity { + pub fn parse(crate_name: &TokenStream, attrs: &[Attribute]) -> Result> { + for attr in attrs { + match attr.parse_meta()? { + Meta::List(ls) if ls.path.is_ident("entity") => { + let guard = parse_guards(crate_name, &ls)?; + return Ok(Some(Self { guard })); + } + Meta::Path(p) if p.is_ident("entity") => { + return Ok(Some(Self { guard: None })); + } + _ => {} + } + } + + Ok(None) + } +} diff --git a/async-graphql-derive/src/interface.rs b/async-graphql-derive/src/interface.rs index 3a6368fc..bd99731e 100644 --- a/async-graphql-derive/src/interface.rs +++ b/async-graphql-derive/src/interface.rs @@ -90,7 +90,6 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result ty, args, deprecation, - context, external, provides, requires, @@ -112,10 +111,8 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result None => quote! { None }, }; - if *context { - decl_params.push(quote! { ctx: &'ctx #crate_name::Context<'ctx> }); - use_params.push(quote! { ctx }); - } + decl_params.push(quote! { ctx: &'ctx #crate_name::Context<'ctx> }); + use_params.push(quote! { ctx }); for InterfaceFieldArgument { name, @@ -168,21 +165,6 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result }); } - let ctx_lifetime = if *context { - quote! { <'ctx> } - } else { - quote! {} - }; - - methods.push(quote! { - #[inline] - async fn #method_name #ctx_lifetime(&self, #(#decl_params),*) -> #ty { - match self { - #(#calls,)* - } - } - }); - let desc = desc .as_ref() .map(|s| quote! {Some(#s)}) @@ -192,8 +174,21 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result .map(|s| quote! {Some(#s)}) .unwrap_or_else(|| quote! {None}); - let ty = OutputType::parse(ty)?; - let schema_ty = ty.value_type(); + let oty = OutputType::parse(ty)?; + let ty = match oty { + OutputType::Value(ty) => ty, + OutputType::Result(_, ty) => ty, + }; + let schema_ty = oty.value_type(); + + methods.push(quote! { + #[inline] + async fn #method_name <'ctx>(&self, #(#decl_params),*) -> #crate_name::FieldResult<#ty> { + match self { + #(#calls,)* + } + } + }); schema_fields.push(quote! { fields.insert(#name.to_string(), #crate_name::registry::Field { @@ -213,16 +208,9 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result }); }); - let resolve_obj = match &ty { - OutputType::Value(_) => quote! { - self.#method_name(#(#use_params),*).await - }, - OutputType::Result(_, _) => { - quote! { - self.#method_name(#(#use_params),*).await. - map_err(|err| err.into_error_with_path(ctx.position, ctx.path_node.as_ref().unwrap().to_json()))? - } - } + let resolve_obj = quote! { + self.#method_name(#(#use_params),*).await. + map_err(|err| err.into_error_with_path(ctx.position, ctx.path_node.as_ref().unwrap().to_json()))? }; resolvers.push(quote! { diff --git a/async-graphql-derive/src/object.rs b/async-graphql-derive/src/object.rs index dc31e40b..56536963 100644 --- a/async-graphql-derive/src/object.rs +++ b/async-graphql-derive/src/object.rs @@ -3,8 +3,9 @@ use crate::output_type::OutputType; use crate::utils::{build_value_repr, check_reserved_name, get_crate_name}; use inflector::Inflector; use proc_macro::TokenStream; +use proc_macro2::{Ident, Span}; use quote::quote; -use syn::{Error, FnArg, ImplItem, ItemImpl, Pat, Result, ReturnType, Type, TypeReference}; +use syn::{Block, Error, FnArg, ImplItem, ItemImpl, Pat, Result, ReturnType, Type, TypeReference}; pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result { let crate_name = get_crate_name(object_args.internal); @@ -43,14 +44,15 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< for item in &mut item_impl.items { if let ImplItem::Method(method) = item { - if method.attrs.iter().any(|attr| attr.path.is_ident("entity")) { + if let Some(entity) = args::Entity::parse(&crate_name, &method.attrs)? { let ty = match &method.sig.output { ReturnType::Type(_, ty) => OutputType::parse(ty)?, ReturnType::Default => { return Err(Error::new_spanned(&method.sig.output, "Missing type")) } }; - let mut arg_ctx = false; + let mut create_ctx = true; + let mut arg_ctx = Ident::new("ctx", Span::call_site()); let mut args = Vec::new(); for (idx, arg) in method.sig.inputs.iter_mut().enumerate() { @@ -72,13 +74,13 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< match (&*pat.pat, &*pat.ty) { (Pat::Ident(arg_ident), Type::Path(arg_ty)) => { args.push(( - arg_ident, - arg_ty, + arg_ident.clone(), + arg_ty.clone(), args::Argument::parse(&crate_name, &pat.attrs)?, )); pat.attrs.clear(); } - (_, Type::Reference(TypeReference { elem, .. })) => { + (arg, Type::Reference(TypeReference { elem, .. })) => { if let Type::Path(path) = elem.as_ref() { if idx != 1 || path.path.segments.last().unwrap().ident != "Context" @@ -87,8 +89,21 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< arg, "The Context must be the second argument.", )); + } else { + create_ctx = false; + match arg { + Pat::Wild(_) => { + pat.pat = Box::new( + syn::parse2::(quote! { #arg_ctx }) + .unwrap(), + ); + } + Pat::Ident(arg_ident) => { + arg_ctx = arg_ident.ident.clone(); + } + _ => {} + } } - arg_ctx = true; } } _ => return Err(Error::new_spanned(arg, "Invalid argument type.")), @@ -96,6 +111,12 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< } } + if create_ctx { + let arg = syn::parse2::(quote! { #arg_ctx: &#crate_name::Context<'_> }) + .unwrap(); + method.sig.inputs.insert(1, arg); + } + let entity_type = ty.value_type(); let mut key_pat = Vec::new(); let mut key_getter = Vec::new(); @@ -131,19 +152,24 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< ); let field_ident = &method.sig.ident; - let ctx_param = if arg_ctx { - quote! { &ctx, } - } else { - quote! {} - }; - let do_find = match &ty { - OutputType::Value(_) => quote! { - self.#field_ident(#ctx_param #(#use_keys),*).await - }, - OutputType::Result(_, _) => { - quote! { self.#field_ident(#ctx_param #(#use_keys),*).await? } - } - }; + if let OutputType::Value(inner_ty) = &ty { + let block = &method.block; + method.block = + syn::parse2::(quote!({ Ok(#block) })).expect("invalid block"); + method.sig.output = syn::parse2::( + quote! { -> #crate_name::FieldResult<#inner_ty> }, + ) + .expect("invalid result type"); + } + let do_find = quote! { self.#field_ident(ctx, #(#use_keys),*).await.map_err(|err| err.into_error(pos))? }; + + if let Some(guard) = &entity.guard { + method.block.stmts.insert( + 0, + syn::parse2(quote! { #guard.check(#arg_ctx).await?; }) + .expect("invalid guard"), + ); + } find_entities.push(( args.len(), @@ -166,12 +192,9 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< .map(|(idx, _)| idx) .unwrap(), ); - } else if let Some(field) = args::Field::parse(&method.attrs)? { + } else if let Some(field) = args::Field::parse(&crate_name, &method.attrs)? { if method.sig.asyncness.is_none() { - return Err(Error::new_spanned( - &method.sig.output, - "Must be asynchronous", - )); + return Err(Error::new_spanned(&method, "Must be asynchronous")); } let field_name = field @@ -214,7 +237,8 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< } }; - let mut arg_ctx = false; + let mut create_ctx = true; + let mut arg_ctx = Ident::new("ctx", Span::call_site()); let mut args = Vec::new(); for (idx, arg) in method.sig.inputs.iter_mut().enumerate() { @@ -236,13 +260,13 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< match (&*pat.pat, &*pat.ty) { (Pat::Ident(arg_ident), Type::Path(arg_ty)) => { args.push(( - arg_ident, - arg_ty, + arg_ident.clone(), + arg_ty.clone(), args::Argument::parse(&crate_name, &pat.attrs)?, )); pat.attrs.clear(); } - (_, Type::Reference(TypeReference { elem, .. })) => { + (arg, Type::Reference(TypeReference { elem, .. })) => { if let Type::Path(path) = elem.as_ref() { if idx != 1 || path.path.segments.last().unwrap().ident != "Context" @@ -252,7 +276,19 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< "The Context must be the second argument.", )); } - arg_ctx = true; + + create_ctx = false; + match arg { + Pat::Wild(_) => { + pat.pat = Box::new( + syn::parse2::(quote! { #arg_ctx }).unwrap(), + ); + } + Pat::Ident(arg_ident) => { + arg_ctx = arg_ident.ident.clone(); + } + _ => {} + } } } _ => return Err(Error::new_spanned(arg, "Invalid argument type.")), @@ -260,6 +296,12 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< } } + if create_ctx { + let arg = syn::parse2::(quote! { #arg_ctx: &#crate_name::Context<'_> }) + .unwrap(); + method.sig.inputs.insert(1, arg); + } + let mut schema_args = Vec::new(); let mut use_params = Vec::new(); let mut get_params = Vec::new(); @@ -335,24 +377,20 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< }); }); - let ctx_param = if arg_ctx { - quote! { &ctx, } - } else { - quote! {} - }; - let field_ident = &method.sig.ident; - let resolve_obj = match &ty { - OutputType::Value(_) => quote! { - self.#field_ident(#ctx_param #(#use_params),*).await - }, - OutputType::Result(_, _) => { - quote! { - { - let res:#crate_name::FieldResult<_> = self.#field_ident(#ctx_param #(#use_params),*).await; - res.map_err(|err| err.into_error_with_path(ctx.position, ctx.path_node.as_ref().unwrap().to_json()))? - } - } + if let OutputType::Value(inner_ty) = &ty { + let block = &method.block; + method.block = + syn::parse2::(quote!({ Ok(#block) })).expect("invalid block"); + method.sig.output = syn::parse2::( + quote! { -> #crate_name::FieldResult<#inner_ty> }, + ) + .expect("invalid result type"); + } + let resolve_obj = quote! { + { + let res:#crate_name::FieldResult<_> = self.#field_ident(ctx, #(#use_params),*).await; + res.map_err(|err| err.into_error_with_path(ctx.position, ctx.path_node.as_ref().unwrap().to_json()))? } }; @@ -364,6 +402,14 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< } }); + if let Some(guard) = field.guard { + method.block.stmts.insert( + 0, + syn::parse2(quote! { #guard.check(#arg_ctx).await?; }) + .expect("invalid guard"), + ); + } + if let Some((idx, _)) = method .attrs .iter() diff --git a/async-graphql-derive/src/simple_object.rs b/async-graphql-derive/src/simple_object.rs index d2935939..1192942f 100644 --- a/async-graphql-derive/src/simple_object.rs +++ b/async-graphql-derive/src/simple_object.rs @@ -38,7 +38,7 @@ pub fn generate(object_args: &args::Object, input: &mut DeriveInput) -> Result Result &#ty { - &self.#ident + #vis async fn #ident(&self, ctx: &#crate_name::Context<'_>) -> #crate_name::FieldResult<&#ty> { + #guard + Ok(&self.#ident) } }); } else { getters.push(quote! { #[inline] - #vis async fn #ident(&self) -> #ty { - self.#ident.clone() + #vis async fn #ident(&self, ctx: &#crate_name::Context<'_>) -> #crate_name::FieldResult<#ty> { + #guard + Ok(self.#ident.clone()) } }); } resolvers.push(quote! { if ctx.name.as_str() == #field_name { + let res = self.#ident(ctx).await.map_err(|err| err.into_error_with_path(ctx.position, ctx.path_node.as_ref().unwrap().to_json()))?; let ctx_obj = ctx.with_selection_set(&ctx.selection_set); - return #crate_name::OutputValueType::resolve(&self.#ident, &ctx_obj, ctx.position).await; + return #crate_name::OutputValueType::resolve(&res, &ctx_obj, ctx.position).await; } }); } diff --git a/async-graphql-derive/src/subscription.rs b/async-graphql-derive/src/subscription.rs index 5f85532c..3f54ee0e 100644 --- a/async-graphql-derive/src/subscription.rs +++ b/async-graphql-derive/src/subscription.rs @@ -3,9 +3,11 @@ use crate::output_type::OutputType; use crate::utils::{build_value_repr, check_reserved_name, get_crate_name}; use inflector::Inflector; use proc_macro::TokenStream; +use proc_macro2::{Ident, Span}; use quote::quote; use syn::{ - Error, FnArg, ImplItem, ItemImpl, Pat, Result, ReturnType, Type, TypeImplTrait, TypeReference, + Block, Error, FnArg, ImplItem, ItemImpl, Pat, Result, ReturnType, Type, TypeImplTrait, + TypeReference, }; pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result { @@ -41,7 +43,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< for item in &mut item_impl.items { if let ImplItem::Method(method) = item { - if let Some(field) = args::Field::parse(&method.attrs)? { + if let Some(field) = args::Field::parse(&crate_name, &method.attrs)? { let ident = &method.sig.ident; let field_name = field .name @@ -60,7 +62,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< if method.sig.asyncness.is_none() { return Err(Error::new_spanned( - &method.sig.asyncness, + &method, "The subscription stream function must be asynchronous", )); } @@ -72,7 +74,8 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< } }; - let mut arg_ctx = false; + let mut create_ctx = true; + let mut arg_ctx = Ident::new("ctx", Span::call_site()); let mut args = Vec::new(); for (idx, arg) in method.sig.inputs.iter_mut().enumerate() { @@ -94,13 +97,13 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< match (&*pat.pat, &*pat.ty) { (Pat::Ident(arg_ident), Type::Path(arg_ty)) => { args.push(( - arg_ident, - arg_ty, + arg_ident.clone(), + arg_ty.clone(), args::Argument::parse(&crate_name, &pat.attrs)?, )); pat.attrs.clear(); } - (_, Type::Reference(TypeReference { elem, .. })) => { + (arg, Type::Reference(TypeReference { elem, .. })) => { if let Type::Path(path) = elem.as_ref() { if idx != 1 || path.path.segments.last().unwrap().ident != "Context" @@ -109,8 +112,21 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< arg, "The Context must be the second argument.", )); + } else { + create_ctx = false; + match arg { + Pat::Wild(_) => { + pat.pat = Box::new( + syn::parse2::(quote! { #arg_ctx }) + .unwrap(), + ); + } + Pat::Ident(arg_ident) => { + arg_ctx = arg_ident.ident.clone(); + } + _ => {} + } } - arg_ctx = true; } } _ => { @@ -122,6 +138,12 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< } } + if create_ctx { + let arg = syn::parse2::(quote! { #arg_ctx: &#crate_name::Context<'_> }) + .unwrap(); + method.sig.inputs.insert(1, arg); + } + let mut schema_args = Vec::new(); let mut use_params = Vec::new(); let mut get_params = Vec::new(); @@ -184,6 +206,16 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< quote! { #res_ty } }; + if let OutputType::Value(inner_ty) = &ty { + let block = &method.block; + method.block = + syn::parse2::(quote!({ Ok(#block) })).expect("invalid block"); + method.sig.output = syn::parse2::( + quote! { -> #crate_name::FieldResult<#inner_ty> }, + ) + .expect("invalid result type"); + } + schema_fields.push(quote! { fields.insert(#field_name.to_string(), #crate_name::registry::Field { name: #field_name.to_string(), @@ -202,23 +234,18 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< }); }); - let ctx_param = if arg_ctx { - quote! { &ctx, } - } else { - quote! {} + let create_field_stream = quote! { + #crate_name::futures::stream::StreamExt::fuse(self.#ident(ctx, #(#use_params),*).await. + map_err(|err| err.into_error_with_path(ctx.position, ctx.path_node.as_ref().unwrap().to_json()))?) }; - let create_field_stream = match &ty { - OutputType::Value(_) => quote! { - #crate_name::futures::stream::StreamExt::fuse(self.#ident(#ctx_param #(#use_params),*).await) - }, - OutputType::Result(_, _) => { - quote! { - #crate_name::futures::stream::StreamExt::fuse(self.#ident(#ctx_param #(#use_params),*).await. - map_err(|err| err.into_error_with_path(ctx.position, ctx.path_node.as_ref().unwrap().to_json()))?) - } - } - }; + if let Some(guard) = &field.guard { + method.block.stmts.insert( + 0, + syn::parse2(quote! { #guard.check(#arg_ctx).await?; }) + .expect("invalid guard"), + ); + } create_stream.push(quote! { if ctx.name.as_str() == #field_name { diff --git a/async-graphql-derive/src/utils.rs b/async-graphql-derive/src/utils.rs index 17cccb21..add62cdb 100644 --- a/async-graphql-derive/src/utils.rs +++ b/async-graphql-derive/src/utils.rs @@ -2,7 +2,7 @@ use graphql_parser::parse_query; use graphql_parser::query::{Definition, OperationDefinition, ParseError, Query, Value}; use proc_macro2::{Span, TokenStream}; use quote::quote; -use syn::{Error, Ident, Meta, MetaList, NestedMeta, Result}; +use syn::{Error, Expr, Ident, Lit, Meta, MetaList, NestedMeta, Result}; pub fn get_crate_name(internal: bool) -> TokenStream { if internal { @@ -134,8 +134,15 @@ fn parse_nested_validator( for item in &ls.nested { if let NestedMeta::Meta(Meta::NameValue(nv)) = item { let name = &nv.path; - let value = &nv.lit; - params.push(quote! { #name: #value }); + if let Lit::Str(value) = &nv.lit { + let expr = syn::parse_str::(&value.value())?; + params.push(quote! { #name: #expr }); + } else { + return Err(Error::new_spanned( + &nv.lit, + "Value must be string literal", + )); + } } else { return Err(Error::new_spanned( nested_meta, @@ -174,3 +181,50 @@ pub fn parse_validator(crate_name: &TokenStream, args: &MetaList) -> Result Result> { + for arg in &args.nested { + if let NestedMeta::Meta(Meta::List(ls)) = arg { + if ls.path.is_ident("guard") { + let mut guards = None; + + for item in &ls.nested { + if let NestedMeta::Meta(Meta::List(ls)) = item { + let ty = &ls.path; + let mut params = Vec::new(); + for attr in &ls.nested { + if let NestedMeta::Meta(Meta::NameValue(nv)) = attr { + let name = &nv.path; + if let Lit::Str(value) = &nv.lit { + let expr = syn::parse_str::(&value.value())?; + params.push(quote! { #name: #expr }); + } else { + return Err(Error::new_spanned( + &nv.lit, + "Value must be string literal", + )); + } + } else { + return Err(Error::new_spanned(attr, "Invalid property for guard")); + } + } + + let guard = quote! { #ty { #(#params),* } }; + if guards.is_none() { + guards = Some(guard); + } else { + guards = + Some(quote! { #crate_name::guard::GuardExt::and(#guard, #guards) }); + } + } else { + return Err(Error::new_spanned(item, "Invalid guard")); + } + } + + return Ok(guards); + } + } + } + + Ok(None) +} diff --git a/docs/zh-CN/src/input_value_validators.md b/docs/zh-CN/src/input_value_validators.md index 8e356cbf..9a43f77b 100644 --- a/docs/zh-CN/src/input_value_validators.md +++ b/docs/zh-CN/src/input_value_validators.md @@ -14,7 +14,7 @@ struct Query; #[Object] impl Query { - async fn input(#[arg(validator(or(Email, MAC(colon = false))))] a: String) { + async fn input(#[arg(validator(or(Email, MAC(colon = "false"))))] a: String) { } } ``` @@ -31,8 +31,8 @@ struct Query; impl Query { async fn input(#[validator( or( - and(IntGreaterThan(value = 10), IntLessThan(value = 100)), - IntEqual(value = 0) + and(IntGreaterThan(value = "10"), IntLessThan(value = "100")), + IntEqual(value = "0") ))] a: String) { } { } diff --git a/src/error.rs b/src/error.rs index fd894b0e..e75c73aa 100644 --- a/src/error.rs +++ b/src/error.rs @@ -123,7 +123,7 @@ where } /// Error for query -#[derive(Debug, Error)] +#[derive(Debug, Error, PartialEq)] #[allow(missing_docs)] pub enum QueryError { #[error("Not supported.")] @@ -248,7 +248,7 @@ impl QueryError { } } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct RuleError { pub locations: Vec, pub message: String, @@ -311,7 +311,7 @@ pub enum ParseRequestError { } #[allow(missing_docs)] -#[derive(Debug, Error)] +#[derive(Debug, Error, PartialEq)] pub enum Error { #[error("Parse error: {message}")] Parse { diff --git a/src/guard.rs b/src/guard.rs new file mode 100644 index 00000000..2ba97d08 --- /dev/null +++ b/src/guard.rs @@ -0,0 +1,33 @@ +//! Field guards + +use crate::{Context, FieldResult}; + +/// Field guard +/// +/// Guard is a precondition for a field that is resolved if `Ok(()` is returned, otherwise an error is returned. +#[async_trait::async_trait] +pub trait Guard { + #[allow(missing_docs)] + async fn check(&self, ctx: &Context<'_>) -> FieldResult<()>; +} + +/// An extension trait for `Guard` +pub trait GuardExt: Guard + Sized { + /// Merge the two guards. + fn and(self, other: R) -> GuardAnd { + GuardAnd(self, other) + } +} + +impl GuardExt for T {} + +/// Guard for `GuardExt::and` +pub struct GuardAnd(A, B); + +#[async_trait::async_trait] +impl Guard for GuardAnd { + async fn check(&self, ctx: &Context<'_>) -> FieldResult<()> { + self.0.check(ctx).await?; + self.1.check(ctx).await + } +} diff --git a/src/lib.rs b/src/lib.rs index 291d7327..9afd9dbc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -90,6 +90,7 @@ mod types; mod validation; pub mod extensions; +pub mod guard; pub mod validators; #[doc(hidden)] @@ -167,6 +168,7 @@ pub use types::{EnumItem, EnumType}; /// | external | Mark a field as owned by another service. This allows service A to use fields from service B while also knowing at runtime the types of that field. | bool | Y | /// | provides | Annotate the expected returned fieldset from a field on a base type that is guaranteed to be selectable by the gateway. | string | Y | /// | requires | Annotate the required input fieldset from a base type for a resolver. It is used to develop a query plan where the required fields may not be needed by the client, but the service may need additional information from other services. | string | Y | +/// | guard | Field of guard | [`Guard`](guard/trait.Guard.html) | Y | /// /// # Field argument parameters /// @@ -274,6 +276,7 @@ pub use async_graphql_derive::Object; /// | external | Mark a field as owned by another service. This allows service A to use fields from service B while also knowing at runtime the types of that field. | bool | Y | /// | provides | Annotate the expected returned fieldset from a field on a base type that is guaranteed to be selectable by the gateway. | string | Y | /// | requires | Annotate the required input fieldset from a base type for a resolver. It is used to develop a query plan where the required fields may not be needed by the client, but the service may need additional information from other services. | string | Y | +/// | guard | Field of guard | [`Guard`](guard/trait.Guard.html) | Y | /// /// # Examples /// @@ -421,7 +424,6 @@ pub use async_graphql_derive::InputObject; /// | name | Field name | string | N | /// | type | Field type | string | N | /// | desc | Field description | string | Y | -/// | context | Method with the context | string | Y | /// | deprecation | Field deprecation reason | string | Y | /// | args | Field arguments | | Y | /// @@ -447,8 +449,6 @@ pub use async_graphql_derive::InputObject; /// /// The type, name, and parameters of the interface field must exactly match the type that implements the interface, /// The internal implementation is a forward of the function call. -/// You can specify the field function name that implements the interface type through the 'method' property, -/// or you can specify that the field function has a context parameter through the 'context' attribute. /// /// ```rust /// use async_graphql::*; @@ -476,7 +476,7 @@ pub use async_graphql_derive::InputObject; /// } /// /// #[Interface( -/// field(name = "value_a", type = "&'ctx str", context), +/// field(name = "value_a", type = "&'ctx str"), /// field(name = "value_b", type = "&i32"), /// field(name = "value_c", type = "i32", /// arg(name = "a", type = "i32"), @@ -541,6 +541,7 @@ pub use async_graphql_derive::Union; /// | name | Field name | string | Y | /// | desc | Field description | string | Y | /// | deprecation | Field deprecation reason | string | Y | +/// | guard | Field of guard | [`Guard`](guard/trait.Guard.html) | Y | /// /// # Field argument parameters /// diff --git a/src/model/enum_value.rs b/src/model/enum_value.rs index 172ef4f2..859ddacc 100644 --- a/src/model/enum_value.rs +++ b/src/model/enum_value.rs @@ -1,4 +1,4 @@ -use crate::{registry, Context}; +use crate::registry; use async_graphql_derive::Object; pub struct __EnumValue<'a> { @@ -11,19 +11,19 @@ pub struct __EnumValue<'a> { desc = "One possible value for a given Enum. Enum values are unique values, not a placeholder for a string or numeric value. However an Enum value is returned in a JSON response as a string." )] impl<'a> __EnumValue<'a> { - async fn name(&self, _: &Context<'_>) -> String { + async fn name(&self) -> String { self.value.name.to_string() } - async fn description(&self, _: &Context<'_>) -> Option { + async fn description(&self) -> Option { self.value.description.map(|s| s.to_string()) } - async fn is_deprecated(&self, _: &Context<'_>) -> bool { + async fn is_deprecated(&self) -> bool { self.value.deprecation.is_some() } - async fn deprecation_reason(&self, _: &Context<'_>) -> Option { + async fn deprecation_reason(&self) -> Option { self.value.deprecation.map(|s| s.to_string()) } } diff --git a/src/query.rs b/src/query.rs index 2a21da92..b19929bd 100644 --- a/src/query.rs +++ b/src/query.rs @@ -45,6 +45,7 @@ pub trait IntoQueryBuilder: Sized { } /// Query response +#[derive(Debug)] pub struct QueryResponse { /// Data of query result pub data: serde_json::Value, diff --git a/src/validation/test_harness.rs b/src/validation/test_harness.rs index 1ba8fe1f..4a84074f 100644 --- a/src/validation/test_harness.rs +++ b/src/validation/test_harness.rs @@ -1,5 +1,6 @@ #![allow(unused_variables)] #![allow(dead_code)] +#![allow(unreachable_code)] use crate::validation::visitor::{visit, Visitor, VisitorContext}; use crate::*; diff --git a/src/validators/mod.rs b/src/validators/mod.rs index 274df9c6..cc88f757 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -28,12 +28,12 @@ pub use string_validators::{Email, StringMaxLength, StringMinLength, MAC}; /// } /// /// // Input is email or MAC address -/// async fn value2(&self, #[arg(validator(or(Email, MAC(colon = false))))] email_or_mac: String) -> i32 { +/// async fn value2(&self, #[arg(validator(or(Email, MAC(colon = "false"))))] email_or_mac: String) -> i32 { /// unimplemented!() /// } /// /// // Input is integer between 100 and 200 -/// async fn value3(&self, #[arg(validator(IntRange(min = 100, max = 200)))] value: i32) -> i32 { +/// async fn value3(&self, #[arg(validator(IntRange(min = "100", max = "200")))] value: i32) -> i32 { /// unimplemented!() /// } /// } @@ -68,7 +68,7 @@ pub trait InputValueValidatorExt: InputValueValidator + Sized { impl InputValueValidatorExt for I {} -/// Invalidator for `InputValueValidator::and` +/// Invalidator for `InputValueValidatorExt::and` pub struct And(A, B); impl InputValueValidator for And diff --git a/tests/federation.rs b/tests/federation.rs index f15b8962..94031730 100644 --- a/tests/federation.rs +++ b/tests/federation.rs @@ -1,3 +1,5 @@ +#![allow(unreachable_code)] + use async_graphql::*; struct User { @@ -68,16 +70,14 @@ impl QueryRoot { #[async_std::test] pub async fn test_federation() { let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription); - let query = format!( - r#"{{ - _entities(representations: [{{__typename: "Product", upc: "B00005N5PF"}}]) {{ + let query = r#"{ + _entities(representations: [{__typename: "Product", upc: "B00005N5PF"}]) { __typename - ... on Product {{ + ... on Product { upc - }} - }} - }}"# - ); + } + } + }"#; assert_eq!( schema.execute(&query).await.unwrap().data, serde_json::json!({ diff --git a/tests/guard.rs b/tests/guard.rs new file mode 100644 index 00000000..04ce71fe --- /dev/null +++ b/tests/guard.rs @@ -0,0 +1,362 @@ +use async_graphql::guard::Guard; +use async_graphql::*; +use futures::{Stream, StreamExt}; +use std::sync::Arc; + +#[async_std::test] +pub async fn test_enum_type() { + #[derive(Eq, PartialEq, Copy, Clone)] + enum Role { + Admin, + Guest, + } + + struct RoleGuard { + role: Role, + } + + #[async_trait::async_trait] + impl Guard for RoleGuard { + async fn check(&self, ctx: &Context<'_>) -> FieldResult<()> { + if *ctx.data::() == self.role { + Ok(()) + } else { + Err("Forbidden".into()) + } + } + } + + #[SimpleObject] + struct MyObj { + #[field(guard(RoleGuard(role = "Role::Admin")))] + value: i32, + } + + struct Query; + + #[Object] + impl Query { + #[field(guard(RoleGuard(role = "Role::Admin")))] + async fn value1(&self) -> FieldResult { + Ok(1) + } + + #[field(guard(RoleGuard(role = "Role::Admin")))] + async fn value2(&self, ctx1: &Context<'_>) -> FieldResult { + Ok(2) + } + + #[field(guard(RoleGuard(role = "Role::Admin")))] + async fn value3(&self, _: &Context<'_>) -> i32 { + 3 + } + + async fn obj(&self) -> MyObj { + MyObj { value: 99 } + } + + #[entity(guard(RoleGuard(role = "Role::Admin")))] + async fn find_obj1(&self, value: i32) -> FieldResult { + Ok(MyObj { value }) + } + + #[entity(guard(RoleGuard(role = "Role::Admin")))] + #[allow(unused_variables)] + async fn find_obj2(&self, ctx1: &Context<'_>, value: i32, n: i32) -> FieldResult { + Ok(MyObj { value }) + } + + #[entity(guard(RoleGuard(role = "Role::Admin")))] + #[allow(unused_variables)] + async fn find_obj3(&self, _: &Context<'_>, value: i32, a: i32, b: i32) -> MyObj { + MyObj { value } + } + } + + struct Subscription; + + #[Subscription] + impl Subscription { + #[field(guard(RoleGuard(role = "Role::Admin")))] + async fn values1(&self) -> FieldResult> { + Ok(futures::stream::iter(vec![1, 2, 3])) + } + + #[field(guard(RoleGuard(role = "Role::Admin")))] + async fn values2(&self, ctx1: &Context<'_>) -> FieldResult> { + Ok(futures::stream::iter(vec![1, 2, 3])) + } + + #[field(guard(RoleGuard(role = "Role::Admin")))] + async fn values3(&self, _: &Context<'_>) -> impl Stream { + futures::stream::iter(vec![1, 2, 3]) + } + } + + let schema = Schema::new(Query, EmptyMutation, Subscription); + + let query = "{ obj { value } }"; + assert_eq!( + QueryBuilder::new(query) + .data(Role::Admin) + .execute(&schema) + .await + .unwrap() + .data, + serde_json::json!({ + "obj": {"value": 99} + }) + ); + + let query = "{ obj { value } }"; + assert_eq!( + QueryBuilder::new(query) + .data(Role::Admin) + .execute(&schema) + .await + .unwrap() + .data, + serde_json::json!({ + "obj": {"value": 99} + }) + ); + + let query = "{ obj { value } }"; + assert_eq!( + QueryBuilder::new(query) + .data(Role::Guest) + .execute(&schema) + .await + .unwrap_err(), + Error::Query { + pos: Pos { line: 1, column: 9 }, + path: Some(serde_json::json!(["obj", "value"])), + err: QueryError::FieldError { + err: "Forbidden".to_string(), + extended_error: None, + }, + } + ); + + let query = "{ value1 value2 value3 }"; + assert_eq!( + QueryBuilder::new(query) + .data(Role::Admin) + .execute(&schema) + .await + .unwrap() + .data, + serde_json::json!({ + "value1": 1, + "value2": 2, + "value3": 3, + }) + ); + + let query = "{ value1 }"; + assert_eq!( + QueryBuilder::new(query) + .data(Role::Guest) + .execute(&schema) + .await + .unwrap_err(), + Error::Query { + pos: Pos { line: 1, column: 3 }, + path: Some(serde_json::json!(["value1"])), + err: QueryError::FieldError { + err: "Forbidden".to_string(), + extended_error: None, + }, + } + ); + + assert_eq!( + schema + .create_subscription_stream( + "subscription { values1 }", + None, + Variables::default(), + Some(Arc::new({ + let mut data = Data::default(); + data.insert(Role::Admin); + data + })), + ) + .await + .unwrap() + .collect::>() + .await, + vec![ + serde_json::json! ({"values1": 1}), + serde_json::json! ({"values1": 2}), + serde_json::json! ({"values1": 3}) + ] + ); + + assert_eq!( + schema + .create_subscription_stream( + "subscription { values2 }", + None, + Variables::default(), + Some(Arc::new({ + let mut data = Data::default(); + data.insert(Role::Admin); + data + })), + ) + .await + .unwrap() + .collect::>() + .await, + vec![ + serde_json::json! ({"values2": 1}), + serde_json::json! ({"values2": 2}), + serde_json::json! ({"values2": 3}) + ] + ); + + assert_eq!( + schema + .create_subscription_stream( + "subscription { values3 }", + None, + Variables::default(), + Some(Arc::new({ + let mut data = Data::default(); + data.insert(Role::Admin); + data + })), + ) + .await + .unwrap() + .collect::>() + .await, + vec![ + serde_json::json! ({"values3": 1}), + serde_json::json! ({"values3": 2}), + serde_json::json! ({"values3": 3}) + ] + ); + + assert_eq!( + schema + .create_subscription_stream( + "subscription { values1 }", + None, + Variables::default(), + Some(Arc::new({ + let mut data = Data::default(); + data.insert(Role::Guest); + data + })), + ) + .await + .err() + .unwrap(), + Error::Query { + pos: Pos { + line: 1, + column: 16 + }, + path: Some(serde_json::json!(["values1"])), + err: QueryError::FieldError { + err: "Forbidden".to_string(), + extended_error: None, + }, + } + ); + + let query = r#"{ + _entities(representations: [{__typename: "MyObj", value: 1}]) { + __typename + ... on MyObj { + value + } + } + }"#; + assert_eq!( + QueryBuilder::new(query) + .data(Role::Admin) + .execute(&schema) + .await + .unwrap() + .data, + serde_json::json!({ + "_entities": [ + {"__typename": "MyObj", "value": 1}, + ] + }) + ); + + let query = r#"{ + _entities(representations: [{__typename: "MyObj", value: 1, n: 1}]) { + __typename + ... on MyObj { + value + } + } + }"#; + assert_eq!( + QueryBuilder::new(query) + .data(Role::Admin) + .execute(&schema) + .await + .unwrap() + .data, + serde_json::json!({ + "_entities": [ + {"__typename": "MyObj", "value": 1}, + ] + }) + ); + + let query = r#"{ + _entities(representations: [{__typename: "MyObj", value: 1, a: 1, b: 2}]) { + __typename + ... on MyObj { + value + } + } + }"#; + assert_eq!( + QueryBuilder::new(query) + .data(Role::Admin) + .execute(&schema) + .await + .unwrap() + .data, + serde_json::json!({ + "_entities": [ + {"__typename": "MyObj", "value": 1}, + ] + }) + ); + + let query = r#"{ + _entities(representations: [{__typename: "MyObj", value: 1}]) { + __typename + ... on MyObj { + value + } + } + }"#; + assert_eq!( + QueryBuilder::new(query) + .data(Role::Guest) + .execute(&schema) + .await + .unwrap_err(), + Error::Query { + pos: Pos { + line: 2, + column: 13 + }, + path: None, + err: QueryError::FieldError { + err: "Forbidden".to_string(), + extended_error: None, + }, + } + ); +}