diff --git a/async-graphql-tide/tests/graphql.rs b/async-graphql-tide/tests/graphql.rs index 37144957..545e3fb0 100644 --- a/async-graphql-tide/tests/graphql.rs +++ b/async-graphql-tide/tests/graphql.rs @@ -43,7 +43,10 @@ fn quickstart() -> Result<()> { let client = Task::>::spawn(async move { Timer::after(Duration::from_millis(300)).await; - let resp = reqwest::Client::new() + let resp = reqwest::Client::builder() + .no_proxy() + .build() + .unwrap() .post(format!("http://{}", listen_addr).as_str()) .body(r#"{"query":"{ add(a: 10, b: 20) }"}"#) .header(reqwest::header::CONTENT_TYPE, "application/json") @@ -57,7 +60,10 @@ fn quickstart() -> Result<()> { assert_eq!(string, json!({"data": {"add": 30}}).to_string()); // - let resp = reqwest::Client::new() + let resp = reqwest::Client::builder() + .no_proxy() + .build() + .unwrap() .get( format!( "http://{}?query=%7B%20add%28a%3A%2010%2C%20b%3A%2020%29%20%7D", @@ -133,7 +139,10 @@ fn hello() -> Result<()> { let client = Task::>::spawn(async move { Timer::after(Duration::from_millis(300)).await; - let resp = reqwest::Client::new() + let resp = reqwest::Client::builder() + .no_proxy() + .build() + .unwrap() .post(format!("http://{}", listen_addr).as_str()) .body(r#"{"query":"{ hello }"}"#) .header(reqwest::header::CONTENT_TYPE, "application/json") @@ -147,7 +156,10 @@ fn hello() -> Result<()> { assert_eq!(string, json!({"data":{"hello":"Hello, Foo!"}}).to_string()); - let resp = reqwest::Client::new() + let resp = reqwest::Client::builder() + .no_proxy() + .build() + .unwrap() .post(format!("http://{}", listen_addr).as_str()) .body(r#"{"query":"{ hello }"}"#) .header(reqwest::header::CONTENT_TYPE, "application/json") @@ -230,7 +242,10 @@ fn upload() -> Result<()> { .text("map", r#"{ "0": ["variables.file"] }"#) .part("0", reqwest::multipart::Part::stream("test").file_name("test.txt").mime_str("text/plain")?); - let resp = reqwest::Client::new() + let resp = reqwest::Client::builder() + .no_proxy() + .build() + .unwrap() .post(format!("http://{}", listen_addr).as_str()) .multipart(form) .send() diff --git a/src/schema.rs b/src/schema.rs index bcd0d8af..789a6f87 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -336,8 +336,13 @@ where cache_control, complexity, depth, - } = check_rules(&self.env.registry, &document, self.validation_mode) - .log_error(&extensions)?; + } = check_rules( + &self.env.registry, + &document, + Some(&variables), + self.validation_mode, + ) + .log_error(&extensions)?; extensions.lock().validation_end(); // check limit diff --git a/src/validation/mod.rs b/src/validation/mod.rs index 3ece10a3..0998f67c 100644 --- a/src/validation/mod.rs +++ b/src/validation/mod.rs @@ -9,7 +9,7 @@ mod test_harness; use crate::parser::query::Document; use crate::registry::Registry; -use crate::{CacheControl, Error, Result}; +use crate::{CacheControl, Error, Result, Variables}; use visitor::{visit, VisitorContext, VisitorNil}; pub struct CheckResult { @@ -31,9 +31,10 @@ pub enum ValidationMode { pub fn check_rules( registry: &Registry, doc: &Document, + variables: Option<&Variables>, mode: ValidationMode, ) -> Result { - let mut ctx = VisitorContext::new(registry, doc); + let mut ctx = VisitorContext::new(registry, doc, variables); let mut cache_control = CacheControl::default(); let mut complexity = 0; let mut depth = 0; diff --git a/src/validation/rules/arguments_of_correct_type.rs b/src/validation/rules/arguments_of_correct_type.rs index 2e269869..06f572be 100644 --- a/src/validation/rules/arguments_of_correct_type.rs +++ b/src/validation/rules/arguments_of_correct_type.rs @@ -43,17 +43,27 @@ impl<'a> Visitor<'a> for ArgumentsOfCorrectType<'a> { .and_then(|args| args.get(name.as_str()).map(|input| input)) { if let Some(validator) = &arg.validator { - if let Some(reason) = validator.is_valid(value) { - ctx.report_error( - vec![name.position()], - format!("Invalid value for argument \"{}\", {}", arg.name, reason), - ); - return; + let value = match &value.node { + Value::Variable(var_name) => { + ctx.variables.and_then(|variables| variables.get(var_name)) + } + _ => Some(&value.node), + }; + + if let Some(value) = value { + if let Some(reason) = validator.is_valid(value) { + ctx.report_error( + vec![name.position()], + format!("Invalid value for argument \"{}\", {}", arg.name, reason), + ); + return; + } } } if let Some(reason) = is_valid_input_value( ctx.registry, + ctx.variables, &arg.ty, value, QueryPathNode { diff --git a/src/validation/rules/default_values_of_correct_type.rs b/src/validation/rules/default_values_of_correct_type.rs index 2dd47d61..4ca5613e 100644 --- a/src/validation/rules/default_values_of_correct_type.rs +++ b/src/validation/rules/default_values_of_correct_type.rs @@ -20,6 +20,7 @@ impl<'a> Visitor<'a> for DefaultValuesOfCorrectType { )); } else if let Some(reason) = is_valid_input_value( ctx.registry, + ctx.variables, &variable_definition.var_type.to_string(), value, QueryPathNode { diff --git a/src/validation/test_harness.rs b/src/validation/test_harness.rs index 27519434..960fcd2f 100644 --- a/src/validation/test_harness.rs +++ b/src/validation/test_harness.rs @@ -340,7 +340,7 @@ where { let schema = &*TEST_HARNESS; let registry = &schema.env.registry; - let mut ctx = VisitorContext::new(registry, doc); + let mut ctx = VisitorContext::new(registry, doc, None); let mut visitor = factory(); visit(&mut visitor, &mut ctx, doc); if !ctx.errors.is_empty() { diff --git a/src/validation/utils.rs b/src/validation/utils.rs index 5a5422df..a2974bd5 100644 --- a/src/validation/utils.rs +++ b/src/validation/utils.rs @@ -1,6 +1,6 @@ use crate::context::QueryPathNode; use crate::parser::query::OperationDefinition; -use crate::{registry, Pos, QueryPathSegment, Value}; +use crate::{registry, Pos, QueryPathSegment, Value, Variables}; use std::collections::HashSet; #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -53,6 +53,7 @@ pub fn operation_name(operation_definition: &OperationDefinition) -> (Option<&st pub fn is_valid_input_value( registry: ®istry::Registry, + variables: Option<&Variables>, type_name: &str, value: &Value, path_node: QueryPathNode, @@ -67,13 +68,14 @@ pub fn is_valid_input_value( &path_node, format!("expected type \"{}\"", type_name), )), - _ => is_valid_input_value(registry, type_name, value, path_node), + _ => is_valid_input_value(registry, variables, type_name, value, path_node), }, registry::MetaTypeName::List(type_name) => match value { Value::List(elems) => { for (idx, elem) in elems.iter().enumerate() { if let Some(reason) = is_valid_input_value( registry, + variables, type_name, elem, QueryPathNode { @@ -86,7 +88,7 @@ pub fn is_valid_input_value( } None } - _ => is_valid_input_value(registry, type_name, value, path_node), + _ => is_valid_input_value(registry, variables, type_name, value, path_node), }, registry::MetaTypeName::Named(type_name) => { if let Value::Null = value { @@ -96,11 +98,21 @@ pub fn is_valid_input_value( if let Some(ty) = registry.types.get(type_name) { match ty { registry::MetaType::Scalar { is_valid, .. } => { - if !is_valid(value) { - Some(valid_error( - &path_node, - format!("expected type \"{}\"", type_name), - )) + let value = match value { + Value::Variable(var_name) => { + variables.and_then(|variables| variables.get(var_name)) + } + _ => Some(value), + }; + if let Some(value) = value { + if !is_valid(value) { + Some(valid_error( + &path_node, + format!("expected type \"{}\"", type_name), + )) + } else { + None + } } else { None } @@ -136,19 +148,28 @@ pub fn is_valid_input_value( input_names.remove(field.name); if let Some(value) = values.get(field.name) { if let Some(validator) = &field.validator { - if let Some(reason) = validator.is_valid(value) { - return Some(valid_error( - &QueryPathNode { - parent: Some(&path_node), - segment: QueryPathSegment::Name(field.name), - }, - reason, - )); + let value = match value { + Value::Variable(var_name) => variables + .and_then(|variables| variables.get(var_name)), + _ => Some(value), + }; + + if let Some(value) = value { + if let Some(reason) = validator.is_valid(value) { + return Some(valid_error( + &QueryPathNode { + parent: Some(&path_node), + segment: QueryPathSegment::Name(field.name), + }, + reason, + )); + } } } if let Some(reason) = is_valid_input_value( registry, + variables, &field.ty, value, QueryPathNode { diff --git a/src/validation/visitor.rs b/src/validation/visitor.rs index a7574b80..d5a89458 100644 --- a/src/validation/visitor.rs +++ b/src/validation/visitor.rs @@ -4,11 +4,12 @@ use crate::parser::query::{ OperationDefinition, Selection, SelectionSet, TypeCondition, VariableDefinition, }; use crate::registry::{self, MetaType, MetaTypeName}; -use crate::{Pos, Positioned, Value}; +use crate::{Pos, Positioned, Value, Variables}; use std::collections::HashMap; pub struct VisitorContext<'a> { pub registry: &'a registry::Registry, + pub variables: Option<&'a Variables>, pub errors: Vec, type_stack: Vec>, input_type: Vec>>, @@ -16,9 +17,14 @@ pub struct VisitorContext<'a> { } impl<'a> VisitorContext<'a> { - pub fn new(registry: &'a registry::Registry, doc: &'a Document) -> Self { + pub fn new( + registry: &'a registry::Registry, + doc: &'a Document, + variables: Option<&'a Variables>, + ) -> Self { Self { registry, + variables, errors: Default::default(), type_stack: Default::default(), input_type: Default::default(), diff --git a/tests/input_validators.rs b/tests/input_validators.rs index 0169fe36..5f0e3780 100644 --- a/tests/input_validators.rs +++ b/tests/input_validators.rs @@ -1586,3 +1586,123 @@ pub async fn test_input_validator_operator_and() { } } } + +#[async_std::test] +pub async fn test_input_validator_variable() { + struct QueryRoot; + + #[InputObject] + struct InputMaxLength { + #[field(validator(StringMinLength(length = "6")))] + pub id: String, + } + + #[Object] + impl QueryRoot { + async fn field_parameter( + &self, + #[arg(validator(StringMinLength(length = "6")))] _id: String, + ) -> bool { + true + } + + async fn input_object(&self, _input: InputMaxLength) -> bool { + true + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription); + let test_cases = [ + "abc", + "acbce", + "abcdef", + "abcdefghi", + "abcdefghijkl", + "abcdefghijklmnop", + ]; + + let validator_length = 6; + for case in &test_cases { + let mut variables = Variables::default(); + variables.insert("id".to_string(), Value::String(case.to_string())); + + let field_query = "query($id: String!) {fieldParameter(id: $id)}"; + let object_query = "query($id: String!) {inputObject(input: {id: $id})}"; + let case_length = case.len(); + + if case_length < validator_length { + let should_fail_msg = format!( + "StringMinValue case {} should have failed, but did not", + case + ); + + let field_error_msg = format!( + "Invalid value for argument \"id\", the value length is {}, must be greater than or equal to {}", + case_length, validator_length + ); + let object_error_msg = format!( + "Invalid value for argument \"input.id\", the value length is {}, must be greater than or equal to {}", + case_length, validator_length + ); + + assert_eq!( + QueryBuilder::new(field_query) + .variables(variables.clone()) + .execute(&schema) + .await + .expect_err(&should_fail_msg[..]), + Error::Rule { + errors: vec!(RuleError { + locations: vec!(Pos { + line: 1, + column: 37 + }), + message: field_error_msg + }) + } + ); + + assert_eq!( + QueryBuilder::new(object_query) + .variables(variables.clone()) + .execute(&schema) + .await + .expect_err(&should_fail_msg[..]), + Error::Rule { + errors: vec!(RuleError { + locations: vec!(Pos { + line: 1, + column: 34 + }), + message: object_error_msg + }) + } + ); + } else { + let error_msg = format!("Schema returned error with test_string = {}", case); + assert_eq!( + QueryBuilder::new(field_query) + .variables(variables.clone()) + .execute(&schema) + .await + .expect(&error_msg[..]) + .data, + serde_json::json!({"fieldParameter": true}), + "Failed to validate {} with StringMinLength", + case + ); + + assert_eq!( + QueryBuilder::new(object_query) + .variables(variables.clone()) + .execute(&schema) + .await + .expect(&error_msg[..]) + .data, + serde_json::json!({"inputObject": true}), + "Failed to validate {} with StringMinLength", + case + ); + } + } +}