From 01b2ce8e18e5cfcfcf2039d1e657fd611b7d3d62 Mon Sep 17 00:00:00 2001 From: sunli Date: Tue, 12 May 2020 16:27:06 +0800 Subject: [PATCH] Improve GraphQL query parser performance. #76 --- async-graphql-derive/src/args.rs | 26 +-- async-graphql-derive/src/interface.rs | 6 +- async-graphql-derive/src/object.rs | 4 +- async-graphql-derive/src/simple_object.rs | 4 +- async-graphql-derive/src/subscription.rs | 19 +- async-graphql-derive/src/union.rs | 4 +- async-graphql-derive/src/utils.rs | 4 +- async-graphql-parser/Cargo.toml | 1 + async-graphql-parser/src/ast.rs | 168 ++++++++++++---- async-graphql-parser/src/lib.rs | 2 +- async-graphql-parser/src/query_parser.rs | 179 ++++++++++++++---- async-graphql-parser/src/value.rs | 15 +- src/base.rs | 10 +- src/context.rs | 46 +++-- src/lib.rs | 3 +- src/mutation_resolver.rs | 14 +- src/query.rs | 106 +++-------- src/registry.rs | 2 +- src/resolver.rs | 14 +- src/scalars/id.rs | 2 +- src/scalars/string.rs | 2 +- src/schema.rs | 48 ++--- src/subscription/subscription_type.rs | 12 +- src/types/connection/connection_type.rs | 10 +- src/types/connection/cursor.rs | 2 +- src/types/connection/edge.rs | 4 +- src/types/empty_subscription.rs | 1 + src/types/enum.rs | 4 +- src/types/query_root.rs | 10 +- .../rules/arguments_of_correct_type.rs | 6 +- .../rules/fields_on_correct_type.rs | 4 +- src/validation/rules/known_argument_names.rs | 6 +- src/validation/rules/known_directives.rs | 2 +- src/validation/rules/known_type_names.rs | 4 +- .../rules/lone_anonymous_operation.rs | 2 +- .../rules/no_undefined_variables.rs | 8 +- src/validation/rules/no_unused_fragments.rs | 8 +- src/validation/rules/no_unused_variables.rs | 8 +- .../rules/overlapping_fields_can_be_merged.rs | 6 +- .../rules/possible_fragment_spreads.rs | 12 +- .../rules/provided_non_null_arguments.rs | 6 +- src/validation/rules/unique_argument_names.rs | 2 +- .../rules/unique_operation_names.rs | 2 +- src/validation/rules/unique_variable_names.rs | 2 +- .../rules/variables_in_allowed_position.rs | 8 +- src/validation/utils.rs | 17 +- src/validation/visitor.rs | 28 +-- 47 files changed, 494 insertions(+), 359 deletions(-) diff --git a/async-graphql-derive/src/args.rs b/async-graphql-derive/src/args.rs index 4d3f3449..3cac51a1 100644 --- a/async-graphql-derive/src/args.rs +++ b/async-graphql-derive/src/args.rs @@ -1,5 +1,5 @@ use crate::utils::{get_rustdoc, parse_guards, parse_validator}; -use async_graphql_parser::{parse_value, Value}; +use async_graphql_parser::{parse_value, ParsedValue}; use proc_macro2::TokenStream; use quote::quote; use syn::{Attribute, AttributeArgs, Error, Lit, Meta, MetaList, NestedMeta, Result, Type}; @@ -126,7 +126,7 @@ impl Object { pub struct Argument { pub name: Option, pub desc: Option, - pub default: Option, + pub default: Option, pub validator: TokenStream, } @@ -163,12 +163,6 @@ impl Argument { } else if nv.path.is_ident("default") { if let syn::Lit::Str(lit) = &nv.lit { match parse_value(&lit.value()) { - Ok(Value::Variable(_)) => { - return Err(Error::new_spanned( - &nv.lit, - "The default cannot be a variable", - )) - } Ok(value) => default = Some(value), Err(err) => { return Err(Error::new_spanned( @@ -436,7 +430,7 @@ impl EnumItem { pub struct InputField { pub name: Option, pub desc: Option, - pub default: Option, + pub default: Option, pub validator: TokenStream, } @@ -480,12 +474,6 @@ impl InputField { } else if nv.path.is_ident("default") { if let syn::Lit::Str(lit) = &nv.lit { match parse_value(&lit.value()) { - Ok(Value::Variable(_)) => { - return Err(Error::new_spanned( - &lit, - "The default cannot be a variable", - )) - } Ok(value) => default = Some(value), Err(err) => { return Err(Error::new_spanned( @@ -578,7 +566,7 @@ pub struct InterfaceFieldArgument { pub name: String, pub desc: Option, pub ty: Type, - pub default: Option, + pub default: Option, } impl InterfaceFieldArgument { @@ -624,12 +612,6 @@ impl InterfaceFieldArgument { } else if nv.path.is_ident("default") { if let syn::Lit::Str(lit) = &nv.lit { match parse_value(&lit.value()) { - Ok(Value::Variable(_)) => { - return Err(Error::new_spanned( - &nv.lit, - "The default cannot be a variable", - )) - } Ok(value) => default = Some(value), Err(err) => { return Err(Error::new_spanned( diff --git a/async-graphql-derive/src/interface.rs b/async-graphql-derive/src/interface.rs index 83b15137..2d26f882 100644 --- a/async-graphql-derive/src/interface.rs +++ b/async-graphql-derive/src/interface.rs @@ -242,7 +242,7 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result }; resolvers.push(quote! { - if ctx.name.as_str() == #name { + if ctx.name.node == #name { #(#get_params)* let ctx_obj = ctx.with_selection_set(&ctx.selection_set); return #crate_name::OutputValueType::resolve(&#resolve_obj, &ctx_obj, ctx.position()).await; @@ -307,14 +307,14 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result async fn resolve_field(&self, ctx: &#crate_name::Context<'_>) -> #crate_name::Result<#crate_name::serde_json::Value> { #(#resolvers)* Err(#crate_name::QueryError::FieldNotFound { - field_name: ctx.name.clone_inner(), + field_name: ctx.name.to_string(), object: #gql_typename.to_string(), }.into_error(ctx.position())) } fn collect_inline_fields<'a>( &'a self, - name: &#crate_name::Positioned, + name: &str, ctx: &#crate_name::ContextSelectionSet<'a>, futures: &mut Vec<#crate_name::BoxFieldFuture<'a>>, ) -> #crate_name::Result<()> { diff --git a/async-graphql-derive/src/object.rs b/async-graphql-derive/src/object.rs index 8d5d1963..0bc6f7d4 100644 --- a/async-graphql-derive/src/object.rs +++ b/async-graphql-derive/src/object.rs @@ -386,7 +386,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< }); resolvers.push(quote! { - if ctx.name.as_str() == #field_name { + if ctx.name.node == #field_name { use #crate_name::OutputValueType; #(#get_params)* #guard @@ -460,7 +460,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< async fn resolve_field(&self, ctx: &#crate_name::Context<'_>) -> #crate_name::Result<#crate_name::serde_json::Value> { #(#resolvers)* Err(#crate_name::QueryError::FieldNotFound { - field_name: ctx.name.clone_inner(), + field_name: ctx.name.to_string(), object: #gql_typename.to_string(), }.into_error(ctx.position())) } diff --git a/async-graphql-derive/src/simple_object.rs b/async-graphql-derive/src/simple_object.rs index 381f5379..9beec804 100644 --- a/async-graphql-derive/src/simple_object.rs +++ b/async-graphql-derive/src/simple_object.rs @@ -113,7 +113,7 @@ pub fn generate(object_args: &args::Object, input: &mut DeriveInput) -> Result Result) -> #crate_name::Result<#crate_name::serde_json::Value> { #(#resolvers)* Err(#crate_name::QueryError::FieldNotFound { - field_name: ctx.name.clone_inner(), + field_name: ctx.name.to_string(), object: #gql_typename.to_string(), }.into_error(ctx.position())) } diff --git a/async-graphql-derive/src/subscription.rs b/async-graphql-derive/src/subscription.rs index 615ad7f6..e7700686 100644 --- a/async-graphql-derive/src/subscription.rs +++ b/async-graphql-derive/src/subscription.rs @@ -217,7 +217,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< #(#schema_args)* args }, - ty: <#stream_ty as #crate_name::futures::stream::Stream>::Item::create_type_info(registry), + ty: <<#stream_ty as #crate_name::futures::stream::Stream>::Item as #crate_name::Type>::create_type_info(registry), deprecation: #field_deprecation, cache_control: Default::default(), external: false, @@ -236,13 +236,21 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< }); create_stream.push(quote! { - if ctx.name.as_str() == #field_name { + if ctx.name.node == #field_name { use #crate_name::futures::stream::{StreamExt, TryStreamExt}; #(#get_params)* #guard let field_name = std::sync::Arc::new(ctx.result_name().to_string()); - let field_selection_set = std::sync::Arc::new(ctx.selection_set.clone()); + + // I think the code here is safe because the lifetime of selection_set is always less than the environment. + let field_selection_set = unsafe { + (&ctx.selection_set + as *const Positioned<#crate_name::parser::ast::SelectionSet>) + .as_ref() + .unwrap() + }; + let schema = schema.clone(); let pos = ctx.position(); let environment = environment.clone(); @@ -261,7 +269,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< parent: None, segment: #crate_name::QueryPathSegment::Name(&field_name), }), - &*field_selection_set, + field_selection_set, &resolve_id, ); #crate_name::OutputValueType::resolve(&msg, &ctx_selection_set, pos).await @@ -325,6 +333,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< #[allow(bare_trait_objects)] async fn create_field_stream( &self, + idx: usize, ctx: &#crate_name::Context<'_>, schema: &#crate_name::Schema, environment: std::sync::Arc<#crate_name::Environment>, @@ -338,7 +347,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< #(#create_stream)* Err(#crate_name::QueryError::FieldNotFound { - field_name: ctx.name.clone_inner(), + field_name: ctx.name.to_string(), object: #gql_typename.to_string(), }.into_error(ctx.position())) } diff --git a/async-graphql-derive/src/union.rs b/async-graphql-derive/src/union.rs index 303f6948..ec4e2d34 100644 --- a/async-graphql-derive/src/union.rs +++ b/async-graphql-derive/src/union.rs @@ -132,14 +132,14 @@ pub fn generate(union_args: &args::Interface, input: &DeriveInput) -> Result) -> #crate_name::Result<#crate_name::serde_json::Value> { Err(#crate_name::QueryError::FieldNotFound { - field_name: ctx.name.clone_inner(), + field_name: ctx.name.to_string(), object: #gql_typename.to_string(), }.into_error(ctx.position())) } fn collect_inline_fields<'a>( &'a self, - name: &#crate_name::Positioned, + name: &str, ctx: &#crate_name::ContextSelectionSet<'a>, futures: &mut Vec<#crate_name::BoxFieldFuture<'a>>, ) -> #crate_name::Result<()> { diff --git a/async-graphql-derive/src/utils.rs b/async-graphql-derive/src/utils.rs index 3ed5efad..b866247c 100644 --- a/async-graphql-derive/src/utils.rs +++ b/async-graphql-derive/src/utils.rs @@ -22,7 +22,7 @@ pub fn build_value_repr(crate_name: &TokenStream, value: &Value) -> TokenStream quote! { #crate_name::Value::Float(#n) } } Value::String(s) => { - quote! { #crate_name::Value::String(#s.to_string()) } + quote! { #crate_name::Value::String(#s.to_string().into()) } } Value::Boolean(n) => { quote! { #crate_name::Value::Boolean(#n) } @@ -46,7 +46,7 @@ pub fn build_value_repr(crate_name: &TokenStream, value: &Value) -> TokenStream .map(|(n, v)| { let value = build_value_repr(crate_name, v); quote! { - obj.insert(#n.to_string(), #value); + obj.insert(#n.to_string().into(), #value); } }) .collect::>(); diff --git a/async-graphql-parser/Cargo.toml b/async-graphql-parser/Cargo.toml index 809f675e..796adb00 100644 --- a/async-graphql-parser/Cargo.toml +++ b/async-graphql-parser/Cargo.toml @@ -17,3 +17,4 @@ pest = "2.1.3" pest_derive = "2.1.0" thiserror = "1.0.11" serde_json = "1.0.48" +arrayvec = "0.5.1" diff --git a/async-graphql-parser/src/ast.rs b/async-graphql-parser/src/ast.rs index 094bde7b..eca35443 100644 --- a/async-graphql-parser/src/ast.rs +++ b/async-graphql-parser/src/ast.rs @@ -1,10 +1,11 @@ use crate::pos::Positioned; use crate::value::Value; +use std::collections::HashMap; use std::fmt; -#[derive(Clone, Debug, PartialEq)] +#[derive(Debug, PartialEq)] pub enum Type { - Named(String), + Named(&'static str), List(Box), NonNull(Box), } @@ -19,46 +20,147 @@ impl fmt::Display for Type { } } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Directive { - pub name: Positioned, - pub arguments: Vec<(Positioned, Positioned)>, + pub name: Positioned<&'static str>, + pub arguments: Vec<(Positioned<&'static str>, Positioned)>, } impl Directive { pub fn get_argument(&self, name: &str) -> Option<&Positioned> { self.arguments .iter() - .find(|item| item.0.as_str() == name) + .find(|item| item.0.node == name) .map(|item| &item.1) } } -#[derive(Clone, Debug)] -pub struct Document { - pub definitions: Vec>, +pub type FragmentsMap = HashMap<&'static str, Positioned>; + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum OperationType { + Query, + Mutation, + Subscription, } -#[derive(Clone, Debug)] +#[derive(Debug)] +pub struct CurrentOperation { + pub ty: OperationType, + pub variable_definitions: Vec>, + pub selection_set: Positioned, +} + +#[derive(Debug)] +pub struct Document { + pub(crate) source: String, + pub(crate) definitions: Vec>, + pub(crate) fragments: FragmentsMap, + pub(crate) current_operation: Option, +} + +impl Document { + #[inline] + pub fn definitions(&self) -> &[Positioned] { + &self.definitions + } + + #[inline] + pub fn fragments(&self) -> &FragmentsMap { + &self.fragments + } + + #[inline] + pub fn current_operation(&self) -> &CurrentOperation { + self.current_operation + .as_ref() + .expect("Must first call retain_operation") + } + + pub fn retain_operation(&mut self, operation_name: Option<&str>) -> bool { + let mut fragments = HashMap::new(); + + for definition in self.definitions.drain(..) { + match definition.node { + Definition::Operation(operation_definition) if self.current_operation.is_none() => { + match operation_definition.node { + OperationDefinition::SelectionSet(s) => { + self.current_operation = Some(CurrentOperation { + ty: OperationType::Query, + variable_definitions: Vec::new(), + selection_set: s, + }); + } + OperationDefinition::Query(query) + if query.name.is_none() + || operation_name.is_none() + || query.name.as_ref().map(|name| name.node) + == operation_name.as_deref() => + { + self.current_operation = Some(CurrentOperation { + ty: OperationType::Query, + variable_definitions: query.node.variable_definitions, + selection_set: query.node.selection_set, + }); + } + OperationDefinition::Mutation(mutation) + if mutation.name.is_none() + || operation_name.is_none() + || mutation.name.as_ref().map(|name| name.node) + == operation_name.as_deref() => + { + self.current_operation = Some(CurrentOperation { + ty: OperationType::Mutation, + variable_definitions: mutation.node.variable_definitions, + selection_set: mutation.node.selection_set, + }); + } + OperationDefinition::Subscription(subscription) + if subscription.name.is_none() + || operation_name.is_none() + || subscription.name.as_ref().map(|name| name.node) + == operation_name.as_deref() => + { + self.current_operation = Some(CurrentOperation { + ty: OperationType::Subscription, + variable_definitions: subscription.node.variable_definitions, + selection_set: subscription.node.selection_set, + }); + } + _ => {} + } + } + Definition::Operation(_) => {} + Definition::Fragment(fragment) => { + fragments.insert(fragment.name.node, fragment); + } + } + } + self.fragments = fragments; + self.current_operation.is_some() + } +} + +#[derive(Debug)] pub enum Definition { Operation(Positioned), Fragment(Positioned), } -#[derive(Clone, Debug)] +#[derive(Debug)] pub enum TypeCondition { - On(Positioned), + On(Positioned<&'static str>), } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct FragmentDefinition { - pub name: Positioned, + pub name: Positioned<&'static str>, pub type_condition: Positioned, pub directives: Vec>, pub selection_set: Positioned, } -#[derive(Clone, Debug)] +#[derive(Debug)] pub enum OperationDefinition { SelectionSet(Positioned), Query(Positioned), @@ -66,54 +168,54 @@ pub enum OperationDefinition { Subscription(Positioned), } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Query { - pub name: Option>, + pub name: Option>, pub variable_definitions: Vec>, pub directives: Vec>, pub selection_set: Positioned, } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Mutation { - pub name: Option>, + pub name: Option>, pub variable_definitions: Vec>, pub directives: Vec>, pub selection_set: Positioned, } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Subscription { - pub name: Option>, + pub name: Option>, pub variable_definitions: Vec>, pub directives: Vec>, pub selection_set: Positioned, } -#[derive(Clone, Debug, Default)] +#[derive(Debug, Default)] pub struct SelectionSet { pub items: Vec>, } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct VariableDefinition { - pub name: Positioned, + pub name: Positioned<&'static str>, pub var_type: Positioned, pub default_value: Option>, } -#[derive(Clone, Debug)] +#[derive(Debug)] pub enum Selection { Field(Positioned), FragmentSpread(Positioned), InlineFragment(Positioned), } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct Field { - pub alias: Option>, - pub name: Positioned, - pub arguments: Vec<(Positioned, Positioned)>, + pub alias: Option>, + pub name: Positioned<&'static str>, + pub arguments: Vec<(Positioned<&'static str>, Positioned)>, pub directives: Vec>, pub selection_set: Positioned, } @@ -122,18 +224,18 @@ impl Field { pub fn get_argument(&self, name: &str) -> Option<&Positioned> { self.arguments .iter() - .find(|item| item.0.as_str() == name) + .find(|item| item.0.node == name) .map(|item| &item.1) } } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct FragmentSpread { - pub fragment_name: Positioned, + pub fragment_name: Positioned<&'static str>, pub directives: Vec>, } -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct InlineFragment { pub type_condition: Option>, pub directives: Vec>, diff --git a/async-graphql-parser/src/lib.rs b/async-graphql-parser/src/lib.rs index b9ecec42..aca14fd7 100644 --- a/async-graphql-parser/src/lib.rs +++ b/async-graphql-parser/src/lib.rs @@ -9,5 +9,5 @@ mod query_parser; mod value; pub use pos::{Pos, Positioned}; -pub use query_parser::{parse_query, parse_value, Error, Result}; +pub use query_parser::{parse_query, parse_value, Error, ParsedValue, Result}; pub use value::{UploadValue, Value}; diff --git a/async-graphql-parser/src/query_parser.rs b/async-graphql-parser/src/query_parser.rs index 67f42895..5200389d 100644 --- a/async-graphql-parser/src/query_parser.rs +++ b/async-graphql-parser/src/query_parser.rs @@ -2,12 +2,15 @@ use crate::ast::*; use crate::pos::Positioned; use crate::value::Value; use crate::Pos; +use arrayvec::ArrayVec; use pest::error::LineColLocation; use pest::iterators::Pair; use pest::Parser; +use std::borrow::Cow; use std::collections::BTreeMap; use std::fmt; use std::iter::Peekable; +use std::ops::Deref; use std::str::Chars; #[derive(Parser)] @@ -95,12 +98,11 @@ impl<'a> PositionCalculator<'a> { } /// Parse a GraphQL query. -pub fn parse_query>(input: T) -> Result { - let document_pair: Pair = QueryParser::parse(Rule::document, input.as_ref())? - .next() - .unwrap(); +pub fn parse_query>(input: T) -> Result { + let source = input.into(); + let document_pair: Pair = QueryParser::parse(Rule::document, &source)?.next().unwrap(); let mut definitions = Vec::new(); - let mut pc = PositionCalculator::new(input.as_ref()); + let mut pc = PositionCalculator::new(&source); for pair in document_pair.into_inner() { match pair.as_rule() { @@ -117,16 +119,36 @@ pub fn parse_query>(input: T) -> Result { _ => unreachable!(), } } - Ok(Document { definitions }) + + Ok(Document { + source, + definitions, + fragments: Default::default(), + current_operation: None, + }) +} + +pub struct ParsedValue { + #[allow(dead_code)] + source: String, + value: Value, +} + +impl Deref for ParsedValue { + type Target = Value; + + fn deref(&self) -> &Self::Target { + &self.value + } } /// Parse a graphql value -pub fn parse_value>(input: T) -> Result { - let value_pair: Pair = QueryParser::parse(Rule::value, input.as_ref())? - .next() - .unwrap(); - let mut pc = PositionCalculator::new(input.as_ref()); - parse_value2(value_pair, &mut pc) +pub fn parse_value>(input: T) -> Result { + let source = input.into(); + let value_pair: Pair = QueryParser::parse(Rule::value, &source)?.next().unwrap(); + let mut pc = PositionCalculator::new(&source); + let value = parse_value2(value_pair, &mut pc)?; + Ok(ParsedValue { source, value }) } fn parse_named_operation_definition( @@ -157,7 +179,10 @@ fn parse_named_operation_definition( }; } Rule::name => { - name = Some(Positioned::new(pair.as_str().to_string(), pc.step(&pair))); + name = Some(Positioned::new( + to_static_str(pair.as_str()), + pc.step(&pair), + )); } Rule::variable_definitions => { variable_definitions = Some(parse_variable_definitions(pair, pc)?); @@ -221,7 +246,7 @@ fn parse_type(pair: Pair, pc: &mut PositionCalculator) -> Result { match pair.as_rule() { Rule::nonnull_type => Ok(Type::NonNull(Box::new(parse_type(pair, pc)?))), Rule::list_type => Ok(Type::List(Box::new(parse_type(pair, pc)?))), - Rule::name => Ok(Type::Named(pair.as_str().to_string())), + Rule::name => Ok(Type::Named(to_static_str(pair.as_str()))), Rule::type_ => parse_type(pair, pc), _ => unreachable!(), } @@ -284,7 +309,10 @@ fn parse_directive(pair: Pair, pc: &mut PositionCalculator) -> Result { let pos = pc.step(&pair); - name = Some(Positioned::new(pair.as_str().to_string(), pos)) + name = Some(Positioned::new( + to_static_str(to_static_str(pair.as_str())), + pos, + )) } Rule::arguments => arguments = Some(parse_arguments(pair, pc)?), _ => unreachable!(), @@ -313,10 +341,16 @@ fn parse_directives( Ok(directives) } -fn parse_variable(pair: Pair, pc: &mut PositionCalculator) -> Result> { +fn parse_variable( + pair: Pair, + pc: &mut PositionCalculator, +) -> Result> { for pair in pair.into_inner() { if let Rule::name = pair.as_rule() { - return Ok(Positioned::new(pair.as_str().to_string(), pc.step(&pair))); + return Ok(Positioned::new( + to_static_str(pair.as_str()), + pc.step(&pair), + )); } } unreachable!() @@ -333,14 +367,14 @@ fn parse_value2(pair: Pair, pc: &mut PositionCalculator) -> Result Rule::string => Value::String({ let start_pos = pair.as_span().start_pos().line_col(); unquote_string( - pair.as_str(), + to_static_str(pair.as_str()), Pos { line: start_pos.0, column: start_pos.1, }, )? }), - Rule::name => Value::Enum(pair.as_str().to_string()), + Rule::name => Value::Enum(to_static_str(pair.as_str())), Rule::boolean => Value::Boolean(match pair.as_str() { "true" => true, "false" => false, @@ -351,12 +385,15 @@ fn parse_value2(pair: Pair, pc: &mut PositionCalculator) -> Result }) } -fn parse_object_pair(pair: Pair, pc: &mut PositionCalculator) -> Result<(String, Value)> { +fn parse_object_pair( + pair: Pair, + pc: &mut PositionCalculator, +) -> Result<(Cow<'static, str>, Value)> { let mut name = None; let mut value = None; for pair in pair.into_inner() { match pair.as_rule() { - Rule::name => name = Some(pair.as_str().to_string()), + Rule::name => name = Some(Cow::Borrowed(to_static_str(pair.as_str()))), Rule::value => value = Some(parse_value2(pair, pc)?), _ => unreachable!(), } @@ -393,12 +430,17 @@ fn parse_array_value(pair: Pair, pc: &mut PositionCalculator) -> Result, pc: &mut PositionCalculator, -) -> Result<(Positioned, Positioned)> { +) -> Result<(Positioned<&'static str>, Positioned)> { let mut name = None; let mut value = None; for pair in pair.into_inner() { match pair.as_rule() { - Rule::name => name = Some(Positioned::new(pair.as_str().to_string(), pc.step(&pair))), + Rule::name => { + name = Some(Positioned::new( + to_static_str(pair.as_str()), + pc.step(&pair), + )) + } Rule::value => { value = { let pos = pc.step(&pair); @@ -414,7 +456,7 @@ fn parse_pair( fn parse_arguments( pair: Pair, pc: &mut PositionCalculator, -) -> Result, Positioned)>> { +) -> Result, Positioned)>> { let mut arguments = Vec::new(); for pair in pair.into_inner() { match pair.as_rule() { @@ -425,10 +467,13 @@ fn parse_arguments( Ok(arguments) } -fn parse_alias(pair: Pair, pc: &mut PositionCalculator) -> Result> { +fn parse_alias(pair: Pair, pc: &mut PositionCalculator) -> Result> { for pair in pair.into_inner() { if let Rule::name = pair.as_rule() { - return Ok(Positioned::new(pair.as_str().to_string(), pc.step(&pair))); + return Ok(Positioned::new( + to_static_str(pair.as_str()), + pc.step(&pair), + )); } } unreachable!() @@ -445,7 +490,12 @@ fn parse_field(pair: Pair, pc: &mut PositionCalculator) -> Result alias = Some(parse_alias(pair, pc)?), - Rule::name => name = Some(Positioned::new(pair.as_str().to_string(), pc.step(&pair))), + Rule::name => { + name = Some(Positioned::new( + to_static_str(pair.as_str()), + pc.step(&pair), + )) + } Rule::arguments => arguments = Some(parse_arguments(pair, pc)?), Rule::directives => directives = Some(parse_directives(pair, pc)?), Rule::selection_set => selection_set = Some(parse_selection_set(pair, pc)?), @@ -474,7 +524,12 @@ fn parse_fragment_spread( let mut directives = None; for pair in pair.into_inner() { match pair.as_rule() { - Rule::name => name = Some(Positioned::new(pair.as_str().to_string(), pc.step(&pair))), + Rule::name => { + name = Some(Positioned::new( + to_static_str(pair.as_str()), + pc.step(&pair), + )) + } Rule::directives => directives = Some(parse_directives(pair, pc)?), _ => unreachable!(), } @@ -496,7 +551,10 @@ fn parse_type_condition( if let Rule::name = pair.as_rule() { let pos = pc.step(&pair); return Ok(Positioned::new( - TypeCondition::On(Positioned::new(pair.as_str().to_string(), pc.step(&pair))), + TypeCondition::On(Positioned::new( + to_static_str(pair.as_str()), + pc.step(&pair), + )), pos, )); } @@ -565,7 +623,12 @@ fn parse_fragment_definition( for pair in pair.into_inner() { match pair.as_rule() { - Rule::name => name = Some(Positioned::new(pair.as_str().to_string(), pc.step(&pair))), + Rule::name => { + name = Some(Positioned::new( + to_static_str(pair.as_str()), + pc.step(&pair), + )) + } Rule::type_condition => type_condition = Some(parse_type_condition(pair, pc)?), Rule::directives => directives = Some(parse_directives(pair, pc)?), Rule::selection_set => selection_set = Some(parse_selection_set(pair, pc)?), @@ -584,11 +647,23 @@ fn parse_fragment_definition( )) } -fn unquote_string(s: &str, pos: Pos) -> Result { - let mut res = String::with_capacity(s.len()); +#[inline] +fn to_static_str(s: &str) -> &'static str { + unsafe { (s as *const str).as_ref().unwrap() } +} + +fn unquote_string(s: &'static str, pos: Pos) -> Result> { debug_assert!(s.starts_with('"') && s.ends_with('"')); - let mut chars = s[1..s.len() - 1].chars(); - let mut temp_code_point = String::with_capacity(4); + let s = &s[1..s.len() - 1]; + + if !s.contains('\\') { + return Ok(Cow::Borrowed(to_static_str(s))); + } + + let mut chars = s.chars(); + let mut res = String::with_capacity(s.len()); + let mut temp_code_point = ArrayVec::<[u8; 4]>::new(); + while let Some(c) = chars.next() { match c { '\\' => { @@ -603,13 +678,28 @@ fn unquote_string(s: &str, pos: Pos) -> Result { temp_code_point.clear(); for _ in 0..4 { match chars.next() { - Some(inner_c) => temp_code_point.push(inner_c), + Some(inner_c) if inner_c.is_digit(16) => { + temp_code_point.push(inner_c as u8) + } + Some(inner_c) => { + return Err(Error { + pos, + message: format!( + "{} is not a valid unicode code point", + inner_c + ), + }); + } None => { return Err(Error { pos, message: format!( - "\\u must have 4 characters after it, only found '{}'", - temp_code_point + "{} must have 4 characters after it", + unsafe { + std::str::from_utf8_unchecked( + temp_code_point.as_slice(), + ) + } ), }); } @@ -617,14 +707,23 @@ fn unquote_string(s: &str, pos: Pos) -> Result { } // convert our hex string into a u32, then convert that into a char - match u32::from_str_radix(&temp_code_point, 16).map(std::char::from_u32) { + match u32::from_str_radix( + unsafe { std::str::from_utf8_unchecked(temp_code_point.as_slice()) }, + 16, + ) + .map(std::char::from_u32) + { Ok(Some(unicode_char)) => res.push(unicode_char), _ => { return Err(Error { pos, message: format!( "{} is not a valid unicode code point", - temp_code_point + unsafe { + std::str::from_utf8_unchecked( + temp_code_point.as_slice(), + ) + } ), }); } @@ -642,7 +741,7 @@ fn unquote_string(s: &str, pos: Pos) -> Result { } } - Ok(res) + Ok(Cow::Owned(res)) } #[cfg(test)] diff --git a/async-graphql-parser/src/value.rs b/async-graphql-parser/src/value.rs index 5991671e..2d40804b 100644 --- a/async-graphql-parser/src/value.rs +++ b/async-graphql-parser/src/value.rs @@ -1,3 +1,4 @@ +use std::borrow::Cow; use std::collections::BTreeMap; use std::fmt; use std::fmt::Formatter; @@ -30,14 +31,14 @@ impl Clone for UploadValue { #[allow(missing_docs)] pub enum Value { Null, - Variable(String), + Variable(&'static str), Int(i64), Float(f64), - String(String), + String(Cow<'static, str>), Boolean(bool), - Enum(String), + Enum(&'static str), List(Vec), - Object(BTreeMap), + Object(BTreeMap, Value>), Upload(UploadValue), } @@ -160,7 +161,7 @@ impl From for serde_json::Value { .into(), Value::Object(obj) => serde_json::Value::Object( obj.into_iter() - .map(|(name, value)| (name, value.into())) + .map(|(name, value)| (name.to_string(), value.into())) .collect(), ), Value::Upload(_) => serde_json::Value::Null, @@ -175,11 +176,11 @@ impl From for Value { serde_json::Value::Bool(n) => Value::Boolean(n), serde_json::Value::Number(n) if n.is_f64() => Value::Float(n.as_f64().unwrap()), serde_json::Value::Number(n) => Value::Int(n.as_i64().unwrap()), - serde_json::Value::String(s) => Value::String(s), + serde_json::Value::String(s) => Value::String(Cow::Owned(s)), serde_json::Value::Array(ls) => Value::List(ls.into_iter().map(Into::into).collect()), serde_json::Value::Object(obj) => Value::Object( obj.into_iter() - .map(|(name, value)| (name, value.into())) + .map(|(name, value)| (Cow::Owned(name), value.into())) .collect(), ), } diff --git a/src/base.rs b/src/base.rs index 29da53f2..82ac02eb 100644 --- a/src/base.rs +++ b/src/base.rs @@ -1,8 +1,8 @@ use crate::parser::Pos; use crate::registry::Registry; use crate::{ - registry, Context, ContextSelectionSet, FieldResult, InputValueResult, Positioned, QueryError, - Result, Value, ID, + registry, Context, ContextSelectionSet, FieldResult, InputValueResult, QueryError, Result, + Value, ID, }; use std::borrow::Cow; use std::future::Future; @@ -81,19 +81,19 @@ pub trait ObjectType: OutputValueType { /// Collect the fields with the `name` inline object fn collect_inline_fields<'a>( &'a self, - name: &Positioned, + name: &str, ctx: &ContextSelectionSet<'a>, futures: &mut Vec>, ) -> Result<()> where Self: Send + Sync + Sized, { - if name.as_str() == Self::type_name().as_ref() + if name == Self::type_name().as_ref() || ctx .registry .implements .get(Self::type_name().as_ref()) - .map(|ty| ty.contains(name.as_str())) + .map(|ty| ty.contains(name)) .unwrap_or_default() { crate::collect_fields(ctx, self, futures) diff --git a/src/context.rs b/src/context.rs index 03fed6ed..33b2d6ff 100644 --- a/src/context.rs +++ b/src/context.rs @@ -1,12 +1,14 @@ use crate::extensions::BoxExtension; -use crate::parser::ast::{Directive, Field, FragmentDefinition, SelectionSet, VariableDefinition}; +use crate::parser::ast::{Directive, Field, SelectionSet}; use crate::registry::Registry; use crate::{InputValueType, QueryError, Result, Schema, Type}; use crate::{Pos, Positioned, Value}; +use async_graphql_parser::ast::Document; use async_graphql_parser::UploadValue; use fnv::FnvHashMap; use std::any::{Any, TypeId}; -use std::collections::{BTreeMap, HashMap}; +use std::borrow::Cow; +use std::collections::BTreeMap; use std::fs::File; use std::ops::{Deref, DerefMut}; use std::sync::atomic::AtomicUsize; @@ -23,7 +25,7 @@ impl Default for Variables { } impl Deref for Variables { - type Target = BTreeMap; + type Target = BTreeMap, Value>; fn deref(&self) -> &Self::Target { if let Value::Object(obj) = &self.0 { @@ -242,11 +244,10 @@ pub struct ContextBase<'a, T> { pub(crate) extensions: &'a [BoxExtension], pub(crate) item: T, pub(crate) variables: &'a Variables, - pub(crate) variable_definitions: &'a [Positioned], pub(crate) registry: &'a Registry, pub(crate) data: &'a Data, pub(crate) ctx_data: Option<&'a Data>, - pub(crate) fragments: &'a HashMap, + pub(crate) document: &'a Document, } impl<'a, T> Deref for ContextBase<'a, T> { @@ -260,8 +261,7 @@ impl<'a, T> Deref for ContextBase<'a, T> { #[doc(hidden)] pub struct Environment { pub variables: Variables, - pub variable_definitions: Vec>, - pub fragments: HashMap, + pub document: Box, pub ctx_data: Arc, } @@ -281,11 +281,10 @@ impl Environment { extensions: &[], item, variables: &self.variables, - variable_definitions: &self.variable_definitions, registry: &schema.0.registry, data: &schema.0.data, ctx_data: Some(&self.ctx_data), - fragments: &self.fragments, + document: &self.document, } } } @@ -314,8 +313,8 @@ impl<'a, T> ContextBase<'a, T> { field .alias .as_ref() - .map(|alias| alias.as_str()) - .unwrap_or_else(|| field.name.as_str()), + .map(|alias| alias.node) + .unwrap_or_else(|| field.name.node), ), }), extensions: self.extensions, @@ -323,11 +322,10 @@ impl<'a, T> ContextBase<'a, T> { resolve_id: self.get_child_resolve_id(), inc_resolve_id: self.inc_resolve_id, variables: self.variables, - variable_definitions: self.variable_definitions, registry: self.registry, data: self.data, ctx_data: self.ctx_data, - fragments: self.fragments, + document: self.document, } } @@ -343,11 +341,10 @@ impl<'a, T> ContextBase<'a, T> { resolve_id: self.resolve_id, inc_resolve_id: &self.inc_resolve_id, variables: self.variables, - variable_definitions: self.variable_definitions, registry: self.registry, data: self.data, ctx_data: self.ctx_data, - fragments: self.fragments, + document: self.document, } } @@ -367,11 +364,13 @@ impl<'a, T> ContextBase<'a, T> { fn var_value(&self, name: &str, pos: Pos) -> Result { let def = self + .document + .current_operation() .variable_definitions .iter() - .find(|def| def.name.as_str() == name); + .find(|def| def.name.node == name); if let Some(def) = def { - if let Some(var_value) = self.variables.get(def.name.as_str()) { + if let Some(var_value) = self.variables.get(def.name.node) { return Ok(var_value.clone()); } else if let Some(default) = &def.default_value { return Ok(default.clone_inner()); @@ -409,7 +408,7 @@ impl<'a, T> ContextBase<'a, T> { #[doc(hidden)] pub fn is_skip(&self, directives: &[Positioned]) -> Result { for directive in directives { - if directive.name.as_str() == "skip" { + if directive.name.node == "skip" { if let Some(value) = directive.get_argument("if") { match InputValueType::parse( self.resolve_input_value(value.clone_inner(), value.position())?, @@ -428,7 +427,7 @@ impl<'a, T> ContextBase<'a, T> { } .into_error(directive.position())); } - } else if directive.name.as_str() == "include" { + } else if directive.name.node == "include" { if let Some(value) = directive.get_argument("if") { match InputValueType::parse( self.resolve_input_value(value.clone_inner(), value.position())?, @@ -449,7 +448,7 @@ impl<'a, T> ContextBase<'a, T> { } } else { return Err(QueryError::UnknownDirective { - name: directive.name.clone_inner(), + name: directive.name.to_string(), } .into_error(directive.position())); } @@ -472,11 +471,10 @@ impl<'a> ContextBase<'a, &'a Positioned> { resolve_id: self.get_child_resolve_id(), inc_resolve_id: self.inc_resolve_id, variables: self.variables, - variable_definitions: self.variable_definitions, registry: self.registry, data: self.data, ctx_data: self.ctx_data, - fragments: self.fragments, + document: self.document, } } } @@ -515,7 +513,7 @@ impl<'a> ContextBase<'a, &'a Positioned> { self.item .alias .as_ref() - .map(|alias| alias.as_str()) - .unwrap_or_else(|| self.item.name.as_str()) + .map(|alias| alias.node) + .unwrap_or_else(|| self.item.name.node) } } diff --git a/src/lib.rs b/src/lib.rs index 707e8a73..bad63153 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -98,7 +98,8 @@ pub mod extensions; pub mod guard; pub mod validators; -use async_graphql_parser as parser; +#[doc(hidden)] +pub use async_graphql_parser as parser; #[doc(hidden)] pub use anyhow; diff --git a/src/mutation_resolver.rs b/src/mutation_resolver.rs index 314e7ebc..5c771a41 100644 --- a/src/mutation_resolver.rs +++ b/src/mutation_resolver.rs @@ -39,7 +39,7 @@ fn do_resolve<'a, T: ObjectType + Send + Sync>( continue; } - if field.name.as_str() == "__typename" { + if field.name.node == "__typename" { values.insert( "__typename".to_string(), root.introspection_type_name().to_string().into(), @@ -59,7 +59,7 @@ fn do_resolve<'a, T: ObjectType + Send + Sync>( .registry .types .get(T::type_name().as_ref()) - .and_then(|ty| ty.field_by_name(field.name.as_str())) + .and_then(|ty| ty.field_by_name(field.name.node)) .map(|field| &field.ty) { Some(ty) => &ty, @@ -68,7 +68,7 @@ fn do_resolve<'a, T: ObjectType + Send + Sync>( pos: field.position(), path: None, err: QueryError::FieldNotFound { - field_name: field.name.clone_inner(), + field_name: field.name.to_string(), object: T::type_name().to_string(), }, }); @@ -97,8 +97,10 @@ fn do_resolve<'a, T: ObjectType + Send + Sync>( continue; } - if let Some(fragment) = - ctx.fragments.get(fragment_spread.fragment_name.as_str()) + if let Some(fragment) = ctx + .document + .fragments() + .get(fragment_spread.fragment_name.node) { do_resolve( &ctx.with_selection_set(&fragment.selection_set), @@ -111,7 +113,7 @@ fn do_resolve<'a, T: ObjectType + Send + Sync>( pos: fragment_spread.position(), path: None, err: QueryError::UnknownFragment { - name: fragment_spread.fragment_name.clone_inner(), + name: fragment_spread.fragment_name.to_string(), }, }); } diff --git a/src/query.rs b/src/query.rs index 8d4c8e91..346c8ff5 100644 --- a/src/query.rs +++ b/src/query.rs @@ -1,19 +1,15 @@ use crate::context::{Data, ResolveId}; use crate::error::ParseRequestError; use crate::mutation_resolver::do_mutation_resolve; -use crate::parser::ast::{ - Definition, Document, OperationDefinition, SelectionSet, VariableDefinition, -}; use crate::parser::parse_query; use crate::registry::CacheControl; use crate::validation::{check_rules, CheckResult}; use crate::{ - do_resolve, ContextBase, Error, ObjectType, Pos, Positioned, QueryError, Result, Schema, - Variables, + do_resolve, ContextBase, Error, ObjectType, Pos, QueryError, Result, Schema, Variables, }; +use async_graphql_parser::ast::OperationType; use itertools::Itertools; use std::any::Any; -use std::collections::HashMap; use std::fs::File; use std::path::PathBuf; use std::sync::atomic::AtomicUsize; @@ -138,7 +134,7 @@ impl QueryBuilder { extensions .iter() .for_each(|e| e.parse_start(&self.query_source)); - let document = parse_query(&self.query_source).map_err(Into::::into)?; + let mut document = parse_query(&self.query_source).map_err(Into::::into)?; extensions.iter().for_each(|e| e.parse_end()); // check rules @@ -165,20 +161,22 @@ impl QueryBuilder { // execute let inc_resolve_id = AtomicUsize::default(); - let mut fragments = HashMap::new(); - let (selection_set, variable_definitions, is_query) = - current_operation(&document, self.operation_name.as_deref()).ok_or_else(|| { - Error::Query { + if !document.retain_operation(self.operation_name.as_deref()) { + return if let Some(operation_name) = self.operation_name { + Err(Error::Query { + pos: Pos::default(), + path: None, + err: QueryError::UnknownOperationNamed { + name: operation_name, + }, + }) + } else { + Err(Error::Query { pos: Pos::default(), path: None, err: QueryError::MissingOperation, - } - })?; - - for definition in &document.definitions { - if let Definition::Fragment(fragment) = &definition.node { - fragments.insert(fragment.name.clone_inner(), fragment.clone_inner()); - } + }) + }; } let ctx = ContextBase { @@ -186,21 +184,28 @@ impl QueryBuilder { resolve_id: ResolveId::root(), inc_resolve_id: &inc_resolve_id, extensions: &extensions, - item: selection_set, + item: &document.current_operation().selection_set, variables: &self.variables, - variable_definitions, registry: &schema.0.registry, data: &schema.0.data, ctx_data: self.ctx_data.as_ref(), - fragments: &fragments, + document: &document, }; extensions.iter().for_each(|e| e.execution_start()); - let data = if is_query { - do_resolve(&ctx, &schema.0.query).await? - } else { - do_mutation_resolve(&ctx, &schema.0.mutation).await? + + let data = match document.current_operation().ty { + OperationType::Query => do_resolve(&ctx, &schema.0.query).await?, + OperationType::Mutation => do_mutation_resolve(&ctx, &schema.0.mutation).await?, + OperationType::Subscription => { + return Err(Error::Query { + pos: Pos::default(), + path: None, + err: QueryError::NotSupported, + }) + } }; + extensions.iter().for_each(|e| e.execution_end()); let res = QueryResponse { @@ -226,54 +231,3 @@ impl QueryBuilder { Ok(res) } } - -#[allow(clippy::type_complexity)] -fn current_operation<'a>( - document: &'a Document, - operation_name: Option<&str>, -) -> Option<( - &'a Positioned, - &'a [Positioned], - bool, -)> { - for definition in &document.definitions { - match &definition.node { - Definition::Operation(operation_definition) => match &operation_definition.node { - OperationDefinition::SelectionSet(s) => { - return Some((s, &[], true)); - } - OperationDefinition::Query(query) - if query.name.is_none() - || operation_name.is_none() - || query.name.as_ref().map(|name| name.as_str()) - == operation_name.as_deref() => - { - return Some((&query.selection_set, &query.variable_definitions, true)); - } - OperationDefinition::Mutation(mutation) - if mutation.name.is_none() - || operation_name.is_none() - || mutation.name.as_ref().map(|name| name.as_str()) - == operation_name.as_deref() => - { - return Some(( - &mutation.selection_set, - &mutation.variable_definitions, - false, - )); - } - OperationDefinition::Subscription(subscription) - if subscription.name.is_none() - || operation_name.is_none() - || subscription.name.as_ref().map(|name| name.as_str()) - == operation_name.as_deref() => - { - return None; - } - _ => {} - }, - Definition::Fragment(_) => {} - } - } - None -} diff --git a/src/registry.rs b/src/registry.rs index d3dd75ab..c1f029f1 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -404,7 +404,7 @@ impl Registry { match query_type { ParsedType::NonNull(ty) => self.concrete_type_by_parsed_type(ty), ParsedType::List(ty) => self.concrete_type_by_parsed_type(ty), - ParsedType::Named(name) => self.types.get(name.as_str()), + ParsedType::Named(name) => self.types.get(*name), } } diff --git a/src/resolver.rs b/src/resolver.rs index 9fb46800..c3def644 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -40,7 +40,7 @@ pub fn collect_fields<'a, T: ObjectType + Send + Sync>( continue; } - if field.name.as_str() == "__typename" { + if field.name.node == "__typename" { // Get the typename let ctx_field = ctx.with_field(field); let field_name = ctx_field.result_name().to_string(); @@ -68,7 +68,7 @@ pub fn collect_fields<'a, T: ObjectType + Send + Sync>( .registry .types .get(T::type_name().as_ref()) - .and_then(|ty| ty.field_by_name(field.name.as_str())) + .and_then(|ty| ty.field_by_name(field.name.node)) .map(|field| &field.ty) { Some(ty) => &ty, @@ -77,7 +77,7 @@ pub fn collect_fields<'a, T: ObjectType + Send + Sync>( pos: field.position(), path: None, err: QueryError::FieldNotFound { - field_name: field.name.clone_inner(), + field_name: field.name.to_string(), object: T::type_name().to_string(), }, }); @@ -112,7 +112,11 @@ pub fn collect_fields<'a, T: ObjectType + Send + Sync>( continue; } - if let Some(fragment) = ctx.fragments.get(fragment_spread.fragment_name.as_str()) { + if let Some(fragment) = ctx + .document + .fragments() + .get(fragment_spread.fragment_name.node) + { collect_fields( &ctx.with_selection_set(&fragment.selection_set), root, @@ -123,7 +127,7 @@ pub fn collect_fields<'a, T: ObjectType + Send + Sync>( pos: fragment_spread.position(), path: None, err: QueryError::UnknownFragment { - name: fragment_spread.fragment_name.clone_inner(), + name: fragment_spread.fragment_name.to_string(), }, }); } diff --git a/src/scalars/id.rs b/src/scalars/id.rs index 911b2e5b..2053b96b 100644 --- a/src/scalars/id.rs +++ b/src/scalars/id.rs @@ -80,7 +80,7 @@ impl ScalarType for ID { fn parse(value: Value) -> InputValueResult { match value { Value::Int(n) => Ok(ID(n.to_string())), - Value::String(s) => Ok(ID(s)), + Value::String(s) => Ok(ID(s.into_owned())), _ => Err(InputValueError::ExpectedType(value)), } } diff --git a/src/scalars/string.rs b/src/scalars/string.rs index 57895870..d4161a70 100644 --- a/src/scalars/string.rs +++ b/src/scalars/string.rs @@ -20,7 +20,7 @@ impl ScalarType for String { fn parse(value: Value) -> InputValueResult { match value { - Value::String(s) => Ok(s), + Value::String(s) => Ok(s.into_owned()), _ => Err(InputValueError::ExpectedType(value)), } } diff --git a/src/schema.rs b/src/schema.rs index 016e1b78..6fbf3eac 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,7 +1,6 @@ use crate::context::Data; use crate::extensions::{BoxExtension, Extension}; use crate::model::__DirectiveLocation; -use crate::parser::ast::{Definition, OperationDefinition}; use crate::parser::parse_query; use crate::query::QueryBuilder; use crate::registry::{Directive, InputValue, Registry}; @@ -242,49 +241,32 @@ where variables: Variables, ctx_data: Option>, ) -> Result> + Send> { - let document = parse_query(source).map_err(Into::::into)?; + let mut document = parse_query(source).map_err(Into::::into)?; check_rules(&self.0.registry, &document, self.0.validation_mode)?; - let mut fragments = HashMap::new(); - let mut subscription = None; - - for definition in document.definitions { - match definition.node { - Definition::Operation(operation) => { - if let OperationDefinition::Subscription(s) = operation.node { - if subscription.is_none() - && (s.name.as_ref().map(|v| v.as_str()) == operation_name - || operation_name.is_none()) - { - subscription = Some(s); - } - } - } - Definition::Fragment(fragment) => { - fragments.insert(fragment.name.clone_inner(), fragment.into_inner()); - } - } - } - - let subscription = subscription - .ok_or(if let Some(name) = operation_name { - QueryError::UnknownOperationNamed { + if !document.retain_operation(operation_name) { + return if let Some(name) = operation_name { + Err(QueryError::UnknownOperationNamed { name: name.to_string(), } - .into_error(Pos::default()) + .into_error(Pos::default())) } else { - QueryError::MissingOperation.into_error(Pos::default()) - })? - .into_inner(); + Err(QueryError::MissingOperation.into_error(Pos::default())) + }; + } let resolve_id = AtomicUsize::default(); let environment = Arc::new(Environment { variables, - variable_definitions: subscription.variable_definitions, - fragments, + document: Box::new(document), ctx_data: ctx_data.unwrap_or_default(), }); - let ctx = environment.create_context(self, None, &subscription.selection_set, &resolve_id); + let ctx = environment.create_context( + self, + None, + &environment.document.current_operation().selection_set, + &resolve_id, + ); let mut streams = Vec::new(); create_subscription_stream(self, environment.clone(), &ctx, &mut streams).await?; Ok(futures::stream::select_all(streams)) diff --git a/src/subscription/subscription_type.rs b/src/subscription/subscription_type.rs index d2833463..e4e21415 100644 --- a/src/subscription/subscription_type.rs +++ b/src/subscription/subscription_type.rs @@ -17,6 +17,7 @@ pub trait SubscriptionType: Type { #[doc(hidden)] async fn create_field_stream( &self, + idx: usize, ctx: &Context<'_>, schema: &Schema, environment: Arc, @@ -41,7 +42,7 @@ where Subscription: SubscriptionType + Send + Sync + 'static + Sized, { Box::pin(async move { - for selection in &ctx.items { + for (idx, selection) in ctx.items.iter().enumerate() { match &selection.node { Selection::Field(field) => { if ctx.is_skip(&field.directives)? { @@ -52,6 +53,7 @@ where .0 .subscription .create_field_stream( + idx, &ctx.with_field(field), schema, environment.clone(), @@ -64,8 +66,10 @@ where continue; } - if let Some(fragment) = - ctx.fragments.get(fragment_spread.fragment_name.as_str()) + if let Some(fragment) = ctx + .document + .fragments() + .get(fragment_spread.fragment_name.node) { create_subscription_stream( schema, @@ -84,7 +88,7 @@ where if let Some(TypeCondition::On(name)) = inline_fragment.type_condition.as_ref().map(|v| &v.node) { - if name.as_str() == Subscription::type_name() { + if name.node == Subscription::type_name() { create_subscription_stream( schema, environment.clone(), diff --git a/src/types/connection/connection_type.rs b/src/types/connection/connection_type.rs index e6b2c7d2..d83d8314 100644 --- a/src/types/connection/connection_type.rs +++ b/src/types/connection/connection_type.rs @@ -177,18 +177,18 @@ impl ObjectType for Connection { async fn resolve_field(&self, ctx: &Context<'_>) -> Result { - if ctx.name.as_str() == "pageInfo" { + if ctx.name.node == "pageInfo" { let ctx_obj = ctx.with_selection_set(&ctx.selection_set); return OutputValueType::resolve(self.page_info().await, &ctx_obj, ctx.position()) .await; - } else if ctx.name.as_str() == "edges" { + } else if ctx.name.node == "edges" { let ctx_obj = ctx.with_selection_set(&ctx.selection_set); return OutputValueType::resolve(&self.edges().await, &ctx_obj, ctx.position()).await; - } else if ctx.name.as_str() == "totalCount" { + } else if ctx.name.node == "totalCount" { let ctx_obj = ctx.with_selection_set(&ctx.selection_set); return OutputValueType::resolve(&self.total_count().await, &ctx_obj, ctx.position()) .await; - } else if ctx.name.as_str() == T::type_name().to_plural().to_camel_case() { + } else if ctx.name.node == T::type_name().to_plural().to_camel_case() { let ctx_obj = ctx.with_selection_set(&ctx.selection_set); let items = self.nodes.iter().map(|(_, _, item)| item).collect_vec(); return OutputValueType::resolve(&items, &ctx_obj, ctx.position()).await; @@ -198,7 +198,7 @@ impl ObjectType pos: ctx.position(), path: None, err: QueryError::FieldNotFound { - field_name: ctx.name.clone_inner(), + field_name: ctx.name.to_string(), object: Connection::::type_name().to_string(), }, }) diff --git a/src/types/connection/cursor.rs b/src/types/connection/cursor.rs index abd4351f..5dca1023 100644 --- a/src/types/connection/cursor.rs +++ b/src/types/connection/cursor.rs @@ -46,7 +46,7 @@ impl ScalarType for Cursor { fn parse(value: Value) -> InputValueResult { match value { - Value::String(s) => Ok(Cursor(s)), + Value::String(s) => Ok(Cursor(s.into_owned())), _ => Err(InputValueError::ExpectedType(value)), } } diff --git a/src/types/connection/edge.rs b/src/types/connection/edge.rs index 25aff0d1..228cd18d 100644 --- a/src/types/connection/edge.rs +++ b/src/types/connection/edge.rs @@ -103,10 +103,10 @@ where E: ObjectType + Sync + Send + 'a, { async fn resolve_field(&self, ctx: &Context<'_>) -> Result { - if ctx.name.as_str() == "node" { + if ctx.name.node == "node" { let ctx_obj = ctx.with_selection_set(&ctx.selection_set); return OutputValueType::resolve(self.node().await, &ctx_obj, ctx.position()).await; - } else if ctx.name.as_str() == "cursor" { + } else if ctx.name.node == "cursor" { return Ok(self.cursor().await.into()); } diff --git a/src/types/empty_subscription.rs b/src/types/empty_subscription.rs index 629a09aa..49539c7f 100644 --- a/src/types/empty_subscription.rs +++ b/src/types/empty_subscription.rs @@ -38,6 +38,7 @@ impl SubscriptionType for EmptySubscription { async fn create_field_stream( &self, + _idx: usize, _ctx: &Context<'_>, _schema: &Schema, _environment: Arc, diff --git a/src/types/enum.rs b/src/types/enum.rs index 3a4ec915..8a383d50 100644 --- a/src/types/enum.rs +++ b/src/types/enum.rs @@ -12,9 +12,9 @@ pub trait EnumType: Type + Sized + Eq + Send + Copy + Sized + 'static { fn items() -> &'static [EnumItem]; fn parse_enum(value: Value) -> InputValueResult { - let value = match value { + let value = match &value { Value::Enum(s) => s, - Value::String(s) => s, + Value::String(s) => s.as_ref(), _ => return Err(InputValueError::ExpectedType(value)), }; diff --git a/src/types/query_root.rs b/src/types/query_root.rs index 55d157a2..64446792 100644 --- a/src/types/query_root.rs +++ b/src/types/query_root.rs @@ -80,13 +80,13 @@ impl Type for QueryRoot { #[async_trait::async_trait] impl ObjectType for QueryRoot { async fn resolve_field(&self, ctx: &Context<'_>) -> Result { - if ctx.name.as_str() == "__schema" { + if ctx.name.node == "__schema" { if self.disable_introspection { return Err(Error::Query { pos: ctx.position(), path: Some(ctx.path_node.as_ref().unwrap().to_json()), err: QueryError::FieldNotFound { - field_name: ctx.name.clone_inner(), + field_name: ctx.name.to_string(), object: Self::type_name().to_string(), }, }); @@ -101,7 +101,7 @@ impl ObjectType for QueryRoot { ctx.position(), ) .await; - } else if ctx.name.as_str() == "__type" { + } else if ctx.name.node == "__type" { let type_name: String = ctx.param_value("name", || Value::Null)?; let ctx_obj = ctx.with_selection_set(&ctx.selection_set); return OutputValueType::resolve( @@ -113,14 +113,14 @@ impl ObjectType for QueryRoot { ctx.position(), ) .await; - } else if ctx.name.as_str() == "_entities" { + } else if ctx.name.node == "_entities" { let representations: Vec = ctx.param_value("representations", || Value::Null)?; let mut res = Vec::new(); for item in representations { res.push(self.inner.find_entity(ctx, &item.0).await?); } return Ok(res.into()); - } else if ctx.name.as_str() == "_service" { + } else if ctx.name.node == "_service" { let ctx_obj = ctx.with_selection_set(&ctx.selection_set); return OutputValueType::resolve( &Service { diff --git a/src/validation/rules/arguments_of_correct_type.rs b/src/validation/rules/arguments_of_correct_type.rs index 7ae62e9e..d0527076 100644 --- a/src/validation/rules/arguments_of_correct_type.rs +++ b/src/validation/rules/arguments_of_correct_type.rs @@ -20,7 +20,7 @@ impl<'a> Visitor<'a> for ArgumentsOfCorrectType<'a> { self.current_args = ctx .registry .directives - .get(directive.name.as_str()) + .get(directive.name.node) .map(|d| &d.args); } @@ -35,12 +35,12 @@ impl<'a> Visitor<'a> for ArgumentsOfCorrectType<'a> { fn enter_argument( &mut self, ctx: &mut VisitorContext<'a>, - name: &'a Positioned, + name: &'a Positioned<&str>, value: &'a Positioned, ) { if let Some(arg) = self .current_args - .and_then(|args| args.get(name.as_str()).map(|input| input)) + .and_then(|args| args.get(name.node).map(|input| input)) { if let Some(validator) = &arg.validator { if let Some(reason) = validator.is_valid(value) { diff --git a/src/validation/rules/fields_on_correct_type.rs b/src/validation/rules/fields_on_correct_type.rs index fd12e382..71196d46 100644 --- a/src/validation/rules/fields_on_correct_type.rs +++ b/src/validation/rules/fields_on_correct_type.rs @@ -12,14 +12,14 @@ impl<'a> Visitor<'a> for FieldsOnCorrectType { if let Some(registry::Type::Union { .. }) | Some(registry::Type::Interface { .. }) = ctx.parent_type() { - if field.name.as_str() == "__typename" { + if field.name.node == "__typename" { return; } } if parent_type .fields() - .and_then(|fields| fields.get(field.name.as_str())) + .and_then(|fields| fields.get(field.name.node)) .is_none() { ctx.report_error( diff --git a/src/validation/rules/known_argument_names.rs b/src/validation/rules/known_argument_names.rs index 8dd02e11..f3962bcc 100644 --- a/src/validation/rules/known_argument_names.rs +++ b/src/validation/rules/known_argument_names.rs @@ -41,7 +41,7 @@ impl<'a> Visitor<'a> for KnownArgumentNames<'a> { self.current_args = ctx .registry .directives - .get(directive.name.as_str()) + .get(directive.name.node) .map(|d| (&d.args, ArgsType::Directive(&directive.name))); } @@ -56,11 +56,11 @@ impl<'a> Visitor<'a> for KnownArgumentNames<'a> { fn enter_argument( &mut self, ctx: &mut VisitorContext<'a>, - name: &'a Positioned, + name: &'a Positioned<&str>, _value: &'a Positioned, ) { if let Some((args, arg_type)) = &self.current_args { - if !args.contains_key(name.as_str()) { + if !args.contains_key(name.node) { match arg_type { ArgsType::Field { field_name, diff --git a/src/validation/rules/known_directives.rs b/src/validation/rules/known_directives.rs index d44aee44..92854e79 100644 --- a/src/validation/rules/known_directives.rs +++ b/src/validation/rules/known_directives.rs @@ -55,7 +55,7 @@ impl<'a> Visitor<'a> for KnownDirectives { ctx: &mut VisitorContext<'a>, directive: &'a Positioned, ) { - if let Some(schema_directive) = ctx.registry.directives.get(directive.name.as_str()) { + if let Some(schema_directive) = ctx.registry.directives.get(directive.name.node) { if let Some(current_location) = self.location_stack.last() { if !schema_directive.locations.contains(current_location) { ctx.report_error( diff --git a/src/validation/rules/known_type_names.rs b/src/validation/rules/known_type_names.rs index 3918e7b5..7cd78bb0 100644 --- a/src/validation/rules/known_type_names.rs +++ b/src/validation/rules/known_type_names.rs @@ -13,7 +13,7 @@ impl<'a> Visitor<'a> for KnownTypeNames { fragment_definition: &'a Positioned, ) { let TypeCondition::On(name) = &fragment_definition.type_condition.node; - validate_type(ctx, name.as_str(), fragment_definition.position()); + validate_type(ctx, name.node, fragment_definition.position()); } fn enter_variable_definition( @@ -36,7 +36,7 @@ impl<'a> Visitor<'a> for KnownTypeNames { if let Some(TypeCondition::On(name)) = inline_fragment.type_condition.as_ref().map(|c| &c.node) { - validate_type(ctx, name.as_str(), inline_fragment.position()); + validate_type(ctx, name.node, inline_fragment.position()); } } } diff --git a/src/validation/rules/lone_anonymous_operation.rs b/src/validation/rules/lone_anonymous_operation.rs index ed7a7ff0..6d35fe74 100644 --- a/src/validation/rules/lone_anonymous_operation.rs +++ b/src/validation/rules/lone_anonymous_operation.rs @@ -10,7 +10,7 @@ pub struct LoneAnonymousOperation { impl<'a> Visitor<'a> for LoneAnonymousOperation { fn enter_document(&mut self, _ctx: &mut VisitorContext<'a>, doc: &'a Document) { self.operation_count = Some( - doc.definitions + doc.definitions() .iter() .filter(|d| match &d.node { Definition::Operation(_) => true, diff --git a/src/validation/rules/no_undefined_variables.rs b/src/validation/rules/no_undefined_variables.rs index 97f11038..d931f5fe 100644 --- a/src/validation/rules/no_undefined_variables.rs +++ b/src/validation/rules/no_undefined_variables.rs @@ -88,7 +88,7 @@ impl<'a> Visitor<'a> for NoUndefinedVariables<'a> { _ctx: &mut VisitorContext<'a>, fragment_definition: &'a Positioned, ) { - self.current_scope = Some(Scope::Fragment(fragment_definition.name.as_str())); + self.current_scope = Some(Scope::Fragment(fragment_definition.name.node)); } fn enter_variable_definition( @@ -98,7 +98,7 @@ impl<'a> Visitor<'a> for NoUndefinedVariables<'a> { ) { if let Some(Scope::Operation(ref name)) = self.current_scope { if let Some(&mut (_, ref mut vars)) = self.defined_variables.get_mut(name) { - vars.insert(variable_definition.name.as_str()); + vars.insert(variable_definition.name.node); } } } @@ -106,7 +106,7 @@ impl<'a> Visitor<'a> for NoUndefinedVariables<'a> { fn enter_argument( &mut self, _ctx: &mut VisitorContext<'a>, - name: &'a Positioned, + name: &'a Positioned<&str>, value: &'a Positioned, ) { if let Some(ref scope) = self.current_scope { @@ -130,7 +130,7 @@ impl<'a> Visitor<'a> for NoUndefinedVariables<'a> { self.spreads .entry(scope.clone()) .or_insert_with(Vec::new) - .push(fragment_spread.fragment_name.as_str()); + .push(fragment_spread.fragment_name.node); } } } diff --git a/src/validation/rules/no_unused_fragments.rs b/src/validation/rules/no_unused_fragments.rs index 9690723e..c89399ae 100644 --- a/src/validation/rules/no_unused_fragments.rs +++ b/src/validation/rules/no_unused_fragments.rs @@ -35,7 +35,7 @@ impl<'a> Visitor<'a> for NoUnusedFragments<'a> { fn exit_document(&mut self, ctx: &mut VisitorContext<'a>, doc: &'a Document) { let mut reachable = HashSet::new(); - for def in &doc.definitions { + for def in doc.definitions() { if let Definition::Operation(operation_definition) = &def.node { let (name, _) = operation_name(operation_definition); self.find_reachable_fragments(&Scope::Operation(name), &mut reachable); @@ -66,9 +66,9 @@ impl<'a> Visitor<'a> for NoUnusedFragments<'a> { _ctx: &mut VisitorContext<'a>, fragment_definition: &'a Positioned, ) { - self.current_scope = Some(Scope::Fragment(fragment_definition.name.as_str())); + self.current_scope = Some(Scope::Fragment(fragment_definition.name.node)); self.defined_fragments.insert(( - fragment_definition.name.as_str(), + fragment_definition.name.node, fragment_definition.position(), )); } @@ -82,7 +82,7 @@ impl<'a> Visitor<'a> for NoUnusedFragments<'a> { self.spreads .entry(scope.clone()) .or_insert_with(Vec::new) - .push(fragment_spread.fragment_name.as_str()); + .push(fragment_spread.fragment_name.node); } } } diff --git a/src/validation/rules/no_unused_variables.rs b/src/validation/rules/no_unused_variables.rs index a2f22619..85ff3aef 100644 --- a/src/validation/rules/no_unused_variables.rs +++ b/src/validation/rules/no_unused_variables.rs @@ -87,7 +87,7 @@ impl<'a> Visitor<'a> for NoUnusedVariables<'a> { _ctx: &mut VisitorContext<'a>, fragment_definition: &'a Positioned, ) { - self.current_scope = Some(Scope::Fragment(fragment_definition.name.as_str())); + self.current_scope = Some(Scope::Fragment(fragment_definition.name.node)); } fn enter_variable_definition( @@ -98,7 +98,7 @@ impl<'a> Visitor<'a> for NoUnusedVariables<'a> { if let Some(Scope::Operation(ref name)) = self.current_scope { if let Some(vars) = self.defined_variables.get_mut(name) { vars.insert(( - variable_definition.name.as_str(), + variable_definition.name.node, variable_definition.position(), )); } @@ -108,7 +108,7 @@ impl<'a> Visitor<'a> for NoUnusedVariables<'a> { fn enter_argument( &mut self, _ctx: &mut VisitorContext<'a>, - _name: &'a Positioned, + _name: &'a Positioned<&str>, value: &'a Positioned, ) { if let Some(ref scope) = self.current_scope { @@ -128,7 +128,7 @@ impl<'a> Visitor<'a> for NoUnusedVariables<'a> { self.spreads .entry(scope.clone()) .or_insert_with(Vec::new) - .push(fragment_spread.fragment_name.as_str()); + .push(fragment_spread.fragment_name.node); } } } diff --git a/src/validation/rules/overlapping_fields_can_be_merged.rs b/src/validation/rules/overlapping_fields_can_be_merged.rs index 73a5fa7a..20602264 100644 --- a/src/validation/rules/overlapping_fields_can_be_merged.rs +++ b/src/validation/rules/overlapping_fields_can_be_merged.rs @@ -33,8 +33,8 @@ impl<'a, 'ctx> FindConflicts<'a, 'ctx> { let output_name = field .alias .as_ref() - .map(|name| name.as_str()) - .unwrap_or_else(|| field.name.as_str()); + .map(|name| name.node) + .unwrap_or_else(|| field.name.node); self.add_output(output_name, field); } Selection::InlineFragment(inline_fragment) => { @@ -66,7 +66,7 @@ impl<'a, 'ctx> FindConflicts<'a, 'ctx> { } for (name, value) in &prev_field.arguments { - match field.get_argument(name.as_str()) + match field.get_argument(name.node) { Some(other_value) if value == other_value => {} _=> self.ctx.report_error( diff --git a/src/validation/rules/possible_fragment_spreads.rs b/src/validation/rules/possible_fragment_spreads.rs index 80ebd418..4e6aaf27 100644 --- a/src/validation/rules/possible_fragment_spreads.rs +++ b/src/validation/rules/possible_fragment_spreads.rs @@ -10,11 +10,10 @@ pub struct PossibleFragmentSpreads<'a> { impl<'a> Visitor<'a> for PossibleFragmentSpreads<'a> { fn enter_document(&mut self, _ctx: &mut VisitorContext<'a>, doc: &'a Document) { - for d in &doc.definitions { + for d in doc.definitions() { if let Definition::Fragment(fragment) = &d.node { let TypeCondition::On(type_name) = &fragment.type_condition.node; - self.fragment_types - .insert(fragment.name.as_str(), type_name); + self.fragment_types.insert(fragment.name.node, type_name); } } } @@ -24,10 +23,7 @@ impl<'a> Visitor<'a> for PossibleFragmentSpreads<'a> { ctx: &mut VisitorContext<'a>, fragment_spread: &'a Positioned, ) { - if let Some(fragment_type) = self - .fragment_types - .get(fragment_spread.fragment_name.as_str()) - { + if let Some(fragment_type) = self.fragment_types.get(fragment_spread.fragment_name.node) { if let Some(current_type) = ctx.current_type() { if let Some(on_type) = ctx.registry.types.get(*fragment_type) { if !current_type.type_overlap(on_type) { @@ -53,7 +49,7 @@ impl<'a> Visitor<'a> for PossibleFragmentSpreads<'a> { if let Some(TypeCondition::On(fragment_type)) = &inline_fragment.type_condition.as_ref().map(|c| &c.node) { - if let Some(on_type) = ctx.registry.types.get(fragment_type.as_str()) { + if let Some(on_type) = ctx.registry.types.get(fragment_type.node) { if !parent_type.type_overlap(&on_type) { ctx.report_error( vec![inline_fragment.position()], diff --git a/src/validation/rules/provided_non_null_arguments.rs b/src/validation/rules/provided_non_null_arguments.rs index c27585dc..800c7619 100644 --- a/src/validation/rules/provided_non_null_arguments.rs +++ b/src/validation/rules/provided_non_null_arguments.rs @@ -12,14 +12,14 @@ impl<'a> Visitor<'a> for ProvidedNonNullArguments { ctx: &mut VisitorContext<'a>, directive: &'a Positioned, ) { - if let Some(schema_directive) = ctx.registry.directives.get(directive.name.as_str()) { + if let Some(schema_directive) = ctx.registry.directives.get(directive.name.node) { for arg in schema_directive.args.values() { if TypeName::create(&arg.ty).is_non_null() && arg.default_value.is_none() && directive .arguments .iter() - .find(|(name, _)| name.as_str() == arg.name) + .find(|(name, _)| name.node == arg.name) .is_none() { ctx.report_error(vec![directive.position()], @@ -41,7 +41,7 @@ impl<'a> Visitor<'a> for ProvidedNonNullArguments { && field .arguments .iter() - .find(|(name, _)| name.as_str() == arg.name) + .find(|(name, _)| name.node == arg.name) .is_none() { ctx.report_error(vec![field.position()], diff --git a/src/validation/rules/unique_argument_names.rs b/src/validation/rules/unique_argument_names.rs index 930f544f..637ee896 100644 --- a/src/validation/rules/unique_argument_names.rs +++ b/src/validation/rules/unique_argument_names.rs @@ -20,7 +20,7 @@ impl<'a> Visitor<'a> for UniqueArgumentNames<'a> { fn enter_argument( &mut self, ctx: &mut VisitorContext<'a>, - name: &'a Positioned, + name: &'a Positioned<&str>, _value: &'a Positioned, ) { if !self.names.insert(name) { diff --git a/src/validation/rules/unique_operation_names.rs b/src/validation/rules/unique_operation_names.rs index 78dd2116..63f214cb 100644 --- a/src/validation/rules/unique_operation_names.rs +++ b/src/validation/rules/unique_operation_names.rs @@ -31,7 +31,7 @@ impl<'a> Visitor<'a> for UniqueOperationNames<'a> { }; if let Some(name) = name { - if !self.names.insert(name.as_str()) { + if !self.names.insert(name.node) { ctx.report_error( vec![name.position()], format!("There can only be one operation named \"{}\"", name), diff --git a/src/validation/rules/unique_variable_names.rs b/src/validation/rules/unique_variable_names.rs index 24523652..3fbd6e8f 100644 --- a/src/validation/rules/unique_variable_names.rs +++ b/src/validation/rules/unique_variable_names.rs @@ -22,7 +22,7 @@ impl<'a> Visitor<'a> for UniqueVariableNames<'a> { ctx: &mut VisitorContext<'a>, variable_definition: &'a Positioned, ) { - if !self.names.insert(variable_definition.name.as_str()) { + if !self.names.insert(variable_definition.name.node) { ctx.report_error( vec![variable_definition.position()], format!( diff --git a/src/validation/rules/variables_in_allowed_position.rs b/src/validation/rules/variables_in_allowed_position.rs index 89a46423..18bdf896 100644 --- a/src/validation/rules/variables_in_allowed_position.rs +++ b/src/validation/rules/variables_in_allowed_position.rs @@ -31,7 +31,7 @@ impl<'a> VariableInAllowedPosition<'a> { if let Some(usages) = self.variable_usages.get(from) { for (var_name, usage_pos, var_type) in usages { - if let Some(def) = var_defs.iter().find(|def| def.name.as_str() == *var_name) { + if let Some(def) = var_defs.iter().find(|def| def.name.node == *var_name) { let expected_type = match (&def.default_value, &def.var_type.node) { (Some(_), Type::List(_)) => def.var_type.to_string() + "!", (Some(_), Type::Named(_)) => def.var_type.to_string() + "!", @@ -80,7 +80,7 @@ impl<'a> Visitor<'a> for VariableInAllowedPosition<'a> { _ctx: &mut VisitorContext<'a>, fragment_definition: &'a Positioned, ) { - self.current_scope = Some(Scope::Fragment(fragment_definition.name.as_str())); + self.current_scope = Some(Scope::Fragment(fragment_definition.name.node)); } fn enter_variable_definition( @@ -105,7 +105,7 @@ impl<'a> Visitor<'a> for VariableInAllowedPosition<'a> { self.spreads .entry(scope.clone()) .or_insert_with(HashSet::new) - .insert(fragment_spread.fragment_name.as_str()); + .insert(fragment_spread.fragment_name.node); } } @@ -122,7 +122,7 @@ impl<'a> Visitor<'a> for VariableInAllowedPosition<'a> { self.variable_usages .entry(scope.clone()) .or_insert_with(Vec::new) - .push((name.as_str(), pos, *expected_type)); + .push((name, pos, *expected_type)); } } } diff --git a/src/validation/utils.rs b/src/validation/utils.rs index db02dd00..c8925e9c 100644 --- a/src/validation/utils.rs +++ b/src/validation/utils.rs @@ -22,7 +22,7 @@ pub fn referenced_variables(value: &Value) -> Vec<&str> { fn referenced_variables_to_vec<'a>(value: &'a Value, vars: &mut Vec<&'a str>) { match value { Value::Variable(name) => { - vars.push(name.as_str()); + vars.push(name); } Value::List(values) => values .iter() @@ -38,14 +38,13 @@ pub fn operation_name(operation_definition: &OperationDefinition) -> (Option<&st match operation_definition { OperationDefinition::SelectionSet(selection_set) => (None, selection_set.position()), OperationDefinition::Query(query) => { - (query.name.as_ref().map(|n| n.as_str()), query.position()) + (query.name.as_ref().map(|n| n.node), query.position()) + } + OperationDefinition::Mutation(mutation) => { + (mutation.name.as_ref().map(|n| n.node), mutation.position()) } - OperationDefinition::Mutation(mutation) => ( - mutation.name.as_ref().map(|n| n.as_str()), - mutation.position(), - ), OperationDefinition::Subscription(subscription) => ( - subscription.name.as_ref().map(|n| n.as_str()), + subscription.name.as_ref().map(|n| n.node), subscription.position(), ), } @@ -107,7 +106,7 @@ pub fn is_valid_input_value( } registry::Type::Enum { enum_values, .. } => match value { Value::Enum(name) => { - if !enum_values.contains_key(name.as_str()) { + if !enum_values.contains_key(name) { Some(valid_error( &path_node, format!( @@ -129,7 +128,7 @@ pub fn is_valid_input_value( Value::Object(values) => { let mut input_names = values .keys() - .map(|name| name.as_str()) + .map(|name| name.as_ref()) .collect::>(); for field in input_fields.values() { diff --git a/src/validation/visitor.rs b/src/validation/visitor.rs index 0d8717ee..5ba1f71d 100644 --- a/src/validation/visitor.rs +++ b/src/validation/visitor.rs @@ -23,10 +23,10 @@ impl<'a> VisitorContext<'a> { type_stack: Default::default(), input_type: Default::default(), fragments: doc - .definitions + .definitions() .iter() .filter_map(|d| match &d.node { - Definition::Fragment(fragment) => Some((fragment.name.as_str(), fragment)), + Definition::Fragment(fragment) => Some((fragment.name.node, fragment)), _ => None, }) .collect(), @@ -147,14 +147,14 @@ pub trait Visitor<'a> { fn enter_argument( &mut self, _ctx: &mut VisitorContext<'a>, - _name: &'a Positioned, + _name: &'a Positioned<&str>, _value: &'a Positioned, ) { } fn exit_argument( &mut self, _ctx: &mut VisitorContext<'a>, - _name: &'a Positioned, + _name: &'a Positioned<&str>, _value: &'a Positioned, ) { } @@ -340,7 +340,7 @@ where fn enter_argument( &mut self, ctx: &mut VisitorContext<'a>, - name: &'a Positioned, + name: &'a Positioned<&str>, value: &'a Positioned, ) { self.0.enter_argument(ctx, name, value); @@ -350,7 +350,7 @@ where fn exit_argument( &mut self, ctx: &mut VisitorContext<'a>, - name: &'a Positioned, + name: &'a Positioned<&str>, value: &'a Positioned, ) { self.0.exit_argument(ctx, name, value); @@ -451,14 +451,14 @@ fn visit_definitions<'a, V: Visitor<'a>>( ctx: &mut VisitorContext<'a>, doc: &'a Document, ) { - for d in &doc.definitions { + for d in doc.definitions() { match &d.node { Definition::Operation(operation) => { visit_operation_definition(v, ctx, operation); } Definition::Fragment(fragment) => { let TypeCondition::On(name) = &fragment.type_condition.node; - ctx.with_type(ctx.registry.types.get(name.as_str()), |ctx| { + ctx.with_type(ctx.registry.types.get(name.node), |ctx| { visit_fragment_definition(v, ctx, fragment) }); } @@ -539,7 +539,7 @@ fn visit_selection<'a, V: Visitor<'a>>( v.enter_selection(ctx, selection); match &selection.node { Selection::Field(field) => { - if field.name.as_str() != "__typename" { + if field.name.node != "__typename" { ctx.with_type( ctx.current_type() .and_then(|ty| ty.field_by_name(&field.name)) @@ -559,7 +559,7 @@ fn visit_selection<'a, V: Visitor<'a>>( if let Some(TypeCondition::On(name)) = &inline_fragment.type_condition.as_ref().map(|c| &c.node) { - ctx.with_type(ctx.registry.types.get(name.as_str()), |ctx| { + ctx.with_type(ctx.registry.types.get(name.node), |ctx| { visit_inline_fragment(v, ctx, inline_fragment) }); } @@ -580,7 +580,7 @@ fn visit_field<'a, V: Visitor<'a>>( let expected_ty = ctx .parent_type() .and_then(|ty| ty.field_by_name(&field.name)) - .and_then(|schema_field| schema_field.args.get(name.as_str())) + .and_then(|schema_field| schema_field.args.get(name.node)) .map(|input_ty| TypeName::create(&input_ty.ty)); ctx.with_input_type(expected_ty, |ctx| { visit_input_value(v, ctx, field.position(), expected_ty, value) @@ -624,7 +624,7 @@ fn visit_input_value<'a, V: Visitor<'a>>( { if let Type::InputObject { input_fields, .. } = ty { for (item_key, item_value) in values { - if let Some(input_value) = input_fields.get(item_key) { + if let Some(input_value) = input_fields.get(item_key.as_ref()) { visit_input_value( v, ctx, @@ -664,12 +664,12 @@ fn visit_directives<'a, V: Visitor<'a>>( for d in directives { v.enter_directive(ctx, d); - let schema_directive = ctx.registry.directives.get(d.name.as_str()); + let schema_directive = ctx.registry.directives.get(d.name.node); for (name, value) in &d.arguments { v.enter_argument(ctx, name, value); let expected_ty = schema_directive - .and_then(|schema_directive| schema_directive.args.get(name.as_str())) + .and_then(|schema_directive| schema_directive.args.get(name.node)) .map(|input_ty| TypeName::create(&input_ty.ty)); ctx.with_input_type(expected_ty, |ctx| { visit_input_value(v, ctx, d.position(), expected_ty, value)