Fix incorrect variable substitution. #126

This commit is contained in:
Sunli 2020-05-31 11:54:07 +08:00
parent 7f62b7d92c
commit 267620d08e
3 changed files with 129 additions and 20 deletions

View File

@ -431,26 +431,25 @@ impl<'a, T> ContextBase<'a, T> {
.into_error(pos)) .into_error(pos))
} }
fn resolve_input_value(&self, mut value: Value, pos: Pos) -> Result<Value> { fn resolve_input_value(&self, value: &mut Value, pos: Pos) -> Result<()> {
match value { match value {
Value::Variable(var_name) => self.var_value(&var_name, pos), Value::Variable(var_name) => {
*value = self.var_value(&var_name, pos)?;
Ok(())
}
Value::List(ref mut ls) => { Value::List(ref mut ls) => {
for value in ls { for value in ls {
if let Value::Variable(var_name) = value { self.resolve_input_value(value, pos)?;
*value = self.var_value(&var_name, pos)?;
}
} }
Ok(value) Ok(())
} }
Value::Object(ref mut obj) => { Value::Object(ref mut obj) => {
for value in obj.values_mut() { for value in obj.values_mut() {
if let Value::Variable(var_name) = value { self.resolve_input_value(value, pos)?;
*value = self.var_value(&var_name, pos)?;
}
} }
Ok(value) Ok(())
} }
_ => Ok(value), _ => Ok(()),
} }
} }
@ -459,9 +458,9 @@ impl<'a, T> ContextBase<'a, T> {
for directive in directives { for directive in directives {
if directive.name.node == "skip" { if directive.name.node == "skip" {
if let Some(value) = directive.get_argument("if") { if let Some(value) = directive.get_argument("if") {
match InputValueType::parse(Some( let mut inner_value = value.clone_inner();
self.resolve_input_value(value.clone_inner(), value.position())?, self.resolve_input_value(&mut inner_value, value.pos)?;
)) { match InputValueType::parse(Some(inner_value)) {
Ok(true) => return Ok(true), Ok(true) => return Ok(true),
Ok(false) => {} Ok(false) => {}
Err(err) => { Err(err) => {
@ -478,9 +477,9 @@ impl<'a, T> ContextBase<'a, T> {
} }
} else if directive.name.node == "include" { } else if directive.name.node == "include" {
if let Some(value) = directive.get_argument("if") { if let Some(value) = directive.get_argument("if") {
match InputValueType::parse(Some( let mut inner_value = value.clone_inner();
self.resolve_input_value(value.clone_inner(), value.position())?, self.resolve_input_value(&mut inner_value, value.pos)?;
)) { match InputValueType::parse(Some(inner_value)) {
Ok(false) => return Ok(true), Ok(false) => return Ok(true),
Ok(true) => {} Ok(true) => {}
Err(err) => { Err(err) => {
@ -547,11 +546,16 @@ impl<'a> ContextBase<'a, &'a Positioned<Field>> {
.as_ref() .as_ref()
.map(|value| value.position()) .map(|value| value.position())
.unwrap_or_default(); .unwrap_or_default();
let resolved_value = match value { let value = match value {
Some(value) => Some(self.resolve_input_value(value.into_inner(), pos)?), Some(value) => {
let mut new_value = value.into_inner();
self.resolve_input_value(&mut new_value, pos)?;
Some(new_value)
}
None => None, None => None,
}; };
match InputValueType::parse(resolved_value) {
match InputValueType::parse(value) {
Ok(res) => Ok(res), Ok(res) => Ok(res),
Err(err) => Err(err.into_error(pos, T::qualified_type_name())), Err(err) => Err(err.into_error(pos, T::qualified_type_name())),
} }

View File

@ -91,6 +91,7 @@
#![warn(missing_docs)] #![warn(missing_docs)]
#![allow(clippy::needless_doctest_main)] #![allow(clippy::needless_doctest_main)]
#![allow(clippy::needless_lifetimes)] #![allow(clippy::needless_lifetimes)]
#![allow(clippy::trivially_copy_pass_by_ref)]
#![recursion_limit = "256"] #![recursion_limit = "256"]
#![forbid(unsafe_code)] #![forbid(unsafe_code)]

View File

@ -70,3 +70,107 @@ pub async fn test_variable_default_value() {
}) })
); );
} }
#[async_std::test]
pub async fn test_variable_in_input_object() {
#[InputObject]
struct MyInput {
value: i32,
}
struct QueryRoot;
#[Object]
impl QueryRoot {
async fn test(&self, input: MyInput) -> i32 {
input.value
}
async fn test2(&self, input: Vec<MyInput>) -> i32 {
input.iter().map(|item| item.value).sum()
}
}
struct MutationRoot;
#[Object]
impl MutationRoot {
async fn test(&self, input: MyInput) -> i32 {
input.value
}
}
let schema = Schema::new(QueryRoot, MutationRoot, EmptySubscription);
// test query
{
let query = r#"
query TestQuery($value: Int!) {
test(input: {value: $value })
}"#;
let resp = QueryBuilder::new(query)
.variables(
Variables::parse_from_json(serde_json::json!({
"value": 10,
}))
.unwrap(),
)
.execute(&schema)
.await
.unwrap();
assert_eq!(
resp.data,
serde_json::json!({
"test": 10,
})
);
}
// test query2
{
let query = r#"
query TestQuery($value: Int!) {
test2(input: [{value: $value }, {value: $value }])
}"#;
let resp = QueryBuilder::new(query)
.variables(
Variables::parse_from_json(serde_json::json!({
"value": 3,
}))
.unwrap(),
)
.execute(&schema)
.await
.unwrap();
assert_eq!(
resp.data,
serde_json::json!({
"test2": 6,
})
);
}
// test mutation
{
let query = r#"
mutation TestMutation($value: Int!) {
test(input: {value: $value })
}"#;
let resp = QueryBuilder::new(query)
.variables(
Variables::parse_from_json(serde_json::json!({
"value": 10,
}))
.unwrap(),
)
.execute(&schema)
.await
.unwrap();
assert_eq!(
resp.data,
serde_json::json!({
"test": 10,
})
);
}
}