use std::{ any::Any, collections::{HashMap, HashSet}, ops::Deref, sync::Arc, }; use futures_util::stream::{self, Stream, StreamExt}; use indexmap::map::IndexMap; use crate::{ context::{Data, QueryEnvInner}, custom_directive::CustomDirectiveFactory, extensions::{ExtensionFactory, Extensions}, model::__DirectiveLocation, parser::{ parse_query, types::{ Directive, DocumentOperations, ExecutableDocument, OperationType, Selection, SelectionSet, }, Positioned, }, registry::{MetaDirective, MetaInputValue, Registry, SDLExportOptions}, resolver_utils::{resolve_container, resolve_container_serial}, subscription::collect_subscription_streams, types::QueryRoot, validation::{check_rules, ValidationMode}, BatchRequest, BatchResponse, CacheControl, ContextBase, EmptyMutation, EmptySubscription, InputType, ObjectType, OutputType, QueryEnv, Request, Response, ServerError, ServerResult, SubscriptionType, Variables, ID, }; /// Introspection mode #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum IntrospectionMode { /// Introspection only IntrospectionOnly, /// Enables introspection Enabled, /// Disables introspection Disabled, } impl Default for IntrospectionMode { fn default() -> Self { IntrospectionMode::Enabled } } /// Schema builder pub struct SchemaBuilder { validation_mode: ValidationMode, query: QueryRoot, mutation: Mutation, subscription: Subscription, registry: Registry, data: Data, complexity: Option, depth: Option, recursive_depth: usize, extensions: Vec>, custom_directives: HashMap<&'static str, Box>, } impl SchemaBuilder { /// Manually register a input type in the schema. /// /// You can use this function to register schema types that are not directly /// referenced. #[must_use] pub fn register_input_type(mut self) -> Self { T::create_type_info(&mut self.registry); self } /// Manually register a output type in the schema. /// /// You can use this function to register schema types that are not directly /// referenced. #[must_use] pub fn register_output_type(mut self) -> Self { T::create_type_info(&mut self.registry); self } /// Disable introspection queries. #[must_use] pub fn disable_introspection(mut self) -> Self { self.registry.introspection_mode = IntrospectionMode::Disabled; self } /// Only process introspection queries, everything else is processed as an /// error. #[must_use] pub fn introspection_only(mut self) -> Self { self.registry.introspection_mode = IntrospectionMode::IntrospectionOnly; self } /// Set the maximum complexity a query can have. By default, there is no /// limit. #[must_use] pub fn limit_complexity(mut self, complexity: usize) -> Self { self.complexity = Some(complexity); self } /// Set the maximum depth a query can have. By default, there is no limit. #[must_use] pub fn limit_depth(mut self, depth: usize) -> Self { self.depth = Some(depth); self } /// Set the maximum recursive depth a query can have. (default: 32) /// /// If the value is too large, stack overflow may occur, usually `32` is /// enough. #[must_use] pub fn limit_recursive_depth(mut self, depth: usize) -> Self { self.recursive_depth = depth; self } /// Add an extension to the schema. /// /// # Examples /// /// ```rust /// use async_graphql::*; /// /// struct Query; /// /// #[Object] /// impl Query { /// async fn value(&self) -> i32 { /// 100 /// } /// } /// /// let schema = Schema::build(Query, EmptyMutation, EmptySubscription) /// .extension(extensions::Logger) /// .finish(); /// ``` #[must_use] pub fn extension(mut self, extension: impl ExtensionFactory) -> Self { self.extensions.push(Box::new(extension)); self } /// Add a global data that can be accessed in the `Schema`. You access it /// with `Context::data`. #[must_use] pub fn data(mut self, data: D) -> Self { self.data.insert(data); self } /// Set the validation mode, default is `ValidationMode::Strict`. #[must_use] pub fn validation_mode(mut self, validation_mode: ValidationMode) -> Self { self.validation_mode = validation_mode; self } /// Enable federation, which is automatically enabled if the Query has least /// one entity definition. #[must_use] pub fn enable_federation(mut self) -> Self { self.registry.enable_federation = true; self } /// Make the Federation SDL include subscriptions. /// /// Note: Not included by default, in order to be compatible with Apollo /// Server. #[must_use] pub fn enable_subscription_in_federation(mut self) -> Self { self.registry.federation_subscription = true; self } /// Override the name of the specified input type. #[must_use] pub fn override_input_type_description(mut self, desc: &'static str) -> Self { self.registry.set_description(&*T::type_name(), desc); self } /// Override the name of the specified output type. #[must_use] pub fn override_output_type_description(mut self, desc: &'static str) -> Self { self.registry.set_description(&*T::type_name(), desc); self } /// Register a custom directive. /// /// # Panics /// /// Panics if the directive with the same name is already registered. #[must_use] pub fn directive(mut self, directive: T) -> Self { let name = directive.name(); let instance = Box::new(directive); instance.register(&mut self.registry); if name == "skip" || name == "include" || self.custom_directives.insert(name, instance).is_some() { panic!("Directive `{}` already exists", name); } self } /// Build schema. pub fn finish(mut self) -> Schema { // federation if self.registry.enable_federation || self.registry.has_entities() { self.registry.create_federation_types(); } Schema(Arc::new(SchemaInner { validation_mode: self.validation_mode, query: self.query, mutation: self.mutation, subscription: self.subscription, complexity: self.complexity, depth: self.depth, recursive_depth: self.recursive_depth, extensions: self.extensions, env: SchemaEnv(Arc::new(SchemaEnvInner { registry: self.registry, data: self.data, custom_directives: self.custom_directives, })), })) } } #[doc(hidden)] pub struct SchemaEnvInner { pub registry: Registry, pub data: Data, pub custom_directives: HashMap<&'static str, Box>, } #[doc(hidden)] #[derive(Clone)] pub struct SchemaEnv(Arc); impl Deref for SchemaEnv { type Target = SchemaEnvInner; fn deref(&self) -> &Self::Target { &self.0 } } #[doc(hidden)] pub struct SchemaInner { pub(crate) validation_mode: ValidationMode, pub(crate) query: QueryRoot, pub(crate) mutation: Mutation, pub(crate) subscription: Subscription, pub(crate) complexity: Option, pub(crate) depth: Option, pub(crate) recursive_depth: usize, pub(crate) extensions: Vec>, pub(crate) env: SchemaEnv, } /// GraphQL schema. /// /// Cloning a schema is cheap, so it can be easily shared. pub struct Schema(Arc>); impl Clone for Schema { fn clone(&self) -> Self { Schema(self.0.clone()) } } impl Default for Schema where Query: Default + ObjectType + 'static, Mutation: Default + ObjectType + 'static, Subscription: Default + SubscriptionType + 'static, { fn default() -> Self { Schema::new( Query::default(), Mutation::default(), Subscription::default(), ) } } impl Deref for Schema { type Target = SchemaInner; fn deref(&self) -> &Self::Target { &self.0 } } impl Schema where Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, { /// Create a schema builder /// /// The root object for the query and Mutation needs to be specified. /// If there is no mutation, you can use `EmptyMutation`. /// If there is no subscription, you can use `EmptySubscription`. pub fn build( query: Query, mutation: Mutation, subscription: Subscription, ) -> SchemaBuilder { Self::build_with_ignore_name_conflicts(query, mutation, subscription, [] as [&str; 0]) } /// Create a schema builder and specifies a list to ignore type conflict /// detection. /// /// NOTE: It is not recommended to use it unless you know what it does. #[must_use] pub fn build_with_ignore_name_conflicts( query: Query, mutation: Mutation, subscription: Subscription, ignore_name_conflicts: I, ) -> SchemaBuilder where I: IntoIterator, T: Into, { SchemaBuilder { validation_mode: ValidationMode::Strict, query: QueryRoot { inner: query }, mutation, subscription, registry: Self::create_registry( ignore_name_conflicts.into_iter().map(Into::into).collect(), ), data: Default::default(), complexity: None, depth: None, recursive_depth: 32, extensions: Default::default(), custom_directives: Default::default(), } } pub(crate) fn create_registry(ignore_name_conflicts: HashSet) -> Registry { let mut registry = Registry { types: Default::default(), directives: Default::default(), implements: Default::default(), query_type: Query::type_name().to_string(), mutation_type: if Mutation::is_empty() { None } else { Some(Mutation::type_name().to_string()) }, subscription_type: if Subscription::is_empty() { None } else { Some(Subscription::type_name().to_string()) }, introspection_mode: IntrospectionMode::Enabled, enable_federation: false, federation_subscription: false, ignore_name_conflicts, }; registry.add_directive(MetaDirective { name: "include", description: Some("Directs the executor to include this field or fragment only when the `if` argument is true."), locations: vec![ __DirectiveLocation::FIELD, __DirectiveLocation::FRAGMENT_SPREAD, __DirectiveLocation::INLINE_FRAGMENT ], args: { let mut args = IndexMap::new(); args.insert("if".to_string(), MetaInputValue { name: "if", description: Some("Included when true."), ty: "Boolean!".to_string(), default_value: None, visible: None, inaccessible: false, tags: Default::default(), is_secret: false, }); args }, is_repeatable: false, visible: None, }); registry.add_directive(MetaDirective { name: "skip", description: Some("Directs the executor to skip this field or fragment when the `if` argument is true."), locations: vec![ __DirectiveLocation::FIELD, __DirectiveLocation::FRAGMENT_SPREAD, __DirectiveLocation::INLINE_FRAGMENT ], args: { let mut args = IndexMap::new(); args.insert("if".to_string(), MetaInputValue { name: "if", description: Some("Skipped when true."), ty: "Boolean!".to_string(), default_value: None, visible: None, inaccessible: false, tags: Default::default(), is_secret: false, }); args }, is_repeatable: false, visible: None, }); // register scalars ::create_type_info(&mut registry); ::create_type_info(&mut registry); ::create_type_info(&mut registry); ::create_type_info(&mut registry); ::create_type_info(&mut registry); QueryRoot::::create_type_info(&mut registry); if !Mutation::is_empty() { Mutation::create_type_info(&mut registry); } if !Subscription::is_empty() { Subscription::create_type_info(&mut registry); } registry.remove_unused_types(); registry } /// Create a schema pub fn new( query: Query, mutation: Mutation, subscription: Subscription, ) -> Schema { Self::build(query, mutation, subscription).finish() } #[inline] #[allow(unused)] pub(crate) fn registry(&self) -> &Registry { &self.env.registry } /// Returns SDL(Schema Definition Language) of this schema. pub fn sdl(&self) -> String { self.0.env.registry.export_sdl(Default::default()) } /// Returns SDL(Schema Definition Language) of this schema with options. pub fn sdl_with_options(&self, options: SDLExportOptions) -> String { self.0.env.registry.export_sdl(options) } /// Get all names in this schema /// /// Maybe you want to serialize a custom binary protocol. In order to /// minimize message size, a dictionary is usually used to compress type /// names, field names, directive names, and parameter names. This function /// gets all the names, so you can create this dictionary. pub fn names(&self) -> Vec { self.0.env.registry.names() } fn create_extensions(&self, session_data: Arc) -> Extensions { Extensions::new( self.extensions.iter().map(|f| f.create()), self.env.clone(), session_data, ) } async fn prepare_request( &self, mut extensions: Extensions, request: Request, session_data: Arc, ) -> Result<(QueryEnv, CacheControl), Vec> { let mut request = request; let query_data = Arc::new(std::mem::take(&mut request.data)); extensions.attach_query_data(query_data.clone()); let mut request = extensions.prepare_request(request).await?; let mut document = { let query = &request.query; let parsed_doc = request.parsed_query.take(); let recursive_depth = self.recursive_depth; let fut_parse = async move { let doc = match parsed_doc { Some(parsed_doc) => parsed_doc, None => parse_query(query)?, }; check_recursive_depth(&doc, recursive_depth)?; Ok(doc) }; futures_util::pin_mut!(fut_parse); extensions .parse_query(query, &request.variables, &mut fut_parse) .await? }; // check rules let validation_result = { let validation_fut = async { check_rules( &self.env.registry, &document, Some(&request.variables), self.validation_mode, ) }; futures_util::pin_mut!(validation_fut); extensions.validation(&mut validation_fut).await? }; // check limit if let Some(limit_complexity) = self.complexity { if validation_result.complexity > limit_complexity { return Err(vec![ServerError::new("Query is too complex.", None)]); } } if let Some(limit_depth) = self.depth { if validation_result.depth > limit_depth { return Err(vec![ServerError::new("Query is nested too deep.", None)]); } } let operation = if let Some(operation_name) = &request.operation_name { match document.operations { DocumentOperations::Single(_) => None, DocumentOperations::Multiple(mut operations) => operations .remove(operation_name.as_str()) .map(|operation| (Some(operation_name.clone()), operation)), } .ok_or_else(|| { ServerError::new( format!(r#"Unknown operation named "{}""#, operation_name), None, ) }) } else { match document.operations { DocumentOperations::Single(operation) => Ok((None, operation)), DocumentOperations::Multiple(map) if map.len() == 1 => { let (operation_name, operation) = map.into_iter().next().unwrap(); Ok((Some(operation_name.to_string()), operation)) } DocumentOperations::Multiple(_) => Err(ServerError::new( "Operation name required in request.", None, )), } }; let (operation_name, mut operation) = operation.map_err(|err| vec![err])?; // remove skipped fields for fragment in document.fragments.values_mut() { remove_skipped_selection(&mut fragment.node.selection_set.node, &request.variables); } remove_skipped_selection(&mut operation.node.selection_set.node, &request.variables); let env = QueryEnvInner { extensions, variables: request.variables, operation_name, operation, fragments: document.fragments, uploads: request.uploads, session_data, ctx_data: query_data, extension_data: Arc::new(request.data), http_headers: Default::default(), introspection_mode: request.introspection_mode, errors: Default::default(), }; Ok((QueryEnv::new(env), validation_result.cache_control)) } async fn execute_once(&self, env: QueryEnv) -> Response { // execute let ctx = ContextBase { path_node: None, is_for_introspection: false, item: &env.operation.node.selection_set, schema_env: &self.env, query_env: &env, }; let res = match &env.operation.node.ty { OperationType::Query => resolve_container(&ctx, &self.query).await, OperationType::Mutation => { if self.env.registry.introspection_mode == IntrospectionMode::IntrospectionOnly || env.introspection_mode == IntrospectionMode::IntrospectionOnly { resolve_container_serial(&ctx, &EmptyMutation).await } else { resolve_container_serial(&ctx, &self.mutation).await } } OperationType::Subscription => Err(ServerError::new( "Subscriptions are not supported on this transport.", None, )), }; let mut resp = match res { Ok(value) => Response::new(value), Err(err) => Response::from_errors(vec![err]), } .http_headers(std::mem::take(&mut *env.http_headers.lock().unwrap())); resp.errors .extend(std::mem::take(&mut *env.errors.lock().unwrap())); resp } /// Execute a GraphQL query. pub async fn execute(&self, request: impl Into) -> Response { let request = request.into(); let extensions = self.create_extensions(Default::default()); let request_fut = { let extensions = extensions.clone(); async move { match self .prepare_request(extensions, request, Default::default()) .await { Ok((env, cache_control)) => { let fut = async { self.execute_once(env.clone()) .await .cache_control(cache_control) }; futures_util::pin_mut!(fut); env.extensions .execute(env.operation_name.as_deref(), &mut fut) .await } Err(errors) => Response::from_errors(errors), } } }; futures_util::pin_mut!(request_fut); extensions.request(&mut request_fut).await } /// Execute a GraphQL batch query. pub async fn execute_batch(&self, batch_request: BatchRequest) -> BatchResponse { match batch_request { BatchRequest::Single(request) => BatchResponse::Single(self.execute(request).await), BatchRequest::Batch(requests) => BatchResponse::Batch( futures_util::stream::iter(requests.into_iter()) .then(|request| self.execute(request)) .collect() .await, ), } } /// Execute a GraphQL subscription with session data. #[doc(hidden)] pub fn execute_stream_with_session_data( &self, request: impl Into + Send, session_data: Arc, ) -> impl Stream + Send + Unpin { let schema = self.clone(); let request = request.into(); let extensions = self.create_extensions(session_data.clone()); let stream = futures_util::stream::StreamExt::boxed({ let extensions = extensions.clone(); async_stream::stream! { let (env, cache_control) = match schema.prepare_request(extensions, request, session_data).await { Ok(res) => res, Err(errors) => { yield Response::from_errors(errors); return; } }; if env.operation.node.ty != OperationType::Subscription { yield schema.execute_once(env).await.cache_control(cache_control); return; } let ctx = env.create_context( &schema.env, None, &env.operation.node.selection_set, ); let mut streams = Vec::new(); let collect_result = if schema.env.registry.introspection_mode == IntrospectionMode::IntrospectionOnly || env.introspection_mode == IntrospectionMode::IntrospectionOnly { collect_subscription_streams(&ctx, &EmptySubscription, &mut streams) } else { collect_subscription_streams(&ctx, &schema.subscription, &mut streams) }; if let Err(err) = collect_result { yield Response::from_errors(vec![err]); } let mut stream = stream::select_all(streams); while let Some(resp) = stream.next().await { yield resp; } } }); extensions.subscribe(stream) } /// Execute a GraphQL subscription. pub fn execute_stream( &self, request: impl Into, ) -> impl Stream + Send + Unpin { self.execute_stream_with_session_data(request.into(), Default::default()) } } fn remove_skipped_selection(selection_set: &mut SelectionSet, variables: &Variables) { fn is_skipped(directives: &[Positioned], variables: &Variables) -> bool { for directive in directives { let include = match &*directive.node.name.node { "skip" => false, "include" => true, _ => continue, }; if let Some(condition_input) = directive.node.get_argument("if") { let value = condition_input .node .clone() .into_const_with(|name| variables.get(&name).cloned().ok_or(())) .unwrap_or_default(); let value: bool = InputType::parse(Some(value)).unwrap_or_default(); if include != value { return true; } } } false } selection_set .items .retain(|selection| !is_skipped(selection.node.directives(), variables)); for selection in &mut selection_set.items { selection.node.directives_mut().retain(|directive| { directive.node.name.node != "skip" && directive.node.name.node != "include" }); } for selection in &mut selection_set.items { match &mut selection.node { Selection::Field(field) => { remove_skipped_selection(&mut field.node.selection_set.node, variables); } Selection::FragmentSpread(_) => {} Selection::InlineFragment(inline_fragment) => { remove_skipped_selection(&mut inline_fragment.node.selection_set.node, variables); } } } } fn check_recursive_depth(doc: &ExecutableDocument, max_depth: usize) -> ServerResult<()> { fn check_selection_set( doc: &ExecutableDocument, selection_set: &Positioned, current_depth: usize, max_depth: usize, ) -> ServerResult<()> { if current_depth > max_depth { return Err(ServerError::new( format!( "The recursion depth of the query cannot be greater than `{}`", max_depth ), Some(selection_set.pos), )); } for selection in &selection_set.node.items { match &selection.node { Selection::Field(field) => { if !field.node.selection_set.node.items.is_empty() { check_selection_set( doc, &field.node.selection_set, current_depth + 1, max_depth, )?; } } Selection::FragmentSpread(fragment_spread) => { if let Some(fragment) = doc.fragments.get(&fragment_spread.node.fragment_name.node) { check_selection_set( doc, &fragment.node.selection_set, current_depth + 1, max_depth, )?; } } Selection::InlineFragment(inline_fragment) => { check_selection_set( doc, &inline_fragment.node.selection_set, current_depth + 1, max_depth, )?; } } } Ok(()) } for (_, operation) in doc.operations.iter() { check_selection_set(doc, &operation.node.selection_set, 0, max_depth)?; } Ok(()) }