Improve depth and complex.

This commit is contained in:
Sunli 2020-12-18 14:59:37 +08:00
parent c6d8df0bdb
commit bf9e3ec578
25 changed files with 920 additions and 235 deletions

View File

@ -17,7 +17,7 @@ proc-macro = true
[dependencies]
async-graphql-parser = { path = "../parser", version = "=2.1.2" }
proc-macro2 = "1.0.6"
syn = { version = "1.0.44", features = ["full", "extra-traits", "visit-mut"] }
syn = { version = "1.0.44", features = ["full", "extra-traits", "visit-mut", "visit"] }
quote = "1.0.3"
Inflector = "0.11.4"
proc-macro-crate = "0.1.4"

View File

@ -143,6 +143,29 @@ pub struct Object {
pub visible: Option<Visible>,
}
pub enum ComplexityType {
Const(usize),
Fn(String),
}
impl FromMeta for ComplexityType {
fn from_value(value: &Lit) -> darling::Result<Self> {
match value {
Lit::Int(n) => {
let n = n.base10_parse::<i32>().unwrap();
if n < 0 {
return Err(darling::Error::custom(
"The complexity must be greater than or equal to 0.",
));
}
Ok(ComplexityType::Const(n as usize))
}
Lit::Str(s) => Ok(ComplexityType::Fn(s.value())),
_ => Err(darling::Error::unexpected_lit_type(value)),
}
}
}
#[derive(FromMeta, Default)]
#[darling(default)]
pub struct ObjectField {
@ -156,6 +179,7 @@ pub struct ObjectField {
pub requires: Option<String>,
pub guard: Option<Meta>,
pub visible: Option<Visible>,
pub complexity: Option<ComplexityType>,
}
#[derive(FromDeriveInput)]

View File

@ -278,6 +278,7 @@ pub fn generate(interface_args: &args::Interface) -> GeneratorResult<TokenStream
provides: #provides,
requires: #requires,
visible: #visible,
compute_complexity: ::std::option::Option::None,
});
});

View File

@ -1,15 +1,19 @@
use proc_macro::TokenStream;
use quote::quote;
use syn::ext::IdentExt;
use syn::{Block, Error, FnArg, ImplItem, ItemImpl, Pat, ReturnType, Type, TypeReference};
use syn::visit::Visit;
use syn::{
Block, Error, Expr, ExprPath, FnArg, ImplItem, ItemImpl, Pat, ReturnType, Type, TypeReference,
};
use crate::args::{self, RenameRuleExt, RenameTarget};
use crate::args::{self, ComplexityType, RenameRuleExt, RenameTarget};
use crate::output_type::OutputType;
use crate::utils::{
generate_default, generate_guards, generate_validator, get_cfg_attrs, get_crate_name,
get_param_getter_ident, get_rustdoc, get_type_path_and_name, parse_graphql_attrs,
remove_graphql_attrs, visible_fn, GeneratorResult,
};
use std::collections::HashSet;
pub fn generate(
object_args: &args::Object,
@ -322,7 +326,7 @@ pub fn generate(
visible,
..
},
) in args
) in &args
{
let name = name.clone().unwrap_or_else(|| {
object_args
@ -386,6 +390,62 @@ pub fn generate(
let schema_ty = ty.value_type();
let visible = visible_fn(&method_args.visible);
let complexity = if let Some(complexity) = &method_args.complexity {
match complexity {
ComplexityType::Const(n) => {
quote! { ::std::option::Option::Some(#crate_name::registry::ComplexityType::Const(#n)) }
}
ComplexityType::Fn(s) => {
let (variables, expr) = parse_complexity_expr(s)?;
let mut parse_args = Vec::new();
for variable in variables {
if let Some((
ident,
ty,
args::Argument {
name,
default,
default_with,
..
},
)) = args
.iter()
.find(|(pat_ident, _, _)| pat_ident.ident == variable)
{
let default = generate_default(&default, &default_with)?;
let schema_default = default
.as_ref()
.map(|value| {
quote! {
::std::option::Option::Some(::std::string::ToString::to_string(
&<#ty as #crate_name::InputType>::to_value(&#value)
))
}
})
.unwrap_or_else(|| quote! {::std::option::Option::None});
let name = name.clone().unwrap_or_else(|| {
object_args.rename_args.rename(
ident.ident.unraw().to_string(),
RenameTarget::Argument,
)
});
parse_args.push(quote! {
let #ident: #ty = __ctx.param_value(__variables_definition, __field, #name, #schema_default)?;
});
}
}
quote! {
Some(#crate_name::registry::ComplexityType::Fn(|__ctx, __variables_definition, __field, child_complexity| {
#(#parse_args)*
Ok(#expr)
}))
}
}
}
} else {
quote! { ::std::option::Option::None }
};
schema_fields.push(quote! {
#(#cfg_attrs)*
fields.insert(::std::borrow::ToOwned::to_owned(#field_name), #crate_name::registry::MetaField {
@ -403,6 +463,7 @@ pub fn generate(
provides: #provides,
requires: #requires,
visible: #visible,
compute_complexity: #complexity,
});
});
@ -547,3 +608,25 @@ pub fn generate(
};
Ok(expanded.into())
}
#[derive(Default)]
struct VisitComplexityExpr {
variables: HashSet<String>,
}
impl<'a> Visit<'a> for VisitComplexityExpr {
fn visit_expr_path(&mut self, i: &'a ExprPath) {
if let Some(ident) = i.path.get_ident() {
if ident != "child_complexity" {
self.variables.insert(ident.to_string());
}
}
}
}
fn parse_complexity_expr(s: &str) -> GeneratorResult<(HashSet<String>, Expr)> {
let expr: Expr = syn::parse_str(s)?;
let mut visit = VisitComplexityExpr::default();
visit.visit_expr(&expr);
Ok((visit.variables, expr))
}

View File

@ -95,6 +95,7 @@ pub fn generate(object_args: &args::SimpleObject) -> GeneratorResult<TokenStream
provides: #provides,
requires: #requires,
visible: #visible,
compute_complexity: ::std::option::Option::None,
});
});

View File

@ -256,6 +256,7 @@ pub fn generate(
requires: ::std::option::Option::None,
provides: ::std::option::Option::None,
visible: #visible,
compute_complexity: ::std::option::Option::None,
});
});

