diff --git a/derive/Cargo.toml b/derive/Cargo.toml index 95eb9993..d7f958a2 100644 --- a/derive/Cargo.toml +++ b/derive/Cargo.toml @@ -17,7 +17,7 @@ proc-macro = true [dependencies] async-graphql-parser = { path = "../parser", version = "=2.1.2" } proc-macro2 = "1.0.6" -syn = { version = "1.0.44", features = ["full", "extra-traits", "visit-mut"] } +syn = { version = "1.0.44", features = ["full", "extra-traits", "visit-mut", "visit"] } quote = "1.0.3" Inflector = "0.11.4" proc-macro-crate = "0.1.4" diff --git a/derive/src/args.rs b/derive/src/args.rs index 27b8eca8..b2ec45e7 100644 --- a/derive/src/args.rs +++ b/derive/src/args.rs @@ -143,6 +143,29 @@ pub struct Object { pub visible: Option, } +pub enum ComplexityType { + Const(usize), + Fn(String), +} + +impl FromMeta for ComplexityType { + fn from_value(value: &Lit) -> darling::Result { + match value { + Lit::Int(n) => { + let n = n.base10_parse::().unwrap(); + if n < 0 { + return Err(darling::Error::custom( + "The complexity must be greater than or equal to 0.", + )); + } + Ok(ComplexityType::Const(n as usize)) + } + Lit::Str(s) => Ok(ComplexityType::Fn(s.value())), + _ => Err(darling::Error::unexpected_lit_type(value)), + } + } +} + #[derive(FromMeta, Default)] #[darling(default)] pub struct ObjectField { @@ -156,6 +179,7 @@ pub struct ObjectField { pub requires: Option, pub guard: Option, pub visible: Option, + pub complexity: Option, } #[derive(FromDeriveInput)] diff --git a/derive/src/interface.rs b/derive/src/interface.rs index 7b8cf8e3..d3d35155 100644 --- a/derive/src/interface.rs +++ b/derive/src/interface.rs @@ -278,6 +278,7 @@ pub fn generate(interface_args: &args::Interface) -> GeneratorResult { + quote! { ::std::option::Option::Some(#crate_name::registry::ComplexityType::Const(#n)) } + } + ComplexityType::Fn(s) => { + let (variables, expr) = parse_complexity_expr(s)?; + let mut parse_args = Vec::new(); + for variable in variables { + if let Some(( + ident, + ty, + args::Argument { + name, + default, + default_with, + .. + }, + )) = args + .iter() + .find(|(pat_ident, _, _)| pat_ident.ident == variable) + { + let default = generate_default(&default, &default_with)?; + let schema_default = default + .as_ref() + .map(|value| { + quote! { + ::std::option::Option::Some(::std::string::ToString::to_string( + &<#ty as #crate_name::InputType>::to_value(&#value) + )) + } + }) + .unwrap_or_else(|| quote! {::std::option::Option::None}); + let name = name.clone().unwrap_or_else(|| { + object_args.rename_args.rename( + ident.ident.unraw().to_string(), + RenameTarget::Argument, + ) + }); + parse_args.push(quote! { + let #ident: #ty = __ctx.param_value(__variables_definition, __field, #name, #schema_default)?; + }); + } + } + quote! { + Some(#crate_name::registry::ComplexityType::Fn(|__ctx, __variables_definition, __field, child_complexity| { + #(#parse_args)* + Ok(#expr) + })) + } + } + } + } else { + quote! { ::std::option::Option::None } + }; + schema_fields.push(quote! { #(#cfg_attrs)* fields.insert(::std::borrow::ToOwned::to_owned(#field_name), #crate_name::registry::MetaField { @@ -403,6 +463,7 @@ pub fn generate( provides: #provides, requires: #requires, visible: #visible, + compute_complexity: #complexity, }); }); @@ -547,3 +608,25 @@ pub fn generate( }; Ok(expanded.into()) } + +#[derive(Default)] +struct VisitComplexityExpr { + variables: HashSet, +} + +impl<'a> Visit<'a> for VisitComplexityExpr { + fn visit_expr_path(&mut self, i: &'a ExprPath) { + if let Some(ident) = i.path.get_ident() { + if ident != "child_complexity" { + self.variables.insert(ident.to_string()); + } + } + } +} + +fn parse_complexity_expr(s: &str) -> GeneratorResult<(HashSet, Expr)> { + let expr: Expr = syn::parse_str(s)?; + let mut visit = VisitComplexityExpr::default(); + visit.visit_expr(&expr); + Ok((visit.variables, expr)) +} diff --git a/derive/src/simple_object.rs b/derive/src/simple_object.rs index 47f5447d..71893e12 100644 --- a/derive/src/simple_object.rs +++ b/derive/src/simple_object.rs @@ -95,6 +95,7 @@ pub fn generate(object_args: &args::SimpleObject) -> GeneratorResult i32 { + todo!() + } + + #[graphql(complexity = "count * child_complexity")] + async fn values(&self, count: u32) -> i32 { + todo!() + } +} +``` + +**Note: The complexity calculation is done in the validation phase and not in the execution phase, so you don't have to worry about the query + over-limit causing the execute only part of the query.** + diff --git a/docs/zh-CN/src/SUMMARY.md b/docs/zh-CN/src/SUMMARY.md index c7fd1cc6..448512e6 100644 --- a/docs/zh-CN/src/SUMMARY.md +++ b/docs/zh-CN/src/SUMMARY.md @@ -23,6 +23,7 @@ - [游标连接](cursor_connections.md) - [错误扩展](error_extensions.md) - [Apollo Tracing支持](apollo_tracing.md) + - [查询的深度和复杂度](depth_and_complexity.md) - [在内省中隐藏内容](visibility.md) - [集成到WebServer](integrations.md) - [Warp](integrations_to_warp.md) diff --git a/docs/zh-CN/src/depth_and_complexity.md b/docs/zh-CN/src/depth_and_complexity.md new file mode 100644 index 00000000..e895e991 --- /dev/null +++ b/docs/zh-CN/src/depth_and_complexity.md @@ -0,0 +1,106 @@ +# 查询的深度和复杂度 + +⚠️GraphQL提供了非常灵活的查询方法,但在客户端上滥用复杂的查询可能造成风险,限制查询语句的深度和复杂度可以减轻这种风险。 + +## 昂贵的查询 + +考虑一种允许列出博客文章的架构。每个博客帖子也与其他帖子相关。 + +```graphql +type Query { + posts(count: Int = 10): [Post!]! +} + +type Post { + title: String! + text: String! + related(count: Int = 10): [Post!]! +} +``` + +创建一个会引起很大响应的查询不是很困难: + +```graphql +{ + posts(count: 100) { + related(count: 100) { + related(count: 100) { + related(count: 100) { + title + } + } + } + } +} +``` + +响应的大小随`related`字段的每个其他级别呈指数增长。幸运的是,`Async-graphql`提供了一种防止此类查询的方法。 + +## 限制查询的深度 + +查询的深度是字段嵌套的层数,下面是一个深度为`3`的查询。 + +```graphql +{ + a { + b { + c + } + } +} +``` + +在创建`Schema`的时候可以限制深度,如果查询语句超过这个限制,则会出错并且返回`Query is nested too deep.`消息。 + +```rust +let schema = Schema::build(Query, EmptyMutation, EmptySubscription) + .limit_depth(5) // 限制最大深度为5 + .finish(); +``` + +## 限制查询的复杂度 + +复杂度是查询语句中字段的数量,每个字段的复杂度默认为`1`,下面是一个复杂度为`6`的查询。 + +```graphql +{ + a b c { + d { + e f + } + } +} +``` + +在创建`Schema`的时候可以限制复杂度,如果查询语句超过这个限制,则会出错并且返回`Query is too complex.`。 + +```rust +let schema = Schema::build(Query, EmptyMutation, EmptySubscription) + .limit_complexity(5) // 限制最大深度为5 + .finish(); +``` + +## 自定义字段的复杂度 + +针对非列表类型和列表类型的字段,有两种自定义复杂度的方法。 +下面的代码中,`value`字段的复杂度为`5`。而`values`字段的复杂度为`count * child_complexity`,`child_complexity`是一个特殊的变量,表示子 +查询的复杂度,`count`是字段的参数,这个公式用于计算`values`字段的复杂度。 + +```rust +struct Query; + +#[Object] +impl Query { + #[graphql(complexity = 5)] + async fn value(&self) -> i32 { + todo!() + } + + #[graphql(complexity = "count * child_complexity")] + async fn values(&self, count: u32) -> i32 { + todo!() + } +} +``` + +**注意: 计算复杂度是在验证阶段完成而不是在执行阶段,所以你不用担心超限的查询语句会导致查询只执行一部分。** diff --git a/integrations/tide/tests/graphql.rs b/integrations/tide/tests/graphql.rs index 62a3cb83..f685d173 100644 --- a/integrations/tide/tests/graphql.rs +++ b/integrations/tide/tests/graphql.rs @@ -47,19 +47,17 @@ async fn quickstart() -> Result<()> { assert_eq!(string, json!({"data": {"add": 30}}).to_string()); - // Note: This test fails due to a bug in tide or reqwest. I will open it again when the bug is fixed. - // - // let resp = client - // .get(listen_addr) - // .query(&[("query", "{ add(a: 10, b: 20) }")]) - // .send() - // .await?; - // - // assert_eq!(resp.status(), StatusCode::OK); - // let string = resp.text().await?; - // println!("via get {}", string); - // - // assert_eq!(string, json!({"data": {"add": 30}}).to_string()); + let resp = client + .get(listen_addr) + .query(&[("query", "{ add(a: 10, b: 20) }")]) + .send() + .await?; + + assert_eq!(resp.status(), StatusCode::OK); + let string = resp.text().await?; + println!("via get {}", string); + + assert_eq!(string, json!({"data": {"add": 30}}).to_string()); Ok(()) } diff --git a/integrations/tide/tests/test_utils.rs b/integrations/tide/tests/test_utils.rs index 79666a8b..d482899d 100644 --- a/integrations/tide/tests/test_utils.rs +++ b/integrations/tide/tests/test_utils.rs @@ -19,5 +19,5 @@ pub fn client() -> Client { } pub async fn wait_server_ready() { - async_std::task::sleep(Duration::from_millis(300)).await; + async_std::task::sleep(Duration::from_millis(500)).await; } diff --git a/src/lib.rs b/src/lib.rs index e5dc79b4..25baf587 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -220,7 +220,7 @@ pub use request::{BatchRequest, Request}; pub use resolver_utils::{ContainerType, EnumType, ScalarType}; pub use response::{BatchResponse, Response}; pub use schema::{Schema, SchemaBuilder, SchemaEnv}; -pub use validation::ValidationMode; +pub use validation::{ValidationMode, VisitorContext}; pub use context::*; #[doc(no_inline)] diff --git a/src/registry/mod.rs b/src/registry/mod.rs index dfe6010f..512d89d4 100644 --- a/src/registry/mod.rs +++ b/src/registry/mod.rs @@ -7,9 +7,11 @@ use std::sync::Arc; use indexmap::map::IndexMap; use indexmap::set::IndexSet; -use crate::parser::types::{BaseType as ParsedBaseType, Type as ParsedType}; +use crate::parser::types::{ + BaseType as ParsedBaseType, Field, Type as ParsedType, VariableDefinition, +}; use crate::validators::InputValueValidator; -use crate::{model, Any, Context, Type, Value}; +use crate::{model, Any, Context, Positioned, ServerResult, Type, Value, VisitorContext}; pub use cache_control::CacheControl; @@ -83,6 +85,14 @@ impl<'a> MetaTypeName<'a> { _ => false, } } + + pub fn is_list(&self) -> bool { + match self { + MetaTypeName::List(_) => true, + MetaTypeName::NonNull(ty) => MetaTypeName::create(ty).is_list(), + MetaTypeName::Named(name) => name.ends_with(']'), + } + } } #[derive(Clone)] @@ -95,6 +105,19 @@ pub struct MetaInputValue { pub visible: Option, } +type ComputeComplexityFn = fn( + &VisitorContext<'_>, + &[Positioned], + &Field, + usize, +) -> ServerResult; + +#[derive(Clone)] +pub enum ComplexityType { + Const(usize), + Fn(ComputeComplexityFn), +} + #[derive(Clone)] pub struct MetaField { pub name: String, @@ -107,6 +130,7 @@ pub struct MetaField { pub requires: Option<&'static str>, pub provides: Option<&'static str>, pub visible: Option, + pub compute_complexity: Option, } #[derive(Clone)] @@ -449,6 +473,7 @@ impl Registry { requires: None, provides: None, visible: None, + compute_complexity: None, }, ); fields @@ -477,6 +502,7 @@ impl Registry { requires: None, provides: None, visible: None, + compute_complexity: None, }, ); @@ -507,6 +533,7 @@ impl Registry { requires: None, provides: None, visible: None, + compute_complexity: None, }, ); } diff --git a/src/schema.rs b/src/schema.rs index 43882932..c82cb5d4 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -232,7 +232,7 @@ where } } - fn create_registry() -> Registry { + pub(crate) fn create_registry() -> Registry { let mut registry = Registry { types: Default::default(), directives: Default::default(), diff --git a/src/types/connection/connection_type.rs b/src/types/connection/connection_type.rs index 6ac20cd6..c126352f 100644 --- a/src/types/connection/connection_type.rs +++ b/src/types/connection/connection_type.rs @@ -160,6 +160,7 @@ where requires: None, provides: None, visible: None, + compute_complexity: None, }, ); @@ -178,6 +179,7 @@ where requires: None, provides: None, visible: None, + compute_complexity: None, }, ); diff --git a/src/types/connection/edge.rs b/src/types/connection/edge.rs index 263b756a..1e0bc325 100644 --- a/src/types/connection/edge.rs +++ b/src/types/connection/edge.rs @@ -79,6 +79,7 @@ where requires: None, provides: None, visible: None, + compute_complexity: None, }, ); @@ -95,6 +96,7 @@ where requires: None, provides: None, visible: None, + compute_complexity: None, }, ); diff --git a/src/types/external/integers.rs b/src/types/external/integers.rs index 995f481c..6341c0a3 100644 --- a/src/types/external/integers.rs +++ b/src/types/external/integers.rs @@ -147,7 +147,7 @@ impl ScalarType for u8 { } fn is_valid(value: &Value) -> bool { - matches!(value, Value::Number(n) if n.is_i64()) + matches!(value, Value::Number(n) if n.is_u64()) } fn to_value(&self) -> Value { @@ -178,7 +178,7 @@ impl ScalarType for u16 { } fn is_valid(value: &Value) -> bool { - matches!(value, Value::Number(n) if n.is_i64()) + matches!(value, Value::Number(n) if n.is_u64()) } fn to_value(&self) -> Value { @@ -209,7 +209,7 @@ impl ScalarType for u32 { } fn is_valid(value: &Value) -> bool { - matches!(value, Value::Number(n) if n.is_i64()) + matches!(value, Value::Number(n) if n.is_u64()) } fn to_value(&self) -> Value { @@ -240,10 +240,72 @@ impl ScalarType for u64 { } fn is_valid(value: &Value) -> bool { - matches!(value, Value::Number(n) if n.is_i64()) + matches!(value, Value::Number(n) if n.is_u64()) } fn to_value(&self) -> Value { Value::Number(Number::from(*self as u64)) } } + +/// The `Int` scalar type represents non-fractional whole numeric values. +#[Scalar(internal, name = "Int")] +impl ScalarType for usize { + fn parse(value: Value) -> InputValueResult { + match value { + Value::Number(n) => { + let n = n + .as_u64() + .ok_or_else(|| InputValueError::from("Invalid number"))?; + if n > Self::MAX as u64 { + return Err(InputValueError::from(format!( + "Only integers from {} to {} are accepted.", + 0, + Self::MAX + ))); + } + Ok(n as Self) + } + _ => Err(InputValueError::expected_type(value)), + } + } + + fn is_valid(value: &Value) -> bool { + matches!(value, Value::Number(n) if n.is_u64()) + } + + fn to_value(&self) -> Value { + Value::Number(Number::from(*self as u64)) + } +} + +/// The `Int` scalar type represents non-fractional whole numeric values. +#[Scalar(internal, name = "Int")] +impl ScalarType for isize { + fn parse(value: Value) -> InputValueResult { + match value { + Value::Number(n) => { + let n = n + .as_i64() + .ok_or_else(|| InputValueError::from("Invalid number"))?; + if n < Self::MIN as i64 || n > Self::MAX as i64 { + return Err(InputValueError::from(format!( + "Only integers from {} to {} are accepted.", + Self::MIN, + Self::MAX + ))); + } + Ok(n as Self) + } + _ => Err(InputValueError::expected_type(value)), + } + } + + fn is_valid(value: &Value) -> bool { + matches!(value, Value::Number(n) if n.is_i64()) + } + + fn to_value(&self) -> Value { + Value::Number(Number::from(*self as i64)) + } +} diff --git a/src/types/query_root.rs b/src/types/query_root.rs index 975f7dcd..35f2afbf 100644 --- a/src/types/query_root.rs +++ b/src/types/query_root.rs @@ -46,6 +46,7 @@ impl Type for QueryRoot { requires: None, provides: None, visible: None, + compute_complexity: None, }, ); @@ -76,6 +77,7 @@ impl Type for QueryRoot { requires: None, provides: None, visible: None, + compute_complexity: None, }, ); } diff --git a/src/validation/mod.rs b/src/validation/mod.rs index c0ac171e..eb2eb894 100644 --- a/src/validation/mod.rs +++ b/src/validation/mod.rs @@ -11,7 +11,9 @@ mod visitors; use crate::parser::types::ExecutableDocument; use crate::registry::Registry; use crate::{CacheControl, ServerError, Variables}; -use visitor::{visit, VisitorContext, VisitorNil}; +use visitor::{visit, VisitorNil}; + +pub use visitor::VisitorContext; pub struct CheckResult { pub cache_control: CacheControl, @@ -67,9 +69,7 @@ pub fn check_rules( .with(visitors::CacheControlCalculate { cache_control: &mut cache_control, }) - .with(visitors::ComplexityCalculate { - complexity: &mut complexity, - }) + .with(visitors::ComplexityCalculate::new(&mut complexity)) .with(visitors::DepthCalculate::new(&mut depth)); visit(&mut visitor, &mut ctx, doc); } @@ -80,9 +80,7 @@ pub fn check_rules( .with(visitors::CacheControlCalculate { cache_control: &mut cache_control, }) - .with(visitors::ComplexityCalculate { - complexity: &mut complexity, - }) + .with(visitors::ComplexityCalculate::new(&mut complexity)) .with(visitors::DepthCalculate::new(&mut depth)); visit(&mut visitor, &mut ctx, doc); } @@ -95,6 +93,6 @@ pub fn check_rules( Ok(CheckResult { cache_control, complexity, - depth: depth as usize, + depth, }) } diff --git a/src/validation/visitor.rs b/src/validation/visitor.rs index f061454a..2c31b0e7 100644 --- a/src/validation/visitor.rs +++ b/src/validation/visitor.rs @@ -1,15 +1,17 @@ use std::collections::HashMap; use std::fmt::{self, Display, Formatter}; +use async_graphql_value::Value; + use crate::parser::types::{ Directive, ExecutableDocument, Field, FragmentDefinition, FragmentSpread, InlineFragment, OperationDefinition, OperationType, Selection, SelectionSet, TypeCondition, VariableDefinition, }; use crate::registry::{self, MetaType, MetaTypeName}; -use crate::{Name, Pos, Positioned, ServerError, Variables}; -use async_graphql_value::Value; +use crate::{InputType, Name, Pos, Positioned, ServerError, ServerResult, Variables}; -pub(crate) struct VisitorContext<'a> { +#[doc(hidden)] +pub struct VisitorContext<'a> { pub(crate) registry: &'a registry::Registry, pub(crate) variables: Option<&'a Variables>, pub(crate) errors: Vec, @@ -87,9 +89,68 @@ impl<'a> VisitorContext<'a> { pub(crate) fn fragment(&self, name: &str) -> Option<&'a Positioned> { self.fragments.get(name) } + + #[doc(hidden)] + pub fn param_value( + &self, + variable_definitions: &[Positioned], + field: &Field, + name: &str, + default: Option T>, + ) -> ServerResult { + let value = field.get_argument(name).cloned(); + + if value.is_none() { + if let Some(default) = default { + return Ok(default()); + } + } + + let (pos, value) = match value { + Some(value) => { + let pos = value.pos; + ( + pos, + Some(value.node.into_const_with(|name| { + variable_definitions + .iter() + .find(|def| def.node.name.node == name) + .and_then(|def| { + if let Some(variables) = self.variables { + variables + .0 + .get(&def.node.name.node) + .or_else(|| def.node.default_value()) + } else { + None + } + }) + .cloned() + .ok_or_else(|| { + ServerError::new(format!("Variable {} is not defined.", name)) + .at(pos) + }) + })?), + ) + } + None => (Pos::default(), None), + }; + + T::parse(value).map_err(|e| e.into_server_error().at(pos)) + } +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub(crate) enum VisitMode { + Normal, + Inline, } pub(crate) trait Visitor<'a> { + fn mode(&self) -> VisitMode { + VisitMode::Normal + } + fn enter_document(&mut self, _ctx: &mut VisitorContext<'a>, _doc: &'a ExecutableDocument) {} fn exit_document(&mut self, _ctx: &mut VisitorContext<'a>, _doc: &'a ExecutableDocument) {} @@ -687,10 +748,12 @@ fn visit_fragment_definition<'a, V: Visitor<'a>>( name: &'a Name, fragment: &'a Positioned, ) { - v.enter_fragment_definition(ctx, name, fragment); - visit_directives(v, ctx, &fragment.node.directives); - visit_selection_set(v, ctx, &fragment.node.selection_set); - v.exit_fragment_definition(ctx, name, fragment); + if v.mode() == VisitMode::Normal { + v.enter_fragment_definition(ctx, name, fragment); + visit_directives(v, ctx, &fragment.node.directives); + visit_selection_set(v, ctx, &fragment.node.selection_set); + v.exit_fragment_definition(ctx, name, fragment); + } } fn visit_fragment_spread<'a, V: Visitor<'a>>( @@ -700,6 +763,14 @@ fn visit_fragment_spread<'a, V: Visitor<'a>>( ) { v.enter_fragment_spread(ctx, fragment_spread); visit_directives(v, ctx, &fragment_spread.node.directives); + if v.mode() == VisitMode::Inline { + if let Some(fragment) = ctx + .fragments + .get(fragment_spread.node.fragment_name.node.as_str()) + { + visit_selection_set(v, ctx, &fragment.node.selection_set); + } + } v.exit_fragment_spread(ctx, fragment_spread); } diff --git a/src/validation/visitors/cache_control.rs b/src/validation/visitors/cache_control.rs index edd6d086..255c939f 100644 --- a/src/validation/visitors/cache_control.rs +++ b/src/validation/visitors/cache_control.rs @@ -1,6 +1,6 @@ use crate::parser::types::{Field, SelectionSet}; use crate::registry::MetaType; -use crate::validation::visitor::{Visitor, VisitorContext}; +use crate::validation::visitor::{VisitMode, Visitor, VisitorContext}; use crate::{CacheControl, Positioned}; pub struct CacheControlCalculate<'a> { @@ -8,6 +8,10 @@ pub struct CacheControlCalculate<'a> { } impl<'ctx, 'a> Visitor<'ctx> for CacheControlCalculate<'a> { + fn mode(&self) -> VisitMode { + VisitMode::Inline + } + fn enter_selection_set( &mut self, ctx: &mut VisitorContext<'_>, diff --git a/src/validation/visitors/complexity.rs b/src/validation/visitors/complexity.rs index e86934de..5c1b3cd2 100644 --- a/src/validation/visitors/complexity.rs +++ b/src/validation/visitors/complexity.rs @@ -1,13 +1,254 @@ use crate::parser::types::Field; -use crate::validation::visitor::{Visitor, VisitorContext}; +use crate::registry::{ComplexityType, MetaType, MetaTypeName}; +use crate::validation::visitor::{VisitMode, Visitor, VisitorContext}; use crate::Positioned; +use async_graphql_parser::types::{ExecutableDocument, OperationDefinition, VariableDefinition}; +use async_graphql_value::Name; -pub struct ComplexityCalculate<'a> { +pub struct ComplexityCalculate<'ctx, 'a> { pub complexity: &'a mut usize, + pub complexity_stack: Vec, + pub variable_definition: Option<&'ctx [Positioned]>, } -impl<'ctx, 'a> Visitor<'ctx> for ComplexityCalculate<'a> { - fn enter_field(&mut self, _ctx: &mut VisitorContext<'_>, _field: &Positioned) { - *self.complexity += 1; +impl<'ctx, 'a> ComplexityCalculate<'ctx, 'a> { + pub fn new(complexity: &'a mut usize) -> Self { + Self { + complexity, + complexity_stack: Default::default(), + variable_definition: None, + } + } +} + +impl<'ctx, 'a> Visitor<'ctx> for ComplexityCalculate<'ctx, 'a> { + fn mode(&self) -> VisitMode { + VisitMode::Inline + } + + fn enter_document(&mut self, _ctx: &mut VisitorContext<'ctx>, _doc: &'ctx ExecutableDocument) { + self.complexity_stack.push(0); + } + + fn exit_document(&mut self, _ctx: &mut VisitorContext<'ctx>, _doc: &'ctx ExecutableDocument) { + *self.complexity = self.complexity_stack.pop().unwrap(); + } + + fn enter_operation_definition( + &mut self, + _ctx: &mut VisitorContext<'ctx>, + _name: Option<&'ctx Name>, + operation_definition: &'ctx Positioned, + ) { + self.variable_definition = Some(&operation_definition.node.variable_definitions); + } + + fn enter_field(&mut self, _ctx: &mut VisitorContext<'_>, _field: &Positioned) { + self.complexity_stack.push(0); + } + + fn exit_field(&mut self, ctx: &mut VisitorContext<'ctx>, field: &'ctx Positioned) { + let children_complex = self.complexity_stack.pop().unwrap(); + + if let Some(ty) = ctx.parent_type() { + if let MetaType::Object { fields, .. } = ty { + if let Some(meta_field) = fields.get(MetaTypeName::concrete_typename( + field.node.name.node.as_str(), + )) { + if let Some(compute_complexity) = &meta_field.compute_complexity { + match compute_complexity { + ComplexityType::Const(n) => { + *self.complexity_stack.last_mut().unwrap() += n; + } + ComplexityType::Fn(f) => { + if MetaTypeName::create(&meta_field.ty).is_list() { + match f( + ctx, + self.variable_definition.unwrap(), + &field.node, + children_complex, + ) { + Ok(n) => { + *self.complexity_stack.last_mut().unwrap() += n; + } + Err(err) => { + ctx.report_error(vec![field.pos], err.to_string()) + } + } + } + } + } + + return; + } + } + } + } + + *self.complexity_stack.last_mut().unwrap() += 1 + children_complex; + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parser::parse_query; + use crate::validation::{visit, VisitorContext}; + use crate::{EmptyMutation, EmptySubscription, Object, Schema}; + + struct Query; + + #[derive(Copy, Clone)] + struct MyObj; + + #[Object(internal)] + impl MyObj { + async fn a(&self) -> i32 { + 1 + } + + async fn b(&self) -> i32 { + 2 + } + + async fn c(&self) -> MyObj { + MyObj + } + } + + #[Object(internal)] + impl Query { + async fn value(&self) -> i32 { + 1 + } + + async fn obj(&self) -> MyObj { + MyObj + } + + #[graphql(complexity = "count * child_complexity")] + async fn objs(&self, count: usize) -> Vec { + vec![MyObj; count as usize] + } + + #[graphql(complexity = 3)] + async fn d(&self) -> MyObj { + MyObj + } + } + + fn check_complex(query: &str, expect_complex: usize) { + let registry = Schema::::create_registry(); + let doc = parse_query(query).unwrap(); + let mut ctx = VisitorContext::new(®istry, &doc, None); + let mut complex = 0; + let mut complex_calculate = ComplexityCalculate::new(&mut complex); + visit(&mut complex_calculate, &mut ctx, &doc); + assert_eq!(complex, expect_complex); + } + + #[test] + fn complex() { + check_complex( + r#" + { + value #1 + }"#, + 1, + ); + + check_complex( + r#" + { + value #1 + d #3 + }"#, + 4, + ); + + check_complex( + r#" + { + value obj { #2 + a b #2 + } + }"#, + 4, + ); + + check_complex( + r#" + { + value obj { #2 + a b obj { #3 + a b obj { #3 + a #1 + } + } + } + }"#, + 9, + ); + + check_complex( + r#" + fragment A on MyObj { + a b ... A2 #2 + } + + fragment A2 on MyObj { + obj { # 1 + a # 1 + } + } + + query { + obj { # 1 + ... A + } + }"#, + 5, + ); + + check_complex( + r#" + { + obj { # 1 + ... on MyObj { + a b #2 + ... on MyObj { + obj { #1 + a #1 + } + } + } + } + }"#, + 5, + ); + + check_complex( + r#" + { + objs(count: 10) { + a b + } + }"#, + 20, + ); + + check_complex( + r#" + fragment A on MyObj { + a b + } + + query { + objs(count: 10) { + ... A + } + }"#, + 20, + ); } } diff --git a/src/validation/visitors/depth.rs b/src/validation/visitors/depth.rs index b75a901d..fdb886aa 100644 --- a/src/validation/visitors/depth.rs +++ b/src/validation/visitors/depth.rs @@ -1,69 +1,151 @@ -use crate::parser::types::{FragmentSpread, InlineFragment, SelectionSet}; -use crate::validation::visitor::{Visitor, VisitorContext}; +use crate::validation::visitor::{VisitMode, Visitor, VisitorContext}; use crate::Positioned; +use async_graphql_parser::types::Field; pub struct DepthCalculate<'a> { - max_depth: &'a mut i32, - current_depth: i32, + max_depth: &'a mut usize, + current_depth: usize, } impl<'a> DepthCalculate<'a> { - pub fn new(max_depth: &'a mut i32) -> Self { - *max_depth = -1; + pub fn new(max_depth: &'a mut usize) -> Self { Self { max_depth, - current_depth: -1, + current_depth: 0, } } } impl<'ctx, 'a> Visitor<'ctx> for DepthCalculate<'a> { - fn enter_selection_set( - &mut self, - _ctx: &mut VisitorContext<'ctx>, - _selection_set: &'ctx Positioned, - ) { + fn mode(&self) -> VisitMode { + VisitMode::Inline + } + + fn enter_field(&mut self, _ctx: &mut VisitorContext<'ctx>, _field: &'ctx Positioned) { self.current_depth += 1; *self.max_depth = (*self.max_depth).max(self.current_depth); } - fn exit_selection_set( - &mut self, - _ctx: &mut VisitorContext<'ctx>, - _selection_set: &'ctx Positioned, - ) { + fn exit_field(&mut self, _ctx: &mut VisitorContext<'ctx>, _field: &'ctx Positioned) { self.current_depth -= 1; } - - fn enter_fragment_spread( - &mut self, - _ctx: &mut VisitorContext<'ctx>, - _fragment_spread: &'ctx Positioned, - ) { - self.current_depth -= 1; - } - - fn exit_fragment_spread( - &mut self, - _ctx: &mut VisitorContext<'ctx>, - _fragment_spread: &'ctx Positioned, - ) { - self.current_depth += 1; - } - - fn enter_inline_fragment( - &mut self, - _ctx: &mut VisitorContext<'ctx>, - _inline_fragment: &'ctx Positioned, - ) { - self.current_depth -= 1; - } - - fn exit_inline_fragment( - &mut self, - _ctx: &mut VisitorContext<'ctx>, - _inline_fragment: &'ctx Positioned, - ) { - self.current_depth += 1; - } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parser::parse_query; + use crate::validation::{visit, VisitorContext}; + use crate::{EmptyMutation, EmptySubscription, Object, Schema}; + + struct Query; + + struct MyObj; + + #[Object(internal)] + impl MyObj { + async fn a(&self) -> i32 { + 1 + } + + async fn b(&self) -> i32 { + 2 + } + + async fn c(&self) -> MyObj { + MyObj + } + } + + #[Object(internal)] + impl Query { + async fn value(&self) -> i32 { + 1 + } + + async fn obj(&self) -> MyObj { + MyObj + } + } + + fn check_depth(query: &str, expect_depth: usize) { + let registry = Schema::::create_registry(); + let doc = parse_query(query).unwrap(); + let mut ctx = VisitorContext::new(®istry, &doc, None); + let mut depth = 0; + let mut depth_calculate = DepthCalculate::new(&mut depth); + visit(&mut depth_calculate, &mut ctx, &doc); + assert_eq!(depth, expect_depth); + } + + #[test] + fn depth() { + check_depth( + r#"{ + value #1 + }"#, + 1, + ); + + check_depth( + r#" + { + obj { #1 + a b #2 + } + }"#, + 2, + ); + + check_depth( + r#" + { + obj { # 1 + a b c { # 2 + a b c { # 3 + a b # 4 + } + } + } + }"#, + 4, + ); + + check_depth( + r#" + fragment A on MyObj { + a b ... A2 #2 + } + + fragment A2 on MyObj { + obj { + a #3 + } + } + + query { + obj { # 1 + ... A + } + }"#, + 3, + ); + + check_depth( + r#" + { + obj { # 1 + ... on MyObj { + a b #2 + ... on MyObj { + obj { + a #3 + } + } + } + } + }"#, + 3, + ); + } } diff --git a/tests/complexity_and_depth.rs b/tests/complexity_and_depth.rs deleted file mode 100644 index 7db786a1..00000000 --- a/tests/complexity_and_depth.rs +++ /dev/null @@ -1,139 +0,0 @@ -use async_graphql::*; - -#[async_std::test] -pub async fn test_complexity_and_depth() { - struct Query; - - struct MyObj; - - #[Object] - impl MyObj { - async fn a(&self) -> i32 { - 1 - } - - async fn b(&self) -> i32 { - 2 - } - - async fn c(&self) -> MyObj { - MyObj - } - } - - #[Object] - impl Query { - async fn value(&self) -> i32 { - 1 - } - - async fn obj(&self) -> MyObj { - MyObj - } - } - - let query = "{ a:value b:value c:value }"; - let schema = Schema::build(Query, EmptyMutation, EmptySubscription) - .limit_complexity(2) - .finish(); - assert_eq!( - schema.execute(query).await.into_result().unwrap_err(), - vec![ServerError { - message: "Query is too complex.".to_owned(), - locations: Vec::new(), - path: Vec::new(), - extensions: None, - }] - ); - - let query = "{ a:value b:value }"; - let schema = Schema::build(Query, EmptyMutation, EmptySubscription) - .limit_complexity(2) - .finish(); - assert_eq!( - schema.execute(query).await.data, - value!({ - "a": 1, - "b": 1, - }) - ); - - let query = "{ obj { a b } }"; - let schema = Schema::build(Query, EmptyMutation, EmptySubscription) - .limit_complexity(2) - .finish(); - assert_eq!( - schema.execute(query).await.into_result().unwrap_err(), - vec![ServerError { - message: "Query is too complex.".to_owned(), - locations: Vec::new(), - path: Vec::new(), - extensions: None, - }] - ); - - let query = "{ obj { a } }"; - let schema = Schema::build(Query, EmptyMutation, EmptySubscription) - .limit_complexity(2) - .finish(); - assert_eq!( - schema.execute(query).await.data, - value!({ - "obj": { "a": 1 } - }) - ); - - let query = r#" - { - obj { # 0 - a b c { # 1 - a b c { # 2 - a b # 3 - } - } - } - }"#; - let schema = Schema::build(Query, EmptyMutation, EmptySubscription) - .limit_depth(2) - .finish(); - assert_eq!( - schema.execute(query).await.into_result().unwrap_err(), - vec![ServerError { - message: "Query is nested too deep.".to_owned(), - locations: Vec::new(), - path: Vec::new(), - extensions: None, - }] - ); - - let query = r#" - { - obj { # 0 - a b c { # 1 - a b c { # 2 - a b # 3 - } - } - } - }"#; - let schema = Schema::build(Query, EmptyMutation, EmptySubscription) - .limit_depth(3) - .finish(); - assert_eq!( - schema.execute(query).await.data, - value!({ - "obj": { - "a": 1, - "b": 2, - "c": { - "a": 1, - "b": 2, - "c": { - "a": 1, - "b": 2, - } - } - } - }) - ); -}