From 0d76987f4058c9765ce60df67f53cfa6f93299a8 Mon Sep 17 00:00:00 2001 From: Sunli Date: Fri, 18 Dec 2020 23:58:03 +0800 Subject: [PATCH] Add Analyzer extension. --- src/extensions/analyzer.rs | 107 +++++++++++++++++++++++++++++++++++++ src/extensions/mod.rs | 12 +++-- src/extensions/tracing.rs | 4 +- src/lib.rs | 2 +- src/schema.rs | 16 +++--- src/validation/mod.rs | 14 +++-- 6 files changed, 134 insertions(+), 21 deletions(-) create mode 100644 src/extensions/analyzer.rs diff --git a/src/extensions/analyzer.rs b/src/extensions/analyzer.rs new file mode 100644 index 00000000..200b6e8b --- /dev/null +++ b/src/extensions/analyzer.rs @@ -0,0 +1,107 @@ +use crate::extensions::{Extension, ExtensionContext, ExtensionFactory}; +use crate::{value, ValidationResult, Value}; + +/// Analyzer extension +/// +/// This extension will output the `analyzer` field containing `complexity` and `depth` in the response extension of each query. +pub struct Analyzer; + +impl ExtensionFactory for Analyzer { + fn create(&self) -> Box { + Box::new(AnalyzerExtension::default()) + } +} + +#[derive(Default)] +struct AnalyzerExtension { + complexity: usize, + depth: usize, +} + +impl Extension for AnalyzerExtension { + fn name(&self) -> Option<&'static str> { + Some("analyzer") + } + + fn validation_end(&mut self, _ctx: &ExtensionContext<'_>, result: &ValidationResult) { + self.complexity = result.complexity; + self.depth = result.depth; + } + + fn result(&mut self, _ctx: &ExtensionContext<'_>) -> Option { + Some(value! ({ + "complexity": self.complexity, + "depth": self.depth, + })) + } +} + +#[cfg(test)] +mod tests { + use crate::*; + + struct Query; + + #[derive(Copy, Clone)] + struct MyObj; + + #[Object(internal)] + impl MyObj { + async fn value(&self) -> i32 { + 1 + } + + async fn obj(&self) -> MyObj { + MyObj + } + } + + #[Object(internal)] + impl Query { + async fn value(&self) -> i32 { + 1 + } + + async fn obj(&self) -> MyObj { + MyObj + } + + #[graphql(complexity = "count * child_complexity")] + async fn objs(&self, count: usize) -> Vec { + vec![MyObj; count as usize] + } + } + + #[async_std::test] + async fn analyzer() { + let schema = Schema::build(Query, EmptyMutation, EmptySubscription) + .extension(extensions::Analyzer) + .finish(); + + let extensions = schema + .execute( + r#"{ + value obj { + value obj { + value + } + } + objs(count: 10) { value } + }"#, + ) + .await + .into_result() + .unwrap() + .extensions + .unwrap(); + assert_eq!( + extensions, + value!({ + "analyzer": { + "complexity": 5 + 10, + "depth": 3, + } + }) + ); + } +} diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 9187c25e..54e6fcb1 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -1,5 +1,6 @@ //! Extensions for schema +mod analyzer; #[cfg(feature = "apollo_persisted_queries")] pub mod apollo_persisted_queries; #[cfg(feature = "apollo_tracing")] @@ -14,9 +15,10 @@ use std::collections::BTreeMap; use crate::context::{QueryPathNode, ResolveId}; use crate::parser::types::ExecutableDocument; -use crate::{Data, Request, Result, ServerError, ServerResult, Variables}; +use crate::{Data, Request, Result, ServerError, ServerResult, ValidationResult, Variables}; use crate::{Error, Name, Value}; +pub use self::analyzer::Analyzer; #[cfg(feature = "apollo_tracing")] pub use self::apollo_tracing::ApolloTracing; #[cfg(feature = "log")] @@ -121,7 +123,7 @@ pub trait Extension: Sync + Send + 'static { fn validation_start(&mut self, ctx: &ExtensionContext<'_>) {} /// Called at the end of the validation. - fn validation_end(&mut self, ctx: &ExtensionContext<'_>) {} + fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) {} /// Called at the begin of the execution. fn execution_start(&mut self, ctx: &ExtensionContext<'_>) {} @@ -236,9 +238,11 @@ impl Extensions { } } - pub fn validation_end(&mut self, ctx: &ExtensionContext<'_>) { + pub fn validation_end(&mut self, ctx: &ExtensionContext<'_>, result: &ValidationResult) { if let Some(e) = &mut self.0 { - e.get_mut().iter_mut().for_each(|e| e.validation_end(ctx)); + e.get_mut() + .iter_mut() + .for_each(|e| e.validation_end(ctx, result)); } } diff --git a/src/extensions/tracing.rs b/src/extensions/tracing.rs index 8d99e547..fbd6a026 100644 --- a/src/extensions/tracing.rs +++ b/src/extensions/tracing.rs @@ -4,7 +4,7 @@ use tracing::{span, Level, Span}; use crate::extensions::{Extension, ExtensionContext, ExtensionFactory, ResolveInfo}; use crate::parser::types::ExecutableDocument; -use crate::{ServerError, Variables}; +use crate::{ServerError, ValidationResult, Variables}; /// Tracing extension configuration for each request. #[derive(Default)] @@ -136,7 +136,7 @@ impl Extension for TracingExtension { } } - fn validation_end(&mut self, _ctx: &ExtensionContext<'_>) { + fn validation_end(&mut self, _ctx: &ExtensionContext<'_>, _result: &ValidationResult) { self.validation .take() .and_then(|span| span.with_subscriber(|(id, d)| d.exit(id))); diff --git a/src/lib.rs b/src/lib.rs index b55fb98d..d31556a6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -220,7 +220,7 @@ pub use request::{BatchRequest, Request}; pub use resolver_utils::{ContainerType, EnumType, ScalarType}; pub use response::{BatchResponse, Response}; pub use schema::{Schema, SchemaBuilder, SchemaEnv}; -pub use validation::{ValidationMode, VisitorContext}; +pub use validation::{ValidationMode, ValidationResult, VisitorContext}; pub use context::*; #[doc(no_inline)] diff --git a/src/schema.rs b/src/schema.rs index c82cb5d4..5461119b 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -16,7 +16,7 @@ use crate::registry::{MetaDirective, MetaInputValue, Registry}; use crate::resolver_utils::{resolve_container, resolve_container_serial}; use crate::subscription::collect_subscription_streams; use crate::types::QueryRoot; -use crate::validation::{check_rules, CheckResult, ValidationMode}; +use crate::validation::{check_rules, ValidationMode}; use crate::{ BatchRequest, BatchResponse, CacheControl, ContextBase, ObjectType, QueryEnv, Request, Response, ServerError, SubscriptionType, Type, Value, ID, @@ -372,29 +372,25 @@ where // check rules extensions.validation_start(&ctx_extension); - let CheckResult { - cache_control, - complexity, - depth, - } = check_rules( + let validation_result = check_rules( &self.env.registry, &document, Some(&request.variables), self.validation_mode, ) .log_error(&ctx_extension, &extensions)?; - extensions.validation_end(&ctx_extension); + extensions.validation_end(&ctx_extension, &validation_result); // check limit if let Some(limit_complexity) = self.complexity { - if complexity > limit_complexity { + if validation_result.complexity > limit_complexity { return Err(vec![ServerError::new("Query is too complex.")]) .log_error(&ctx_extension, &extensions); } } if let Some(limit_depth) = self.depth { - if depth > limit_depth { + if validation_result.depth > limit_depth { return Err(vec![ServerError::new("Query is nested too deep.")]) .log_error(&ctx_extension, &extensions); } @@ -437,7 +433,7 @@ where uploads: request.uploads, ctx_data: Arc::new(data), }; - Ok((env, cache_control)) + Ok((env, validation_result.cache_control)) } async fn execute_once(&self, env: QueryEnv) -> Response { diff --git a/src/validation/mod.rs b/src/validation/mod.rs index eb2eb894..5b4e561c 100644 --- a/src/validation/mod.rs +++ b/src/validation/mod.rs @@ -11,13 +11,19 @@ mod visitors; use crate::parser::types::ExecutableDocument; use crate::registry::Registry; use crate::{CacheControl, ServerError, Variables}; -use visitor::{visit, VisitorNil}; pub use visitor::VisitorContext; +use visitor::{visit, VisitorNil}; -pub struct CheckResult { +/// Validation results. +pub struct ValidationResult { + /// Cache control pub cache_control: CacheControl, + + /// Query complexity pub complexity: usize, + + /// Query depth pub depth: usize, } @@ -36,7 +42,7 @@ pub fn check_rules( doc: &ExecutableDocument, variables: Option<&Variables>, mode: ValidationMode, -) -> Result> { +) -> Result> { let mut ctx = VisitorContext::new(registry, doc, variables); let mut cache_control = CacheControl::default(); let mut complexity = 0; @@ -90,7 +96,7 @@ pub fn check_rules( return Err(ctx.errors.into_iter().map(Into::into).collect()); } - Ok(CheckResult { + Ok(ValidationResult { cache_control, complexity, depth,