View File

@ -23,6 +23,7 @@
- [Cursor connections](cursor_connections.md)
- [Error extensions](error_extensions.md)
- [Apollo Tracing](apollo_tracing.md)
- [Query complexity and depth](depth_and_complexity.md)
- [Hide content in introspection](visibility.md)
- [Integrations](integrations.md)
- [Warp](integrations_to_warp.md)

View File

@ -0,0 +1,117 @@
# Query complexity and depth
GraphQL provides a powerful way to query your data, but putting great
power in the hands of your API clients also exposes you to a risk of denial
of service attacks. You can mitigate that risk with `Async-graphql` by limiting the
complexity and depth of the queries you allow.
## Expensive Queries
Consider a schema that allows listing blog posts. Each blog post is also related to other posts.
```graphql
type Query {
posts(count: Int = 10): [Post!]!
}
type Post {
title: String!
text: String!
related(count: Int = 10): [Post!]!
}
```
Its not too hard to craft a query that will cause a very large response:
```graphql
{
posts(count: 100) {
related(count: 100) {
related(count: 100) {
related(count: 100) {
title
}
}
}
}
}
```
The size of the response increases exponentially with every other level of the `related` field. Fortunately, `Async-graphql` provides
a way to prevent such queries.
## Limiting Query depth
The depth is the number of nesting levels of the field, and the following is a query with a depth of `3`.
```graphql
{
a {
b {
c
}
}
}
```
You can limit the depth when creating `Schema`. If the query exceeds this limit, an error will occur and the
message `Query is nested too deep` will be returned.
```rust
let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
.limit_depth(5) // Limit the maximum depth to 5
.finish();
```
## Limiting Query complexity
The complexity is the number of fields in the query. The default complexity of each field is `1`. Below is a
query with a complexity of `6`.
```graphql
{
a b c {
d {
e f
}
}
}
```
You can limit the complexity when creating the `Schema`. If the query exceeds this limit, an error will occur
and `Query is too complex` will be returned.
```rust
let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
.limit_complexity(5) // Limit the maximum complexity to 5
.finish();
```
## Custom Complexity Calculation
There are two ways to customize the complexity for non-list type and list type fields.
In the following code, the complexity of the `value` field is `5`. The complexity of the `values` field is `count * child_complexity`,
`child_complexity` is a special variable that represents the complexity of the subquery, and `count` is the parameter of the field,
used to calculate the complexity of the `values` field.
```rust
struct Query;
#[Object]
impl Query {
#[graphql(complexity = 5)]
async fn value(&self) -> i32 {
todo!()
}
#[graphql(complexity = "count * child_complexity")]
async fn values(&self, count: u32) -> i32 {
todo!()
}
}
```
**Note: The complexity calculation is done in the validation phase and not in the execution phase, so you don't have to worry about the query
over-limit causing the execute only part of the query.**

View File

@ -23,6 +23,7 @@
- [游标连接](cursor_connections.md)
- [错误扩展](error_extensions.md)
- [Apollo Tracing支持](apollo_tracing.md)
- [查询的深度和复杂度](depth_and_complexity.md)
- [在内省中隐藏内容](visibility.md)
- [集成到WebServer](integrations.md)
- [Warp](integrations_to_warp.md)

