Rework validators 2

This commit is contained in:
Sunli 2021-11-15 09:12:13 +08:00
parent 8750d8d34b
commit 958b534b77
18 changed files with 258 additions and 52 deletions

View File

@ -17,7 +17,7 @@ default = []
apollo_tracing = ["chrono"]
apollo_persisted_queries = ["lru", "sha2"]
unblock = ["blocking"]
string_number = ["num-traits"]
string_number = []
dataloader = ["futures-timer", "futures-channel", "lru"]
tracing = ["tracinglib", "tracing-futures"]
decimal = ["rust_decimal"]
@ -46,6 +46,7 @@ multer = "2.0.0"
tempfile = "3.2.0"
bytes = { version = "1.0.1", features = ["serde"] }
mime = "0.3.15"
num-traits = "0.2.14"
# Feature optional dependencies
bson = { version = "2.0.0", optional = true, features = ["chrono-0_4"] }
@ -64,7 +65,6 @@ rust_decimal = { version = "1.14.3", optional = true }
# Non-feature optional dependencies
blocking = { version = "1.0.2", optional = true }
lru = { version = "0.6.5", optional = true }
num-traits = { version = "0.2.14", optional = true }
sha2 = { version = "0.9.3", optional = true }
futures-timer = { version = "3.0.2", optional = true }
futures-channel = { version = "0.3.13", optional = true }

View File

