From d932d30209c7be2f080c10ff262f8a9406dce90d Mon Sep 17 00:00:00 2001 From: Sunli Date: Sun, 31 May 2020 11:54:07 +0800 Subject: [PATCH] Fix incorrect variable substitution. #126 --- src/context.rs | 44 ++++++++++--------- src/lib.rs | 1 + tests/variables.rs | 104 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 129 insertions(+), 20 deletions(-) diff --git a/src/context.rs b/src/context.rs index c97aa8b7..216b5f42 100644 --- a/src/context.rs +++ b/src/context.rs @@ -431,26 +431,25 @@ impl<'a, T> ContextBase<'a, T> { .into_error(pos)) } - fn resolve_input_value(&self, mut value: Value, pos: Pos) -> Result { + fn resolve_input_value(&self, value: &mut Value, pos: Pos) -> Result<()> { match value { - Value::Variable(var_name) => self.var_value(&var_name, pos), + Value::Variable(var_name) => { + *value = self.var_value(&var_name, pos)?; + Ok(()) + } Value::List(ref mut ls) => { for value in ls { - if let Value::Variable(var_name) = value { - *value = self.var_value(&var_name, pos)?; - } + self.resolve_input_value(value, pos)?; } - Ok(value) + Ok(()) } Value::Object(ref mut obj) => { for value in obj.values_mut() { - if let Value::Variable(var_name) = value { - *value = self.var_value(&var_name, pos)?; - } + self.resolve_input_value(value, pos)?; } - Ok(value) + Ok(()) } - _ => Ok(value), + _ => Ok(()), } } @@ -459,9 +458,9 @@ impl<'a, T> ContextBase<'a, T> { for directive in directives { if directive.name.node == "skip" { if let Some(value) = directive.get_argument("if") { - match InputValueType::parse(Some( - self.resolve_input_value(value.clone_inner(), value.position())?, - )) { + let mut inner_value = value.clone_inner(); + self.resolve_input_value(&mut inner_value, value.pos)?; + match InputValueType::parse(Some(inner_value)) { Ok(true) => return Ok(true), Ok(false) => {} Err(err) => { @@ -478,9 +477,9 @@ impl<'a, T> ContextBase<'a, T> { } } else if directive.name.node == "include" { if let Some(value) = directive.get_argument("if") { - match InputValueType::parse(Some( - self.resolve_input_value(value.clone_inner(), value.position())?, - )) { + let mut inner_value = value.clone_inner(); + self.resolve_input_value(&mut inner_value, value.pos)?; + match InputValueType::parse(Some(inner_value)) { Ok(false) => return Ok(true), Ok(true) => {} Err(err) => { @@ -547,11 +546,16 @@ impl<'a> ContextBase<'a, &'a Positioned> { .as_ref() .map(|value| value.position()) .unwrap_or_default(); - let resolved_value = match value { - Some(value) => Some(self.resolve_input_value(value.into_inner(), pos)?), + let value = match value { + Some(value) => { + let mut new_value = value.into_inner(); + self.resolve_input_value(&mut new_value, pos)?; + Some(new_value) + } None => None, }; - match InputValueType::parse(resolved_value) { + + match InputValueType::parse(value) { Ok(res) => Ok(res), Err(err) => Err(err.into_error(pos, T::qualified_type_name())), } diff --git a/src/lib.rs b/src/lib.rs index 0b33f1cf..8c4492fe 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -91,6 +91,7 @@ #![warn(missing_docs)] #![allow(clippy::needless_doctest_main)] #![allow(clippy::needless_lifetimes)] +#![allow(clippy::trivially_copy_pass_by_ref)] #![recursion_limit = "256"] #![forbid(unsafe_code)] diff --git a/tests/variables.rs b/tests/variables.rs index 466599b9..f25e7a49 100644 --- a/tests/variables.rs +++ b/tests/variables.rs @@ -70,3 +70,107 @@ pub async fn test_variable_default_value() { }) ); } + +#[async_std::test] +pub async fn test_variable_in_input_object() { + #[InputObject] + struct MyInput { + value: i32, + } + + struct QueryRoot; + + #[Object] + impl QueryRoot { + async fn test(&self, input: MyInput) -> i32 { + input.value + } + + async fn test2(&self, input: Vec) -> i32 { + input.iter().map(|item| item.value).sum() + } + } + + struct MutationRoot; + + #[Object] + impl MutationRoot { + async fn test(&self, input: MyInput) -> i32 { + input.value + } + } + + let schema = Schema::new(QueryRoot, MutationRoot, EmptySubscription); + + // test query + { + let query = r#" + query TestQuery($value: Int!) { + test(input: {value: $value }) + }"#; + let resp = QueryBuilder::new(query) + .variables( + Variables::parse_from_json(serde_json::json!({ + "value": 10, + })) + .unwrap(), + ) + .execute(&schema) + .await + .unwrap(); + assert_eq!( + resp.data, + serde_json::json!({ + "test": 10, + }) + ); + } + + // test query2 + { + let query = r#" + query TestQuery($value: Int!) { + test2(input: [{value: $value }, {value: $value }]) + }"#; + let resp = QueryBuilder::new(query) + .variables( + Variables::parse_from_json(serde_json::json!({ + "value": 3, + })) + .unwrap(), + ) + .execute(&schema) + .await + .unwrap(); + assert_eq!( + resp.data, + serde_json::json!({ + "test2": 6, + }) + ); + } + + // test mutation + { + let query = r#" + mutation TestMutation($value: Int!) { + test(input: {value: $value }) + }"#; + let resp = QueryBuilder::new(query) + .variables( + Variables::parse_from_json(serde_json::json!({ + "value": 10, + })) + .unwrap(), + ) + .execute(&schema) + .await + .unwrap(); + assert_eq!( + resp.data, + serde_json::json!({ + "test": 10, + }) + ); + } +}