View File

@ -0,0 +1,106 @@
# 查询的深度和复杂度
GraphQL提供了非常灵活的查询方法但在客户端上滥用复杂的查询可能造成风险限制查询语句的深度和复杂度可以减轻这种风险。
## 昂贵的查询
考虑一种允许列出博客文章的架构。每个博客帖子也与其他帖子相关。
```graphql
type Query {
posts(count: Int = 10): [Post!]!
}
type Post {
title: String!
text: String!
related(count: Int = 10): [Post!]!
}
```
创建一个会引起很大响应的查询不是很困难:
```graphql
{
posts(count: 100) {
related(count: 100) {
related(count: 100) {
related(count: 100) {
title
}
}
}
}
}
```
响应的大小随`related`字段的每个其他级别呈指数增长。幸运的是,`Async-graphql`提供了一种防止此类查询的方法。
## 限制查询的深度
查询的深度是字段嵌套的层数,下面是一个深度为`3`的查询。
```graphql
{
a {
b {
c
}
}
}
```
在创建`Schema`的时候可以限制深度,如果查询语句超过这个限制,则会出错并且返回`Query is nested too deep.`消息。
```rust
let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
.limit_depth(5) // 限制最大深度为5
.finish();
```
## 限制查询的复杂度
复杂度是查询语句中字段的数量,每个字段的复杂度默认为`1`,下面是一个复杂度为`6`的查询。
```graphql
{
a b c {
d {
e f
}
}
}
```
在创建`Schema`的时候可以限制复杂度,如果查询语句超过这个限制,则会出错并且返回`Query is too complex.`。
```rust
let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
.limit_complexity(5) // 限制最大深度为5
.finish();
```
## 自定义字段的复杂度
针对非列表类型和列表类型的字段,有两种自定义复杂度的方法。
下面的代码中,`value`字段的复杂度为`5`。而`values`字段的复杂度为`count * child_complexity``child_complexity`是一个特殊的变量,表示子
查询的复杂度,`count`是字段的参数,这个公式用于计算`values`字段的复杂度。
```rust
struct Query;
#[Object]
impl Query {
#[graphql(complexity = 5)]
async fn value(&self) -> i32 {
todo!()
}
#[graphql(complexity = "count * child_complexity")]
async fn values(&self, count: u32) -> i32 {
todo!()
}
}
```
**注意: 计算复杂度是在验证阶段完成而不是在执行阶段,所以你不用担心超限的查询语句会导致查询只执行一部分。**

View File

@ -47,19 +47,17 @@ async fn quickstart() -> Result<()> {
assert_eq!(string, json!({"data": {"add": 30}}).to_string());
// Note: This test fails due to a bug in tide or reqwest. I will open it again when the bug is fixed.
//
// let resp = client
// .get(listen_addr)
// .query(&[("query", "{ add(a: 10, b: 20) }")])
// .send()
// .await?;
//
// assert_eq!(resp.status(), StatusCode::OK);
// let string = resp.text().await?;
// println!("via get {}", string);
//
// assert_eq!(string, json!({"data": {"add": 30}}).to_string());
let resp = client
.get(listen_addr)
.query(&[("query", "{ add(a: 10, b: 20) }")])
.send()
.await?;
assert_eq!(resp.status(), StatusCode::OK);
let string = resp.text().await?;
println!("via get {}", string);
assert_eq!(string, json!({"data": {"add": 30}}).to_string());
Ok(())
}

View File

@ -19,5 +19,5 @@ pub fn client() -> Client {
}
pub async fn wait_server_ready() {
async_std::task::sleep(Duration::from_millis(300)).await;
async_std::task::sleep(Duration::from_millis(500)).await;
}

View File

@ -220,7 +220,7 @@ pub use request::{BatchRequest, Request};
pub use resolver_utils::{ContainerType, EnumType, ScalarType};
pub use response::{BatchResponse, Response};
pub use schema::{Schema, SchemaBuilder, SchemaEnv};
pub use validation::ValidationMode;
pub use validation::{ValidationMode, VisitorContext};
pub use context::*;
#[doc(no_inline)]

View File