@ -230,17 +230,15 @@ pub fn generate(
let validators = validator.clone().unwrap_or_default().create_validators(
&crate_name,
quote!(&#ident),
quote!(),
Some(quote!(.map_err(|err| err.into_server_error(__pos)))),
);
// We're generating a new identifier,
// so remove the 'r#` prefix if present
let param_getter_name = get_param_getter_ident(&ident.ident.unraw().to_string());
get_params.push(quote! {
#[allow(non_snake_case)]
let #param_getter_name = || -> #crate_name::ServerResult<#ty> { ctx.param_value(#name, #default) };
let #param_getter_name = || { ctx.param_value::<#ty>(#name, #default) };
#[allow(non_snake_case)]
let #ident: #ty = #param_getter_name()?;
let (__pos, #ident) = #param_getter_name()?;
#validators
});
}
@ -340,7 +338,7 @@ pub fn generate(
let resolve_obj = quote! {
{
let res = self.#field_ident(ctx, #(#use_params),*).await;
res.map_err(|err| ::std::convert::Into::<#crate_name::Error>::into(err).into_server_error(ctx.item.pos))?
res.map_err(|err| ::std::convert::Into::<#crate_name::Error>::into(err).into_server_error(ctx.item.pos))
}
};
@ -359,11 +357,14 @@ pub fn generate(
resolvers.push(quote! {
#(#cfg_attrs)*
if ctx.item.node.name.node == #field_name {
#(#get_params)*
#guard
let f = async move {
#(#get_params)*
#guard
#resolve_obj
};
let obj = f.await.map_err(|err| ctx.set_error_path(err))?;
let ctx_obj = ctx.with_selection_set(&ctx.item.node.selection_set);
let res = #resolve_obj;
return #crate_name::OutputType::resolve(&res, &ctx_obj, ctx.item).await.map(::std::option::Option::Some);
return #crate_name::OutputType::resolve(&obj, &ctx_obj, ctx.item).await.map(::std::option::Option::Some);
}
});

View File

@ -82,7 +82,7 @@ pub fn generate(object_args: &args::InputObject) -> GeneratorResult<TokenStream>
.create_validators(
&crate_name,
quote!(&#ident),
quote!(.map_err(#crate_name::InputValueError::propagate)?),
Some(quote!(.map_err(#crate_name::InputValueError::propagate))),
);
if field.flatten {

View File

@ -204,7 +204,7 @@ pub fn generate(interface_args: &args::Interface) -> GeneratorResult<TokenStream
None => quote! { ::std::option::Option::None },
};
get_params.push(quote! {
let #ident: #ty = ctx.param_value(#name, #get_default)?;
let (_, #ident) = ctx.param_value::<#ty>(#name, #get_default)?;
});
let desc = desc

View File

@ -365,18 +365,16 @@ pub fn generate(
let validators = validator.clone().unwrap_or_default().create_validators(
&crate_name,
quote!(&#ident),
quote!(),
Some(quote!(.map_err(|err| err.into_server_error(__pos)))),
);
// We're generating a new identifier,
// so remove the 'r#` prefix if present
let param_getter_name =
get_param_getter_ident(&ident.ident.unraw().to_string());
get_params.push(quote! {
#[allow(non_snake_case)]
let #param_getter_name = || -> #crate_name::ServerResult<#ty> { ctx.param_value(#name, #default) };
#[allow(non_snake_case)]
let #ident: #ty = #param_getter_name()?;
#[allow(non_snake_case)]
let #param_getter_name = || ctx.param_value::<#ty>(#name, #default);
#[allow(non_snake_case, unused_variables)]
let (__pos, #ident) = #param_getter_name()?;
#validators
});
}

View File

@ -199,14 +199,15 @@ pub fn generate(
let validators = validator.clone().unwrap_or_default().create_validators(
&crate_name,
quote!(&#ident),
quote!(),
Some(quote!(.map_err(|err| err.into_server_error(__pos)))),
);
let param_getter_name = get_param_getter_ident(&ident.ident.to_string());
let param_getter_name = get_param_getter_ident(&ident.ident.unraw().to_string());
get_params.push(quote! {
#[allow(non_snake_case)]
let #param_getter_name = || -> #crate_name::ServerResult<#ty> { ctx.param_value(#name, #default) };
let #param_getter_name = || { ctx.param_value::<#ty>(#name, #default) };
#[allow(non_snake_case)]
let #ident: #ty = ctx.param_value(#name, #default)?;
let (__pos, #ident) = #param_getter_name()?;
#validators
});
}
@ -274,8 +275,8 @@ pub fn generate(
)
});
parse_args.push(quote! {
let #ident: #ty = __ctx.param_value(__variables_definition, __field, #name, #default)?;
});
let #ident: #ty = __ctx.param_value(__variables_definition, __field, #name, #default)?;
});
}
}
quote! {
@ -317,7 +318,7 @@ pub fn generate(
.map_err(|err| {
::std::convert::Into::<#crate_name::Error>::into(err).into_server_error(ctx.item.pos)
.with_path(::std::vec![#crate_name::PathSegment::Field(::std::borrow::ToOwned::to_owned(&*field_name))])
})?
})
};
let guard = match &field.guard {
@ -333,13 +334,18 @@ pub fn generate(
let stream_fn = quote! {
let field_name = ::std::clone::Clone::clone(&ctx.item.node.response_key().node);
let field = ::std::sync::Arc::new(::std::clone::Clone::clone(&ctx.item));
#(#get_params)*
#guard
let f = async {
#(#get_params)*
#guard
#create_field_stream
};
let stream = f.await.map_err(|err| ctx.set_error_path(err))?;
let pos = ctx.item.pos;
let schema_env = ::std::clone::Clone::clone(&ctx.schema_env);
let query_env = ::std::clone::Clone::clone(&ctx.query_env);
let stream = #crate_name::futures_util::stream::StreamExt::then(#create_field_stream, {
let stream = #crate_name::futures_util::stream::StreamExt::then(stream, {
let field_name = ::std::clone::Clone::clone(&field_name);
move |msg| {
let schema_env = ::std::clone::Clone::clone(&schema_env);

View File

@ -162,7 +162,9 @@ pub fn generate_guards(
let value_str = value.value();
if let Some(value_str) = value_str.strip_prefix('@') {
let getter_name = get_param_getter_ident(value_str);
params.push(quote! { #name: #getter_name()? });
params.push(
quote! { #name: #getter_name().map(|(_, value)| value)? },
);
} else {
let expr = syn::parse_str::<Expr>(&value_str)?;
params.push(quote! { #name: (#expr).into() });

View File

@ -27,52 +27,53 @@ impl Validators {
&self,
crate_name: &TokenStream,
value: TokenStream,
map_err: TokenStream,
map_err: Option<TokenStream>,
) -> TokenStream {
let mut codes = Vec::new();
if let Some(n) = &self.multiple_of {
codes.push(quote! {
#crate_name::validators::multiple_of(#value, #n) #map_err
#crate_name::validators::multiple_of(#value, #n)
});
}
if let Some(n) = &self.maximum {
codes.push(quote! {
#crate_name::validators::maximum(#value, #n) #map_err
#crate_name::validators::maximum(#value, #n)
});
}
if let Some(n) = &self.minimum {
codes.push(quote! {
#crate_name::validators::minimum(#value, #n) #map_err
#crate_name::validators::minimum(#value, #n)
});
}
if let Some(n) = &self.max_length {
codes.push(quote! {
#crate_name::validators::max_length(#value, #n) #map_err
#crate_name::validators::max_length(#value, #n)
});
}
if let Some(n) = &self.min_length {
codes.push(quote! {
#crate_name::validators::min_length(#value, #n) #map_err
#crate_name::validators::min_length(#value, #n)
});
}
if let Some(n) = &self.max_items {
codes.push(quote! {
#crate_name::validators::max_items(#value, #n) #map_err
#crate_name::validators::max_items(#value, #n)
});
}
if let Some(n) = &self.min_items {
codes.push(quote! {
#crate_name::validators::min_items(#value, #n) #map_err
#crate_name::validators::min_items(#value, #n)
});
}
let codes = codes.into_iter().map(|s| quote!(#s #map_err ?));
quote!(#(#codes;)*)
}
}

View File

@ -556,18 +556,20 @@ impl<'a> ContextBase<'a, &'a Positioned<Field>> {
&self,
name: &str,
default: Option<fn() -> T>,
) -> ServerResult<T> {
) -> ServerResult<(Pos, T)> {
let value = self.item.node.get_argument(name).cloned();
if value.is_none() {
if let Some(default) = default {
return Ok(default());
return Ok((Pos::default(), default()));
}
}
let (pos, value) = match value {
Some(value) => (value.pos, Some(self.resolve_input_value(value)?)),
None => (Pos::default(), None),
};
InputType::parse(value).map_err(|e| e.into_server_error(pos))
InputType::parse(value)
.map(|value| (pos, value))
.map_err(|e| e.into_server_error(pos))
}
/// Creates a uniform interface to inspect the forthcoming selections.

View File

@ -36,7 +36,7 @@ impl<T: ObjectType> ContainerType for QueryRoot<T> {
.await
.map(Some);
} else if ctx.item.node.name.node == "__type" {
let type_name: String = ctx.param_value("name", None)?;
let (_, type_name) = ctx.param_value::<String>("name", None)?;
let ctx_obj = ctx.with_selection_set(&ctx.item.node.selection_set);
let visible_types = ctx.schema_env.registry.find_visible_types(ctx);
return OutputType::resolve(
@ -56,7 +56,7 @@ impl<T: ObjectType> ContainerType for QueryRoot<T> {
if ctx.schema_env.registry.enable_federation || ctx.schema_env.registry.has_entities() {
if ctx.item.node.name.node == "_entities" {
let representations: Vec<Any> = ctx.param_value("representations", None)?;
let (_, representations) = ctx.param_value::<Vec<Any>>("representations", None)?;
let res = futures_util::future::try_join_all(representations.iter().map(
|item| async move {
self.inner.find_entity(ctx, &item.0).await?.ok_or_else(|| {

View File

@ -2,7 +2,7 @@ use std::ops::Deref;
use crate::{InputType, InputValueError};
pub async fn max_items<T: Deref<Target = [E]> + InputType, E>(
pub fn max_items<T: Deref<Target = [E]> + InputType, E>(
value: &T,
len: usize,
) -> Result<(), InputValueError<T>> {

View File

@ -1,6 +1,6 @@
use crate::{InputType, InputValueError};
pub async fn max_length<T: AsRef<str> + InputType>(
pub fn max_length<T: AsRef<str> + InputType>(
value: &T,
len: usize,
) -> Result<(), InputValueError<T>> {

View File

@ -2,7 +2,7 @@ use num_traits::AsPrimitive;
use crate::{InputType, InputValueError};
pub async fn maximum<T: AsPrimitive<f64> + InputType>(
pub fn maximum<T: AsPrimitive<f64> + InputType>(
value: &T,
n: f64,
) -> Result<(), InputValueError<T>> {

View File

@ -2,7 +2,7 @@ use std::ops::Deref;
use crate::{InputType, InputValueError};
pub async fn min_items<T: Deref<Target = [E]> + InputType, E>(
pub fn min_items<T: Deref<Target = [E]> + InputType, E>(
value: &T,
len: usize,
) -> Result<(), InputValueError<T>> {

View File

@ -1,6 +1,6 @@
use crate::{InputType, InputValueError};
pub async fn min_length<T: AsRef<str> + InputType>(
pub fn min_length<T: AsRef<str> + InputType>(
value: &T,
len: usize,
) -> Result<(), InputValueError<T>> {

View File

@ -2,7 +2,7 @@ use num_traits::AsPrimitive;
use crate::{InputType, InputValueError};
pub async fn minimum<T: AsPrimitive<f64> + InputType>(
pub fn minimum<T: AsPrimitive<f64> + InputType>(
value: &T,
n: f64,
) -> Result<(), InputValueError<T>> {

View File

@ -2,7 +2,7 @@ use num_traits::AsPrimitive;
use crate::{InputType, InputValueError};
pub async fn multiple_of<T: AsPrimitive<f64> + InputType>(
pub fn multiple_of<T: AsPrimitive<f64> + InputType>(
value: &T,
n: f64,
) -> Result<(), InputValueError<T>> {

196
tests/validators.rs Normal file
View File

@ -0,0 +1,196 @@
use async_graphql::*;
use futures_util::{Stream, StreamExt};
#[tokio::test]
pub async fn test_validator_on_object_field_args() {
struct Query;
#[Object]
impl Query {
async fn value(&self, #[graphql(validator(maximum = "10"))] n: i32) -> i32 {
n
}
}
let schema = Schema::new(Query, EmptyMutation, EmptySubscription);
assert_eq!(
schema
.execute("{ value(n: 5) }")
.await
.into_result()
.unwrap()
.data,
value!({ "value": 5 })
);
assert_eq!(
schema
.execute("{ value(n: 11) }")
.await
.into_result()
.unwrap_err(),
vec![ServerError {
message: r#"Failed to parse "Int": the value is 11, must be less than or equal to 10"#
.to_string(),
source: None,
locations: vec![Pos {
line: 1,
column: 12
}],
path: vec![PathSegment::Field("value".to_string())],
extensions: None
}]
);
}
#[tokio::test]
pub async fn test_validator_on_input_object_field() {
#[derive(InputObject)]
struct MyInput {
#[graphql(validator(maximum = "10"))]
a: i32,
}
struct Query;
#[Object]
impl Query {
async fn value(&self, input: MyInput) -> i32 {
input.a
}
}
let schema = Schema::new(Query, EmptyMutation, EmptySubscription);
assert_eq!(
schema
.execute("{ value(input: {a: 5}) }")
.await
.into_result()
.unwrap()
.data,
value!({ "value": 5 })
);
assert_eq!(
schema
.execute("{ value(input: {a: 11}) }")
.await
.into_result()
.unwrap_err(),
vec![ServerError {
message: r#"Failed to parse "Int": the value is 11, must be less than or equal to 10 (occurred while parsing "MyInput")"#
.to_string(),
source: None,
locations: vec![Pos {
line: 1,
column: 16
}],
path: vec![PathSegment::Field("value".to_string())],
extensions: None
}]
);
}
#[tokio::test]
pub async fn test_validator_on_complex_object_field_args() {
#[derive(SimpleObject)]
#[graphql(complex)]
struct Query {
a: i32,
}
#[ComplexObject]
impl Query {
async fn value(&self, #[graphql(validator(maximum = "10"))] n: i32) -> i32 {
n
}
}
let schema = Schema::new(Query { a: 10 }, EmptyMutation, EmptySubscription);
assert_eq!(
schema
.execute("{ value(n: 5) }")
.await
.into_result()
.unwrap()
.data,
value!({ "value": 5 })
);
assert_eq!(
schema
.execute("{ value(n: 11) }")
.await
.into_result()
.unwrap_err(),
vec![ServerError {
message: r#"Failed to parse "Int": the value is 11, must be less than or equal to 10"#
.to_string(),
source: None,
locations: vec![Pos {
line: 1,
column: 12
}],
path: vec![PathSegment::Field("value".to_string())],
extensions: None
}]
);
}
#[tokio::test]
pub async fn test_validator_on_subscription_field_args() {
struct Query;
#[Object]
impl Query {
async fn value(&self) -> i32 {
1
}
}
struct Subscription;
#[Subscription]
impl Subscription {
async fn value(
&self,
#[graphql(validator(maximum = "10"))] n: i32,
) -> impl Stream<Item = i32> {
futures_util::stream::iter(vec![n])
}
}
let schema = Schema::new(Query, EmptyMutation, Subscription);
assert_eq!(
schema
.execute_stream("subscription { value(n: 5) }")
.collect::<Vec<_>>()
.await
.remove(0)
.into_result()
.unwrap()
.data,
value!({ "value": 5 })
);
assert_eq!(
schema
.execute_stream("subscription { value(n: 11) }")
.collect::<Vec<_>>()
.await
.remove(0)
.into_result()
.unwrap_err(),
vec![ServerError {
message: r#"Failed to parse "Int": the value is 11, must be less than or equal to 10"#
.to_string(),
source: None,
locations: vec![Pos {
line: 1,
column: 25
}],
path: vec![PathSegment::Field("value".to_string())],
extensions: None
}]
);
}