@ -7,9 +7,11 @@ use std::sync::Arc;
use indexmap::map::IndexMap;
use indexmap::set::IndexSet;
use crate::parser::types::{BaseType as ParsedBaseType, Type as ParsedType};
use crate::parser::types::{
BaseType as ParsedBaseType, Field, Type as ParsedType, VariableDefinition,
};
use crate::validators::InputValueValidator;
use crate::{model, Any, Context, Type, Value};
use crate::{model, Any, Context, Positioned, ServerResult, Type, Value, VisitorContext};
pub use cache_control::CacheControl;
@ -83,6 +85,14 @@ impl<'a> MetaTypeName<'a> {
_ => false,
}
}
pub fn is_list(&self) -> bool {
match self {
MetaTypeName::List(_) => true,
MetaTypeName::NonNull(ty) => MetaTypeName::create(ty).is_list(),
MetaTypeName::Named(name) => name.ends_with(']'),
}
}
}
#[derive(Clone)]
@ -95,6 +105,19 @@ pub struct MetaInputValue {
pub visible: Option<MetaVisibleFn>,
}
type ComputeComplexityFn = fn(
&VisitorContext<'_>,
&[Positioned<VariableDefinition>],
&Field,
usize,
) -> ServerResult<usize>;
#[derive(Clone)]
pub enum ComplexityType {
Const(usize),
Fn(ComputeComplexityFn),
}
#[derive(Clone)]
pub struct MetaField {
pub name: String,
@ -107,6 +130,7 @@ pub struct MetaField {
pub requires: Option<&'static str>,
pub provides: Option<&'static str>,
pub visible: Option<MetaVisibleFn>,
pub compute_complexity: Option<ComplexityType>,
}
#[derive(Clone)]
@ -449,6 +473,7 @@ impl Registry {
requires: None,
provides: None,
visible: None,
compute_complexity: None,
},
);
fields
@ -477,6 +502,7 @@ impl Registry {
requires: None,
provides: None,
visible: None,
compute_complexity: None,
},
);
@ -507,6 +533,7 @@ impl Registry {
requires: None,
provides: None,
visible: None,
compute_complexity: None,
},
);
}

View File

@ -232,7 +232,7 @@ where
}
}
fn create_registry() -> Registry {
pub(crate) fn create_registry() -> Registry {
let mut registry = Registry {
types: Default::default(),
directives: Default::default(),

View File

@ -160,6 +160,7 @@ where
requires: None,
provides: None,
visible: None,
compute_complexity: None,
},
);
@ -178,6 +179,7 @@ where
requires: None,
provides: None,
visible: None,
compute_complexity: None,
},
);

View File

@ -79,6 +79,7 @@ where
requires: None,
provides: None,
visible: None,
compute_complexity: None,
},
);
@ -95,6 +96,7 @@ where
requires: None,
provides: None,
visible: None,
compute_complexity: None,
},
);

View File

@ -147,7 +147,7 @@ impl ScalarType for u8 {
}
fn is_valid(value: &Value) -> bool {
matches!(value, Value::Number(n) if n.is_i64())
matches!(value, Value::Number(n) if n.is_u64())
}
fn to_value(&self) -> Value {
@ -178,7 +178,7 @@ impl ScalarType for u16 {
}
fn is_valid(value: &Value) -> bool {
matches!(value, Value::Number(n) if n.is_i64())
matches!(value, Value::Number(n) if n.is_u64())
}
fn to_value(&self) -> Value {
@ -209,7 +209,7 @@ impl ScalarType for u32 {
}
fn is_valid(value: &Value) -> bool {
matches!(value, Value::Number(n) if n.is_i64())
matches!(value, Value::Number(n) if n.is_u64())
}
fn to_value(&self) -> Value {
@ -240,10 +240,72 @@ impl ScalarType for u64 {
}
fn is_valid(value: &Value) -> bool {
matches!(value, Value::Number(n) if n.is_i64())
matches!(value, Value::Number(n) if n.is_u64())
}
fn to_value(&self) -> Value {
Value::Number(Number::from(*self as u64))
}
}
/// The `Int` scalar type represents non-fractional whole numeric values.
#[Scalar(internal, name = "Int")]
impl ScalarType for usize {
fn parse(value: Value) -> InputValueResult<Self> {
match value {
Value::Number(n) => {
let n = n
.as_u64()
.ok_or_else(|| InputValueError::from("Invalid number"))?;
if n > Self::MAX as u64 {
return Err(InputValueError::from(format!(
"Only integers from {} to {} are accepted.",
0,
Self::MAX
)));
}
Ok(n as Self)
}
_ => Err(InputValueError::expected_type(value)),
}
}
fn is_valid(value: &Value) -> bool {
matches!(value, Value::Number(n) if n.is_u64())
}
fn to_value(&self) -> Value {
Value::Number(Number::from(*self as u64))
}
}
/// The `Int` scalar type represents non-fractional whole numeric values.
#[Scalar(internal, name = "Int")]
impl ScalarType for isize {
fn parse(value: Value) -> InputValueResult<Self> {
match value {
Value::Number(n) => {
let n = n
.as_i64()
.ok_or_else(|| InputValueError::from("Invalid number"))?;
if n < Self::MIN as i64 || n > Self::MAX as i64 {
return Err(InputValueError::from(format!(
"Only integers from {} to {} are accepted.",
Self::MIN,
Self::MAX
)));
}
Ok(n as Self)
}
_ => Err(InputValueError::expected_type(value)),
}
}
fn is_valid(value: &Value) -> bool {
matches!(value, Value::Number(n) if n.is_i64())
}
fn to_value(&self) -> Value {
Value::Number(Number::from(*self as i64))
}
}

View File

@ -46,6 +46,7 @@ impl<T: Type> Type for QueryRoot<T> {
requires: None,
provides: None,
visible: None,
compute_complexity: None,
},
);
@ -76,6 +77,7 @@ impl<T: Type> Type for QueryRoot<T> {
requires: None,
provides: None,
visible: None,
compute_complexity: None,
},
);
}

View File

@ -11,7 +11,9 @@ mod visitors;
use crate::parser::types::ExecutableDocument;
use crate::registry::Registry;
use crate::{CacheControl, ServerError, Variables};
use visitor::{visit, VisitorContext, VisitorNil};
use visitor::{visit, VisitorNil};
pub use visitor::VisitorContext;
pub struct CheckResult {
pub cache_control: CacheControl,
@ -67,9 +69,7 @@ pub fn check_rules(
.with(visitors::CacheControlCalculate {
cache_control: &mut cache_control,
})
.with(visitors::ComplexityCalculate {
complexity: &mut complexity,
})
.with(visitors::ComplexityCalculate::new(&mut complexity))
.with(visitors::DepthCalculate::new(&mut depth));
visit(&mut visitor, &mut ctx, doc);
}
@ -80,9 +80,7 @@ pub fn check_rules(
.with(visitors::CacheControlCalculate {
cache_control: &mut cache_control,
})
.with(visitors::ComplexityCalculate {
complexity: &mut complexity,
})
.with(visitors::ComplexityCalculate::new(&mut complexity))
.with(visitors::DepthCalculate::new(&mut depth));
visit(&mut visitor, &mut ctx, doc);
}
@ -95,6 +93,6 @@ pub fn check_rules(
Ok(CheckResult {
cache_control,
complexity,
depth: depth as usize,
depth,
})
}

View File

@ -1,15 +1,17 @@
use std::collections::HashMap;
use std::fmt::{self, Display, Formatter};
use async_graphql_value::Value;
use crate::parser::types::{
Directive, ExecutableDocument, Field, FragmentDefinition, FragmentSpread, InlineFragment,
OperationDefinition, OperationType, Selection, SelectionSet, TypeCondition, VariableDefinition,
};
use crate::registry::{self, MetaType, MetaTypeName};
use crate::{Name, Pos, Positioned, ServerError, Variables};
use async_graphql_value::Value;
use crate::{InputType, Name, Pos, Positioned, ServerError, ServerResult, Variables};
pub(crate) struct VisitorContext<'a> {
#[doc(hidden)]
pub struct VisitorContext<'a> {
pub(crate) registry: &'a registry::Registry,
pub(crate) variables: Option<&'a Variables>,
pub(crate) errors: Vec<RuleError>,
@ -87,9 +89,68 @@ impl<'a> VisitorContext<'a> {
pub(crate) fn fragment(&self, name: &str) -> Option<&'a Positioned<FragmentDefinition>> {
self.fragments.get(name)
}
#[doc(hidden)]
pub fn param_value<T: InputType>(
&self,
variable_definitions: &[Positioned<VariableDefinition>],
field: &Field,
name: &str,
default: Option<fn() -> T>,
) -> ServerResult<T> {
let value = field.get_argument(name).cloned();
if value.is_none() {
if let Some(default) = default {
return Ok(default());
}
}
let (pos, value) = match value {
Some(value) => {
let pos = value.pos;
(
pos,
Some(value.node.into_const_with(|name| {
variable_definitions
.iter()
.find(|def| def.node.name.node == name)
.and_then(|def| {
if let Some(variables) = self.variables {
variables
.0
.get(&def.node.name.node)
.or_else(|| def.node.default_value())
} else {
None
}
})
.cloned()
.ok_or_else(|| {
ServerError::new(format!("Variable {} is not defined.", name))
.at(pos)
})
})?),
)
}
None => (Pos::default(), None),
};
T::parse(value).map_err(|e| e.into_server_error().at(pos))
}
}
#[derive(Copy, Clone, Eq, PartialEq)]
pub(crate) enum VisitMode {
Normal,
Inline,
}
pub(crate) trait Visitor<'a> {
fn mode(&self) -> VisitMode {
VisitMode::Normal
}
fn enter_document(&mut self, _ctx: &mut VisitorContext<'a>, _doc: &'a ExecutableDocument) {}
fn exit_document(&mut self, _ctx: &mut VisitorContext<'a>, _doc: &'a ExecutableDocument) {}
@ -687,10 +748,12 @@ fn visit_fragment_definition<'a, V: Visitor<'a>>(
name: &'a Name,
fragment: &'a Positioned<FragmentDefinition>,
) {
v.enter_fragment_definition(ctx, name, fragment);
visit_directives(v, ctx, &fragment.node.directives);
visit_selection_set(v, ctx, &fragment.node.selection_set);
v.exit_fragment_definition(ctx, name, fragment);
if v.mode() == VisitMode::Normal {
v.enter_fragment_definition(ctx, name, fragment);
visit_directives(v, ctx, &fragment.node.directives);
visit_selection_set(v, ctx, &fragment.node.selection_set);
v.exit_fragment_definition(ctx, name, fragment);
}
}
fn visit_fragment_spread<'a, V: Visitor<'a>>(
@ -700,6 +763,14 @@ fn visit_fragment_spread<'a, V: Visitor<'a>>(
) {
v.enter_fragment_spread(ctx, fragment_spread);
visit_directives(v, ctx, &fragment_spread.node.directives);
if v.mode() == VisitMode::Inline {
if let Some(fragment) = ctx
.fragments
.get(fragment_spread.node.fragment_name.node.as_str())
{
visit_selection_set(v, ctx, &fragment.node.selection_set);
}
}
v.exit_fragment_spread(ctx, fragment_spread);
}

View File

@ -1,6 +1,6 @@
use crate::parser::types::{Field, SelectionSet};
use crate::registry::MetaType;
use crate::validation::visitor::{Visitor, VisitorContext};
use crate::validation::visitor::{VisitMode, Visitor, VisitorContext};
use crate::{CacheControl, Positioned};
pub struct CacheControlCalculate<'a> {
@ -8,6 +8,10 @@ pub struct CacheControlCalculate<'a> {
}
impl<'ctx, 'a> Visitor<'ctx> for CacheControlCalculate<'a> {
fn mode(&self) -> VisitMode {
VisitMode::Inline
}
fn enter_selection_set(
&mut self,
ctx: &mut VisitorContext<'_>,

View File

@ -1,13 +1,254 @@
use crate::parser::types::Field;
use crate::validation::visitor::{Visitor, VisitorContext};
use crate::registry::{ComplexityType, MetaType, MetaTypeName};
use crate::validation::visitor::{VisitMode, Visitor, VisitorContext};
use crate::Positioned;
use async_graphql_parser::types::{ExecutableDocument, OperationDefinition, VariableDefinition};
use async_graphql_value::Name;
pub struct ComplexityCalculate<'a> {
pub struct ComplexityCalculate<'ctx, 'a> {
pub complexity: &'a mut usize,
pub complexity_stack: Vec<usize>,
pub variable_definition: Option<&'ctx [Positioned<VariableDefinition>]>,
}
impl<'ctx, 'a> Visitor<'ctx> for ComplexityCalculate<'a> {
fn enter_field(&mut self, _ctx: &mut VisitorContext<'_>, _field: &Positioned<Field>) {
*self.complexity += 1;
impl<'ctx, 'a> ComplexityCalculate<'ctx, 'a> {
pub fn new(complexity: &'a mut usize) -> Self {
Self {
complexity,
complexity_stack: Default::default(),
variable_definition: None,
}
}
}
impl<'ctx, 'a> Visitor<'ctx> for ComplexityCalculate<'ctx, 'a> {
fn mode(&self) -> VisitMode {
VisitMode::Inline
}
fn enter_document(&mut self, _ctx: &mut VisitorContext<'ctx>, _doc: &'ctx ExecutableDocument) {
self.complexity_stack.push(0);
}
fn exit_document(&mut self, _ctx: &mut VisitorContext<'ctx>, _doc: &'ctx ExecutableDocument) {
*self.complexity = self.complexity_stack.pop().unwrap();
}
fn enter_operation_definition(
&mut self,
_ctx: &mut VisitorContext<'ctx>,
_name: Option<&'ctx Name>,
operation_definition: &'ctx Positioned<OperationDefinition>,
) {
self.variable_definition = Some(&operation_definition.node.variable_definitions);
}
fn enter_field(&mut self, _ctx: &mut VisitorContext<'_>, _field: &Positioned<Field>) {
self.complexity_stack.push(0);
}
fn exit_field(&mut self, ctx: &mut VisitorContext<'ctx>, field: &'ctx Positioned<Field>) {
let children_complex = self.complexity_stack.pop().unwrap();
if let Some(ty) = ctx.parent_type() {
if let MetaType::Object { fields, .. } = ty {
if let Some(meta_field) = fields.get(MetaTypeName::concrete_typename(
field.node.name.node.as_str(),
)) {
if let Some(compute_complexity) = &meta_field.compute_complexity {
match compute_complexity {
ComplexityType::Const(n) => {
*self.complexity_stack.last_mut().unwrap() += n;
}
ComplexityType::Fn(f) => {
if MetaTypeName::create(&meta_field.ty).is_list() {
match f(
ctx,
self.variable_definition.unwrap(),
&field.node,
children_complex,
) {
Ok(n) => {
*self.complexity_stack.last_mut().unwrap() += n;
}
Err(err) => {
ctx.report_error(vec![field.pos], err.to_string())
}
}
}
}
}
return;
}
}
}
}
*self.complexity_stack.last_mut().unwrap() += 1 + children_complex;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::parse_query;
use crate::validation::{visit, VisitorContext};
use crate::{EmptyMutation, EmptySubscription, Object, Schema};
struct Query;
#[derive(Copy, Clone)]
struct MyObj;
#[Object(internal)]
impl MyObj {
async fn a(&self) -> i32 {
1
}
async fn b(&self) -> i32 {
2
}
async fn c(&self) -> MyObj {
MyObj
}
}
#[Object(internal)]
impl Query {
async fn value(&self) -> i32 {
1
}
async fn obj(&self) -> MyObj {
MyObj
}
#[graphql(complexity = "count * child_complexity")]
async fn objs(&self, count: usize) -> Vec<MyObj> {
vec![MyObj; count as usize]
}
#[graphql(complexity = 3)]
async fn d(&self) -> MyObj {
MyObj
}
}
fn check_complex(query: &str, expect_complex: usize) {
let registry = Schema::<Query, EmptyMutation, EmptySubscription>::create_registry();
let doc = parse_query(query).unwrap();
let mut ctx = VisitorContext::new(&registry, &doc, None);
let mut complex = 0;
let mut complex_calculate = ComplexityCalculate::new(&mut complex);
visit(&mut complex_calculate, &mut ctx, &doc);
assert_eq!(complex, expect_complex);
}
#[test]
fn complex() {
check_complex(
r#"
{
value #1
}"#,
1,
);
check_complex(
r#"
{
value #1
d #3
}"#,
4,
);
check_complex(
r#"
{
value obj { #2
a b #2
}
}"#,
4,
);
check_complex(
r#"
{
value obj { #2
a b obj { #3
a b obj { #3
a #1
}
}
}
}"#,
9,
);
check_complex(
r#"
fragment A on MyObj {
a b ... A2 #2
}
fragment A2 on MyObj {
obj { # 1
a # 1
}
}
query {
obj { # 1
... A
}
}"#,
5,
);
check_complex(
r#"
{
obj { # 1
... on MyObj {
a b #2
... on MyObj {
obj { #1
a #1
}
}
}
}
}"#,
5,
);
check_complex(
r#"
{
objs(count: 10) {
a b
}
}"#,
20,
);
check_complex(
r#"
fragment A on MyObj {
a b
}
query {
objs(count: 10) {
... A
}
}"#,
20,
);
}
}

View File

@ -1,69 +1,151 @@
use crate::parser::types::{FragmentSpread, InlineFragment, SelectionSet};
use crate::validation::visitor::{Visitor, VisitorContext};
use crate::validation::visitor::{VisitMode, Visitor, VisitorContext};
use crate::Positioned;
use async_graphql_parser::types::Field;
pub struct DepthCalculate<'a> {
max_depth: &'a mut i32,
current_depth: i32,
max_depth: &'a mut usize,
current_depth: usize,
}
impl<'a> DepthCalculate<'a> {
pub fn new(max_depth: &'a mut i32) -> Self {
*max_depth = -1;
pub fn new(max_depth: &'a mut usize) -> Self {
Self {
max_depth,
current_depth: -1,
current_depth: 0,
}
}
}
impl<'ctx, 'a> Visitor<'ctx> for DepthCalculate<'a> {
fn enter_selection_set(
&mut self,
_ctx: &mut VisitorContext<'ctx>,
_selection_set: &'ctx Positioned<SelectionSet>,
) {
fn mode(&self) -> VisitMode {
VisitMode::Inline
}
fn enter_field(&mut self, _ctx: &mut VisitorContext<'ctx>, _field: &'ctx Positioned<Field>) {
self.current_depth += 1;
*self.max_depth = (*self.max_depth).max(self.current_depth);
}
fn exit_selection_set(
&mut self,
_ctx: &mut VisitorContext<'ctx>,
_selection_set: &'ctx Positioned<SelectionSet>,
) {
fn exit_field(&mut self, _ctx: &mut VisitorContext<'ctx>, _field: &'ctx Positioned<Field>) {
self.current_depth -= 1;
}
fn enter_fragment_spread(
&mut self,
_ctx: &mut VisitorContext<'ctx>,
_fragment_spread: &'ctx Positioned<FragmentSpread>,
) {
self.current_depth -= 1;
}
fn exit_fragment_spread(
&mut self,
_ctx: &mut VisitorContext<'ctx>,
_fragment_spread: &'ctx Positioned<FragmentSpread>,
) {
self.current_depth += 1;
}
fn enter_inline_fragment(
&mut self,
_ctx: &mut VisitorContext<'ctx>,
_inline_fragment: &'ctx Positioned<InlineFragment>,
) {
self.current_depth -= 1;
}
fn exit_inline_fragment(
&mut self,
_ctx: &mut VisitorContext<'ctx>,
_inline_fragment: &'ctx Positioned<InlineFragment>,
) {
self.current_depth += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::parse_query;
use crate::validation::{visit, VisitorContext};
use crate::{EmptyMutation, EmptySubscription, Object, Schema};
struct Query;
struct MyObj;
#[Object(internal)]
impl MyObj {
async fn a(&self) -> i32 {
1
}
async fn b(&self) -> i32 {
2
}
async fn c(&self) -> MyObj {
MyObj
}
}
#[Object(internal)]
impl Query {
async fn value(&self) -> i32 {
1
}
async fn obj(&self) -> MyObj {
MyObj
}
}
fn check_depth(query: &str, expect_depth: usize) {
let registry = Schema::<Query, EmptyMutation, EmptySubscription>::create_registry();
let doc = parse_query(query).unwrap();
let mut ctx = VisitorContext::new(&registry, &doc, None);
let mut depth = 0;
let mut depth_calculate = DepthCalculate::new(&mut depth);
visit(&mut depth_calculate, &mut ctx, &doc);
assert_eq!(depth, expect_depth);
}
#[test]
fn depth() {
check_depth(
r#"{
value #1
}"#,
1,
);
check_depth(
r#"
{
obj { #1
a b #2
}
}"#,
2,
);
check_depth(
r#"
{
obj { # 1
a b c { # 2
a b c { # 3
a b # 4
}
}
}
}"#,
4,
);
check_depth(
r#"
fragment A on MyObj {
a b ... A2 #2
}
fragment A2 on MyObj {
obj {
a #3
}
}
query {
obj { # 1
... A
}
}"#,
3,
);
check_depth(
r#"
{
obj { # 1
... on MyObj {
a b #2
... on MyObj {
obj {
a #3
}
}
}
}
}"#,
3,
);
}
}

View File

@ -1,139 +0,0 @@
use async_graphql::*;
#[async_std::test]
pub async fn test_complexity_and_depth() {
struct Query;
struct MyObj;
#[Object]
impl MyObj {
async fn a(&self) -> i32 {
1
}
async fn b(&self) -> i32 {
2
}
async fn c(&self) -> MyObj {
MyObj
}
}
#[Object]
impl Query {
async fn value(&self) -> i32 {
1
}
async fn obj(&self) -> MyObj {
MyObj
}
}
let query = "{ a:value b:value c:value }";
let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
.limit_complexity(2)
.finish();
assert_eq!(
schema.execute(query).await.into_result().unwrap_err(),
vec![ServerError {
message: "Query is too complex.".to_owned(),
locations: Vec::new(),
path: Vec::new(),
extensions: None,
}]
);
let query = "{ a:value b:value }";
let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
.limit_complexity(2)
.finish();
assert_eq!(
schema.execute(query).await.data,
value!({
"a": 1,
"b": 1,
})
);
let query = "{ obj { a b } }";
let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
.limit_complexity(2)
.finish();
assert_eq!(
schema.execute(query).await.into_result().unwrap_err(),
vec![ServerError {
message: "Query is too complex.".to_owned(),
locations: Vec::new(),
path: Vec::new(),
extensions: None,
}]
);
let query = "{ obj { a } }";
let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
.limit_complexity(2)
.finish();
assert_eq!(
schema.execute(query).await.data,
value!({
"obj": { "a": 1 }
})
);
let query = r#"
{
obj { # 0
a b c { # 1
a b c { # 2
a b # 3
}
}
}
}"#;
let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
.limit_depth(2)
.finish();
assert_eq!(
schema.execute(query).await.into_result().unwrap_err(),
vec![ServerError {
message: "Query is nested too deep.".to_owned(),
locations: Vec::new(),
path: Vec::new(),
extensions: None,
}]
);
let query = r#"
{
obj { # 0
a b c { # 1
a b c { # 2
a b # 3
}
}
}
}"#;
let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
.limit_depth(3)
.finish();
assert_eq!(
schema.execute(query).await.data,
value!({
"obj": {
"a": 1,
"b": 2,
"c": {
"a": 1,
"b": 2,
"c": {
"a": 1,
"b": 2,
}
}
}
})
);
}