diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 095c22f3..1e7b013d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,5 +13,7 @@ jobs: - uses: actions/checkout@v1 - name: Build run: cargo build --all --verbose + - name: Check format + run: cargo fmt --all -- --check - name: Run tests run: cargo test --all --verbose diff --git a/async-graphql-derive/src/enum.rs b/async-graphql-derive/src/enum.rs index cc0145a5..8a958630 100644 --- a/async-graphql-derive/src/enum.rs +++ b/async-graphql-derive/src/enum.rs @@ -60,11 +60,11 @@ pub fn generate(enum_args: &args::Enum, input: &DeriveInput) -> Result Result Result }) .unwrap_or_else(|| quote! {None}); schema_args.push(quote! { - #crate_name::registry::InputValue { + args.insert(#name, #crate_name::registry::InputValue { name: #name, description: #desc, ty: <#ty as #crate_name::GQLType>::create_type_info(registry), default_value: #schema_default, - } + }); }); } @@ -174,13 +174,17 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result let schema_ty = ty.value_type(); schema_fields.push(quote! { - #crate_name::registry::Field { + fields.insert(#name, #crate_name::registry::Field { name: #name, description: #desc, - args: vec![#(#schema_args),*], + args: { + let mut args = std::collections::HashMap::new(); + #(#schema_args)* + args + }, ty: <#schema_ty as #crate_name::GQLType>::create_type_info(registry), deprecation: #deprecation, - } + }); }); let resolve_obj = match &ty { @@ -227,7 +231,11 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result #crate_name::registry::Type::Interface { name: #gql_typename, description: #desc, - fields: vec![#(#schema_fields),*], + fields: { + let mut fields = std::collections::HashMap::new(); + #(#schema_fields)* + fields + }, possible_types: vec![#(#possible_types),*], } }) @@ -241,7 +249,7 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result #(#resolvers)* - anyhow::bail!(#crate_name::QueryError::FieldNotFound { + #crate_name::anyhow::bail!(#crate_name::QueryError::FieldNotFound { field_name: field.name.clone(), object: #gql_typename.to_string(), } @@ -250,7 +258,7 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result async fn resolve_inline_fragment(&self, name: &str, ctx: &#crate_name::ContextSelectionSet<'_>, result: &mut #crate_name::serde_json::Map) -> #crate_name::Result<()> { #(#inline_fragment_resolvers)* - anyhow::bail!(#crate_name::QueryError::UnrecognizedInlineFragment { + #crate_name::anyhow::bail!(#crate_name::QueryError::UnrecognizedInlineFragment { object: #gql_typename.to_string(), name: name.to_string(), }); diff --git a/async-graphql-derive/src/object.rs b/async-graphql-derive/src/object.rs index 5643bbdc..a4d423fe 100644 --- a/async-graphql-derive/src/object.rs +++ b/async-graphql-derive/src/object.rs @@ -129,12 +129,12 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< .unwrap_or_else(|| quote! {None}); schema_args.push(quote! { - #crate_name::registry::InputValue { + args.insert(#name, #crate_name::registry::InputValue { name: #name, description: #desc, ty: <#ty as #crate_name::GQLType>::create_type_info(registry), default_value: #schema_default, - } + }); }); use_params.push(quote! { #ident }); @@ -153,13 +153,17 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< let schema_ty = ty.value_type(); schema_fields.push(quote! { - #crate_name::registry::Field { + fields.insert(#field_name, #crate_name::registry::Field { name: #field_name, description: #field_desc, - args: vec![#(#schema_args),*], + args: { + let mut args = std::collections::HashMap::new(); + #(#schema_args)* + args + }, ty: <#schema_ty as #crate_name::GQLType>::create_type_info(registry), deprecation: #field_deprecation, - } + }); }); let ctx_field = match arg_ctx { @@ -206,7 +210,11 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< registry.create_type::(|registry| #crate_name::registry::Type::Object { name: #gql_typename, description: #desc, - fields: vec![#(#schema_fields),*], + fields: { + let mut fields = std::collections::HashMap::new(); + #(#schema_fields)* + fields + }, }) } } @@ -218,7 +226,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< #(#resolvers)* - anyhow::bail!(#crate_name::QueryError::FieldNotFound { + #crate_name::anyhow::bail!(#crate_name::QueryError::FieldNotFound { field_name: field.name.clone(), object: #gql_typename.to_string(), } @@ -226,7 +234,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< } async fn resolve_inline_fragment(&self, name: &str, ctx: &#crate_name::ContextSelectionSet<'_>, result: &mut #crate_name::serde_json::Map) -> #crate_name::Result<()> { - anyhow::bail!(#crate_name::QueryError::UnrecognizedInlineFragment { + #crate_name::anyhow::bail!(#crate_name::QueryError::UnrecognizedInlineFragment { object: #gql_typename.to_string(), name: name.to_string(), }); diff --git a/src/base.rs b/src/base.rs index 026f4fe7..d30f4485 100644 --- a/src/base.rs +++ b/src/base.rs @@ -41,10 +41,17 @@ pub trait GQLInputObject: GQLInputValue {} pub trait GQLScalar: Sized + Send { fn type_name() -> &'static str; + fn description() -> Option<&'static str> { None } + fn parse(value: &Value) -> Option; + + fn is_valid(value: &Value) -> bool { + Self::parse(value).is_some() + } + fn to_json(&self) -> Result; } @@ -60,6 +67,7 @@ macro_rules! impl_scalar { registry.create_type::<$ty, _>(|_| crate::registry::Type::Scalar { name: <$ty as crate::GQLScalar>::type_name().to_string(), description: <$ty>::description(), + is_valid: |value| <$ty as crate::GQLScalar>::is_valid(value), }) } } @@ -73,6 +81,7 @@ macro_rules! impl_scalar { registry.create_type::<$ty, _>(|_| crate::registry::Type::Scalar { name: <$ty as crate::GQLScalar>::type_name().to_string(), description: <$ty>::description(), + is_valid: |value| <$ty as crate::GQLScalar>::is_valid(value), }) } } diff --git a/src/error.rs b/src/error.rs index 8ef64fc3..681b6e78 100644 --- a/src/error.rs +++ b/src/error.rs @@ -104,3 +104,23 @@ impl Display for PositionError { write!(f, "{}", self.inner) } } + +#[derive(Debug)] +pub struct RuleError { + pub locations: Vec, + pub message: String, +} + +#[derive(Debug, Error)] +pub struct RuleErrors { + pub errors: Vec, +} + +impl Display for RuleErrors { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + for error in &self.errors { + writeln!(f, "{}", error.message)?; + } + Ok(()) + } +} diff --git a/src/http/mod.rs b/src/http/mod.rs index 24d0a400..bb14c697 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -4,8 +4,10 @@ mod playground_source; pub use graphiql_source::graphiql_source; pub use playground_source::playground_source; +use crate::error::{RuleError, RuleErrors}; use crate::{GQLObject, PositionError, Result, Schema, Variables}; -use serde::ser::SerializeMap; +use graphql_parser::Pos; +use serde::ser::{SerializeMap, SerializeSeq}; use serde::{Serialize, Serializer}; use std::ops::Deref; @@ -57,7 +59,7 @@ impl Serialize for GQLResponse { Err(err) => { let mut map = serializer.serialize_map(None)?; map.serialize_key("errors")?; - map.serialize_value(&[GQLError(err)])?; + map.serialize_value(&GQLError(err))?; map.end() } } @@ -79,25 +81,75 @@ impl<'a> Serialize for GQLError<'a> { where S: Serializer, { - let mut map = serializer.serialize_map(None)?; - - match self.0.downcast_ref::() { - Some(err) => { - map.serialize_key("message")?; - map.serialize_value(&err.to_string())?; - - map.serialize_key("locations")?; - map.serialize_value(&[serde_json::json! ({ - "line": err.position.line, - "column": err.position.column, - })])?; - } - None => { - map.serialize_key("message")?; - map.serialize_value(&self.0.to_string())?; + if let Some(err) = self.0.downcast_ref::() { + let mut seq = serializer.serialize_seq(Some(1))?; + seq.serialize_element(&GQLPositionError(err))?; + seq.end() + } else if let Some(err) = self.0.downcast_ref::() { + let mut seq = serializer.serialize_seq(Some(err.errors.len()))?; + for err in &err.errors { + seq.serialize_element(&GQLRuleError(err))?; } + seq.end() + } else { + let mut seq = serializer.serialize_seq(None)?; + seq.serialize_element(&serde_json::json!({ + "message": self.0.to_string(), + }))?; + seq.end() } + } +} +struct GQLErrorPos<'a>(&'a Pos); + +impl<'a> Serialize for GQLErrorPos<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("line", &self.0.line)?; + map.serialize_entry("column", &self.0.column)?; + map.end() + } +} + +struct GQLPositionError<'a>(&'a PositionError); + +impl<'a> Serialize for GQLPositionError<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(None)?; + map.serialize_entry("message", &self.0.inner.to_string())?; + map.serialize_entry( + "locations", + std::slice::from_ref(&GQLErrorPos(&self.0.position)), + )?; + map.end() + } +} + +struct GQLRuleError<'a>(&'a RuleError); + +impl<'a> Serialize for GQLRuleError<'a> { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + let mut map = serializer.serialize_map(None)?; + map.serialize_entry("message", &self.0.message)?; + map.serialize_entry( + "locations", + &self + .0 + .locations + .iter() + .map(|pos| GQLErrorPos(pos)) + .collect::>(), + )?; map.end() } } diff --git a/src/lib.rs b/src/lib.rs index 0adfe3f4..ac88809f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -61,6 +61,7 @@ mod resolver; mod scalars; mod schema; mod types; +mod validation; #[doc(hidden)] pub use anyhow; diff --git a/src/model/directive.rs b/src/model/directive.rs index aff8ed69..b5c4ee18 100644 --- a/src/model/directive.rs +++ b/src/model/directive.rs @@ -97,7 +97,7 @@ impl<'a> __Directive<'a> { async fn args(&self) -> Vec<__InputValue<'a>> { self.directive .args - .iter() + .values() .map(|input_value| __InputValue { registry: self.registry, input_value, diff --git a/src/model/field.rs b/src/model/field.rs index f9ec4f1d..9aa03565 100644 --- a/src/model/field.rs +++ b/src/model/field.rs @@ -26,7 +26,7 @@ impl<'a> __Field<'a> { async fn args(&self) -> Vec<__InputValue<'a>> { self.field .args - .iter() + .values() .map(|input_value| __InputValue { registry: self.registry, input_value, diff --git a/src/model/schema.rs b/src/model/schema.rs index 11c1ab3d..c4341204 100644 --- a/src/model/schema.rs +++ b/src/model/schema.rs @@ -4,8 +4,6 @@ use async_graphql_derive::Object; pub struct __Schema<'a> { pub registry: &'a registry::Registry, - pub query_type: &'a str, - pub mutation_type: Option<&'a str>, } #[Object( @@ -27,7 +25,10 @@ impl<'a> __Schema<'a> { desc = "The type that query operations will be rooted at." )] async fn query_type(&self) -> __Type<'a> { - __Type::new_simple(self.registry, &self.registry.types[self.query_type]) + __Type::new_simple( + self.registry, + &self.registry.types[&self.registry.query_type], + ) } #[field( @@ -35,7 +36,7 @@ impl<'a> __Schema<'a> { desc = "If this server supports mutation, the type that mutation operations will be rooted at." )] async fn mutation_type(&self) -> Option<__Type<'a>> { - if let Some(ty) = self.mutation_type { + if let Some(ty) = &self.registry.mutation_type { Some(__Type::new_simple(self.registry, &self.registry.types[ty])) } else { None @@ -54,7 +55,7 @@ impl<'a> __Schema<'a> { async fn directives(&self) -> Vec<__Directive<'a>> { self.registry .directives - .iter() + .values() .map(|directive| __Directive { registry: &self.registry, directive, diff --git a/src/model/type.rs b/src/model/type.rs index 0e1e3c88..09786934 100644 --- a/src/model/type.rs +++ b/src/model/type.rs @@ -1,6 +1,6 @@ use crate::model::{__EnumValue, __Field, __InputValue, __TypeKind}; use crate::registry; -use crate::registry::Type; +use crate::registry::{Type, TypeInfo}; use async_graphql_derive::Object; enum TypeDetail<'a> { @@ -23,21 +23,19 @@ impl<'a> __Type<'a> { } pub fn new(registry: &'a registry::Registry, type_name: &str) -> __Type<'a> { - if let Some(type_name) = parse_non_null(type_name) { - __Type { + match TypeInfo::create(type_name) { + TypeInfo::NonNull(ty) => __Type { registry, - detail: TypeDetail::NonNull(type_name.to_string()), - } - } else if let Some(type_name) = parse_list(type_name) { - __Type { + detail: TypeDetail::NonNull(ty.to_string()), + }, + TypeInfo::List(ty) => __Type { registry, - detail: TypeDetail::List(type_name.to_string()), - } - } else { - __Type { + detail: TypeDetail::List(ty.to_string()), + }, + TypeInfo::Type(ty) => __Type { registry, - detail: TypeDetail::Simple(®istry.types[type_name]), - } + detail: TypeDetail::Simple(®istry.types[ty]), + }, } } } @@ -70,14 +68,7 @@ impl<'a> __Type<'a> { #[field] async fn name(&self) -> Option { match &self.detail { - TypeDetail::Simple(ty) => match ty { - registry::Type::Scalar { name, .. } => Some(name.clone()), - registry::Type::Object { name, .. } => Some(name.to_string()), - registry::Type::Interface { name, .. } => Some(name.to_string()), - registry::Type::Union { name, .. } => Some(name.to_string()), - registry::Type::Enum { name, .. } => Some(name.to_string()), - registry::Type::InputObject { name, .. } => Some(name.to_string()), - }, + TypeDetail::Simple(ty) => Some(ty.name().to_string()), TypeDetail::NonNull(_) => None, TypeDetail::List(_) => None, } @@ -106,40 +97,22 @@ impl<'a> __Type<'a> { &self, #[arg(name = "includeDeprecated", default = "false")] include_deprecated: bool, ) -> Option>> { - if let TypeDetail::Simple(Type::Object { fields, .. }) = &self.detail { - Some( - fields - .iter() - .filter(|field| { - if include_deprecated { - true - } else { - field.deprecation.is_none() - } - }) - .map(|field| __Field { - registry: self.registry, - field, - }) - .collect(), - ) - } else if let TypeDetail::Simple(Type::Interface { fields, .. }) = &self.detail { - Some( - fields - .iter() - .filter(|field| { - if include_deprecated { - true - } else { - field.deprecation.is_none() - } - }) - .map(|field| __Field { - registry: self.registry, - field, - }) - .collect(), - ) + if let TypeDetail::Simple(ty) = &self.detail { + ty.fields().and_then(|fields| { + Some( + fields + .values() + .filter(|field| { + (include_deprecated || field.deprecation.is_none()) + && !field.name.starts_with("__") + }) + .map(|field| __Field { + registry: self.registry, + field, + }) + .collect(), + ) + }) } else { None } @@ -191,14 +164,8 @@ impl<'a> __Type<'a> { if let TypeDetail::Simple(Type::Enum { enum_values, .. }) = &self.detail { Some( enum_values - .iter() - .filter(|field| { - if include_deprecated { - true - } else { - field.deprecation.is_none() - } - }) + .values() + .filter(|field| include_deprecated || field.deprecation.is_none()) .map(|value| __EnumValue { registry: self.registry, value, @@ -238,19 +205,3 @@ impl<'a> __Type<'a> { } } } - -fn parse_non_null(type_name: &str) -> Option<&str> { - if type_name.ends_with("!") { - Some(&type_name[..type_name.len() - 1]) - } else { - None - } -} - -fn parse_list(type_name: &str) -> Option<&str> { - if type_name.starts_with("[") { - Some(&type_name[1..type_name.len() - 1]) - } else { - None - } -} diff --git a/src/registry.rs b/src/registry.rs index 8dead1db..c5b24a88 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -1,6 +1,40 @@ -use crate::{model, GQLType}; +use crate::{model, GQLType, Value}; use std::collections::{HashMap, HashSet}; +fn parse_non_null(type_name: &str) -> Option<&str> { + if type_name.ends_with("!") { + Some(&type_name[..type_name.len() - 1]) + } else { + None + } +} + +fn parse_list(type_name: &str) -> Option<&str> { + if type_name.starts_with("[") { + Some(&type_name[1..type_name.len() - 1]) + } else { + None + } +} + +pub enum TypeInfo<'a> { + List(&'a str), + NonNull(&'a str), + Type(&'a str), +} + +impl<'a> TypeInfo<'a> { + pub fn create(type_name: &str) -> TypeInfo { + if let Some(type_name) = parse_non_null(type_name) { + TypeInfo::NonNull(type_name) + } else if let Some(type_name) = parse_list(type_name) { + TypeInfo::List(type_name) + } else { + TypeInfo::Type(type_name) + } + } +} + pub struct InputValue { pub name: &'static str, pub description: Option<&'static str>, @@ -11,7 +45,7 @@ pub struct InputValue { pub struct Field { pub name: &'static str, pub description: Option<&'static str>, - pub args: Vec, + pub args: HashMap<&'static str, InputValue>, pub ty: String, pub deprecation: Option<&'static str>, } @@ -26,16 +60,17 @@ pub enum Type { Scalar { name: String, description: Option<&'static str>, + is_valid: fn(value: &Value) -> bool, }, Object { name: &'static str, description: Option<&'static str>, - fields: Vec, + fields: HashMap<&'static str, Field>, }, Interface { name: &'static str, description: Option<&'static str>, - fields: Vec, + fields: HashMap<&'static str, Field>, possible_types: Vec, }, Union { @@ -46,7 +81,7 @@ pub enum Type { Enum { name: &'static str, description: Option<&'static str>, - enum_values: Vec, + enum_values: HashMap<&'static str, EnumValue>, }, InputObject { name: &'static str, @@ -55,18 +90,44 @@ pub enum Type { }, } +impl Type { + pub fn field_by_name(&self, name: &str) -> Option<&Field> { + self.fields().and_then(|fields| fields.get(name)) + } + + pub fn fields(&self) -> Option<&HashMap<&'static str, Field>> { + match self { + Type::Object { fields, .. } => Some(&fields), + Type::Interface { fields, .. } => Some(&fields), + _ => None, + } + } + + pub fn name(&self) -> &str { + match self { + Type::Scalar { name, .. } => &name, + Type::Object { name, .. } => name, + Type::Interface { name, .. } => name, + Type::Union { name, .. } => name, + Type::Enum { name, .. } => name, + Type::InputObject { name, .. } => name, + } + } +} + pub struct Directive { pub name: &'static str, pub description: Option<&'static str>, pub locations: Vec, - pub args: Vec, + pub args: HashMap<&'static str, InputValue>, } -#[derive(Default)] pub struct Registry { pub types: HashMap, - pub directives: Vec, + pub directives: HashMap, pub implements: HashMap>, + pub query_type: String, + pub mutation_type: Option, } impl Registry { @@ -75,19 +136,33 @@ impl Registry { if !self.types.contains_key(name.as_ref()) { self.types.insert( name.to_string(), - Type::Scalar { - name: String::new(), + Type::Object { + name: "", description: None, + fields: Default::default(), }, ); - let ty = f(self); + let mut ty = f(self); + if let Type::Object { fields, .. } = &mut ty { + fields.insert( + "__typename", + Field { + name: "", + description: None, + args: Default::default(), + ty: "String!".to_string(), + deprecation: None, + }, + ); + } self.types.insert(name.to_string(), ty); } T::qualified_type_name() } pub fn add_directive(&mut self, directive: Directive) { - self.directives.push(directive); + self.directives + .insert(directive.name.to_string(), directive); } pub fn add_implements(&mut self, ty: &str, interface: &str) { @@ -102,4 +177,12 @@ impl Registry { interfaces }); } + + pub fn get_basic_type(&self, type_name: &str) -> Option<&Type> { + match TypeInfo::create(type_name) { + TypeInfo::Type(type_name) => self.types.get(type_name), + TypeInfo::List(type_name) => self.get_basic_type(type_name), + TypeInfo::NonNull(type_name) => self.get_basic_type(type_name), + } + } } diff --git a/src/scalars/string.rs b/src/scalars/string.rs index 80d02219..ba7faa02 100644 --- a/src/scalars/string.rs +++ b/src/scalars/string.rs @@ -21,6 +21,13 @@ impl GQLScalar for String { } } + fn is_valid(value: &Value) -> bool { + match value { + Value::String(_) => true, + _ => false, + } + } + fn to_json(&self) -> Result { Ok(self.clone().into()) } @@ -37,6 +44,10 @@ impl<'a> GQLType for &'a str { registry.create_type::(|_| registry::Type::Scalar { name: Self::type_name().to_string(), description: Some(STRING_DESC), + is_valid: |value| match value { + Value::String(_) => true, + _ => false, + }, }) } } diff --git a/src/scalars/uuid.rs b/src/scalars/uuid.rs index 908d2772..a55e01a9 100644 --- a/src/scalars/uuid.rs +++ b/src/scalars/uuid.rs @@ -18,4 +18,4 @@ impl GQLScalar for Uuid { } } -impl_scalar!(Uuid); \ No newline at end of file +impl_scalar!(Uuid); diff --git a/src/schema.rs b/src/schema.rs index 6eef2585..63c76c0b 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -2,6 +2,7 @@ use crate::context::Data; use crate::model::__DirectiveLocation; use crate::registry::{Directive, InputValue, Registry}; use crate::types::QueryRoot; +use crate::validation::check_rules; use crate::{ ContextBase, GQLObject, GQLOutputValue, GQLType, QueryError, QueryParseError, Result, Variables, }; @@ -19,7 +20,17 @@ pub struct Schema { impl Schema { pub fn new(query: Query, mutation: Mutation) -> Self { - let mut registry = Registry::default(); + let mut registry = Registry { + types: Default::default(), + directives: Default::default(), + implements: Default::default(), + query_type: Query::type_name().to_string(), + mutation_type: if Mutation::is_empty() { + None + } else { + Some(Mutation::type_name().to_string()) + }, + }; registry.add_directive(Directive { name: "include", @@ -29,12 +40,16 @@ impl Schema { __DirectiveLocation::FRAGMENT_SPREAD, __DirectiveLocation::INLINE_FRAGMENT ], - args: vec![InputValue{ - name: "if", - description: Some("Included when true."), - ty: "Boolean!".to_string(), - default_value: None - }] + args: { + let mut args = HashMap::new(); + args.insert("if", InputValue { + name: "if", + description: Some("Included when true."), + ty: "Boolean!".to_string(), + default_value: None + }); + args + } }); registry.add_directive(Directive { @@ -45,12 +60,16 @@ impl Schema { __DirectiveLocation::FRAGMENT_SPREAD, __DirectiveLocation::INLINE_FRAGMENT ], - args: vec![InputValue{ - name: "if", - description: Some("Skipped when true."), - ty: "Boolean!".to_string(), - default_value: None - }] + args: { + let mut args = HashMap::new(); + args.insert("if", InputValue { + name: "if", + description: Some("Skipped when true."), + ty: "Boolean!".to_string(), + default_value: None + }); + args + } }); // register scalars @@ -59,21 +78,13 @@ impl Schema { f32::create_type_info(&mut registry); String::create_type_info(&mut registry); - Query::create_type_info(&mut registry); + QueryRoot::::create_type_info(&mut registry); if !Mutation::is_empty() { Mutation::create_type_info(&mut registry); } Self { - query: QueryRoot { - inner: query, - query_type: Query::type_name().to_string(), - mutation_type: if !Mutation::is_empty() { - Some(Mutation::type_name().to_string()) - } else { - None - }, - }, + query: QueryRoot { inner: query }, mutation, registry, data: Default::default(), @@ -132,6 +143,8 @@ impl<'a, Query, Mutation> QueryBuilder<'a, Query, Mutation> { parse_query(self.query_source).map_err(|err| QueryParseError(err.to_string()))?; let mut fragments = HashMap::new(); + check_rules(self.registry, &document)?; + for definition in &document.definitions { if let Definition::Fragment(fragment) = definition { fragments.insert(fragment.name.clone(), fragment); diff --git a/src/types/empty_mutation.rs b/src/types/empty_mutation.rs index feb060ed..8c6fb4cd 100644 --- a/src/types/empty_mutation.rs +++ b/src/types/empty_mutation.rs @@ -14,7 +14,7 @@ impl GQLType for GQLEmptyMutation { registry.create_type::(|_| registry::Type::Object { name: "EmptyMutation", description: None, - fields: Vec::new(), + fields: Default::default(), }) } } diff --git a/src/types/query_root.rs b/src/types/query_root.rs index 3dc87f90..589a49ff 100644 --- a/src/types/query_root.rs +++ b/src/types/query_root.rs @@ -1,15 +1,15 @@ use crate::model::{__Schema, __Type}; +use crate::registry::Type; use crate::{ registry, Context, ContextSelectionSet, ErrorWithPosition, GQLObject, GQLOutputValue, GQLType, QueryError, Result, Value, }; use graphql_parser::query::Field; use std::borrow::Cow; +use std::collections::HashMap; pub struct QueryRoot { pub inner: T, - pub query_type: String, - pub mutation_type: Option, } impl GQLType for QueryRoot { @@ -18,7 +18,44 @@ impl GQLType for QueryRoot { } fn create_type_info(registry: &mut registry::Registry) -> String { - T::create_type_info(registry) + let schema_type = __Schema::create_type_info(registry); + let root = T::create_type_info(registry); + if let Some(Type::Object { fields, .. }) = registry.types.get_mut(T::type_name().as_ref()) { + fields.insert( + "__schema", + registry::Field { + name: "__schema", + description: Some("Access the current type schema of this server."), + args: Default::default(), + ty: schema_type, + deprecation: None, + }, + ); + + fields.insert( + "__type", + registry::Field { + name: "__type", + description: Some("Request the type information of a single type."), + args: { + let mut args = HashMap::new(); + args.insert( + "name", + registry::InputValue { + name: "name", + description: None, + ty: "String!".to_string(), + default_value: None, + }, + ); + args + }, + ty: "__Type".to_string(), + deprecation: None, + }, + ); + } + root } } @@ -30,8 +67,6 @@ impl GQLObject for QueryRoot { return GQLOutputValue::resolve( &__Schema { registry: &ctx.registry, - query_type: &self.query_type, - mutation_type: self.mutation_type.as_deref(), }, &ctx_obj, ) diff --git a/src/validation/context.rs b/src/validation/context.rs new file mode 100644 index 00000000..365bd5ff --- /dev/null +++ b/src/validation/context.rs @@ -0,0 +1,36 @@ +use crate::error::RuleError; +use crate::registry::{Registry, Type}; +use graphql_parser::Pos; + +pub struct ValidatorContext<'a> { + pub registry: &'a Registry, + pub errors: Vec, + type_stack: Vec<&'a Type>, +} + +impl<'a> ValidatorContext<'a> { + pub fn new(registry: &'a Registry) -> Self { + Self { + registry, + errors: Default::default(), + type_stack: Default::default(), + } + } + + pub fn report_error>(&mut self, locations: Vec, msg: T) { + self.errors.push(RuleError { + locations, + message: msg.into(), + }) + } + + pub fn with_type)>(&mut self, ty: &'a Type, mut f: F) { + self.type_stack.push(ty); + f(self); + self.type_stack.pop(); + } + + pub fn parent_type(&self) -> &'a Type { + self.type_stack.last().unwrap() + } +} diff --git a/src/validation/mod.rs b/src/validation/mod.rs new file mode 100644 index 00000000..d6b40676 --- /dev/null +++ b/src/validation/mod.rs @@ -0,0 +1,23 @@ +use crate::error::RuleErrors; +use crate::registry::Registry; +use crate::validation::context::ValidatorContext; +use crate::validation::visitor::{visit, VisitorNil}; +use crate::Result; +use graphql_parser::query::Document; + +mod context; +mod rules; +mod utils; +mod visitor; + +pub fn check_rules(registry: &Registry, doc: &Document) -> Result<()> { + let mut ctx = ValidatorContext::new(registry); + let mut visitor = VisitorNil.with(rules::ArgumentsOfCorrectType::default()); + + visit(&mut visitor, &mut ctx, doc); + if !ctx.errors.is_empty() { + Err(RuleErrors { errors: ctx.errors }.into()) + } else { + Ok(()) + } +} diff --git a/src/validation/rules/arguments_of_correct_type.rs b/src/validation/rules/arguments_of_correct_type.rs new file mode 100644 index 00000000..71f9d549 --- /dev/null +++ b/src/validation/rules/arguments_of_correct_type.rs @@ -0,0 +1,55 @@ +use crate::registry::InputValue; +use crate::validation::context::ValidatorContext; +use crate::validation::utils::is_valid_input_value; +use crate::validation::visitor::Visitor; +use graphql_parser::query::Field; +use graphql_parser::schema::{Directive, Value}; +use graphql_parser::Pos; +use std::collections::HashMap; + +#[derive(Default)] +pub struct ArgumentsOfCorrectType<'a> { + current_args: Option<(&'a HashMap<&'static str, InputValue>, Pos)>, +} + +impl<'a> Visitor<'a> for ArgumentsOfCorrectType<'a> { + fn enter_directive(&mut self, ctx: &mut ValidatorContext<'a>, directive: &'a Directive) { + self.current_args = ctx + .registry + .directives + .get(&directive.name) + .map(|d| (&d.args, directive.position)); + } + + fn exit_directive(&mut self, _ctx: &mut ValidatorContext<'a>, _directive: &'a Directive) { + self.current_args = None; + } + + fn enter_argument(&mut self, ctx: &mut ValidatorContext<'a>, name: &str, value: &'a Value) { + if let Some((arg, pos)) = self + .current_args + .and_then(|(args, pos)| args.get(name).map(|input| (input, pos))) + { + if !is_valid_input_value(ctx.registry, &arg.ty, value) { + ctx.report_error( + vec![pos], + format!( + "Invalid value for argument \"{}\", expected type \"{}\"", + arg.name, arg.ty, + ), + ); + } + } + } + + fn enter_field(&mut self, ctx: &mut ValidatorContext<'a>, field: &'a Field) { + self.current_args = ctx + .parent_type() + .field_by_name(&field.name) + .map(|f| (&f.args, field.position)); + } + + fn exit_field(&mut self, _ctx: &mut ValidatorContext<'a>, _field: &'a Field) { + self.current_args = None; + } +} diff --git a/src/validation/rules/mod.rs b/src/validation/rules/mod.rs new file mode 100644 index 00000000..7b587371 --- /dev/null +++ b/src/validation/rules/mod.rs @@ -0,0 +1,3 @@ +mod arguments_of_correct_type; + +pub use arguments_of_correct_type::ArgumentsOfCorrectType; diff --git a/src/validation/utils.rs b/src/validation/utils.rs new file mode 100644 index 00000000..c886561d --- /dev/null +++ b/src/validation/utils.rs @@ -0,0 +1,46 @@ +use crate::registry::{Registry, Type, TypeInfo}; +use crate::Value; + +pub fn is_valid_input_value(registry: &Registry, type_name: &str, value: &Value) -> bool { + match TypeInfo::create(type_name) { + TypeInfo::NonNull(type_name) => match value { + Value::Null => false, + _ => is_valid_input_value(registry, type_name, value), + }, + TypeInfo::List(type_name) => match value { + Value::List(elems) => elems + .iter() + .all(|elem| is_valid_input_value(registry, type_name, elem)), + _ => false, + }, + TypeInfo::Type(type_name) => { + if let Some(ty) = registry.types.get(type_name) { + match ty { + Type::Scalar { is_valid, .. } => is_valid(value), + Type::Enum { enum_values, .. } => match value { + Value::Enum(name) => enum_values.contains_key(name.as_str()), + _ => false, + }, + Type::InputObject { input_fields, .. } => match value { + Value::Object(values) => { + for field in input_fields { + if !is_valid_input_value( + registry, + &field.ty, + values.get(field.name).unwrap_or(&Value::Null), + ) { + return false; + } + } + true + } + _ => false, + }, + _ => false, + } + } else { + unreachable!() + } + } + } +} diff --git a/src/validation/visitor.rs b/src/validation/visitor.rs new file mode 100644 index 00000000..667955be --- /dev/null +++ b/src/validation/visitor.rs @@ -0,0 +1,445 @@ +use crate::validation::context::ValidatorContext; +use graphql_parser::query::{ + Definition, Directive, Document, Field, FragmentDefinition, FragmentSpread, InlineFragment, + Name, OperationDefinition, Selection, SelectionSet, TypeCondition, Value, VariableDefinition, +}; + +pub trait Visitor<'a> { + fn enter_document(&mut self, _ctx: &mut ValidatorContext<'a>, _doc: &'a Document) {} + fn exit_document(&mut self, _ctx: &mut ValidatorContext<'a>, _doc: &'a Document) {} + + fn enter_operation_definition( + &mut self, + _ctx: &mut ValidatorContext<'a>, + _operation_definition: &'a OperationDefinition, + ) { + } + fn exit_operation_definition( + &mut self, + _ctx: &mut ValidatorContext<'a>, + _operation_definition: &'a OperationDefinition, + ) { + } + + fn enter_fragment_definition( + &mut self, + _ctx: &mut ValidatorContext<'a>, + _fragment_definition: &'a FragmentDefinition, + ) { + } + fn exit_fragment_definition( + &mut self, + _ctx: &mut ValidatorContext<'a>, + _fragment_definition: &'a FragmentDefinition, + ) { + } + + fn enter_variable_definition( + &mut self, + _ctx: &mut ValidatorContext<'a>, + _variable_definition: &'a VariableDefinition, + ) { + } + fn exit_variable_definition( + &mut self, + _ctx: &mut ValidatorContext<'a>, + _variable_definition: &'a VariableDefinition, + ) { + } + + fn enter_directive(&mut self, _ctx: &mut ValidatorContext<'a>, _directive: &'a Directive) {} + fn exit_directive(&mut self, _ctx: &mut ValidatorContext<'a>, _directive: &'a Directive) {} + + fn enter_argument(&mut self, _ctx: &mut ValidatorContext<'a>, _name: &str, _value: &'a Value) {} + fn exit_argument(&mut self, _ctx: &mut ValidatorContext<'a>, _name: &str, _value: &'a Value) {} + + fn enter_selection(&mut self, _ctx: &mut ValidatorContext<'a>, _selection: &'a Selection) {} + fn exit_selection(&mut self, _ctx: &mut ValidatorContext<'a>, _selection: &'a Selection) {} + + fn enter_field(&mut self, _ctx: &mut ValidatorContext<'a>, _field: &'a Field) {} + fn exit_field(&mut self, _ctx: &mut ValidatorContext<'a>, _field: &'a Field) {} + + fn enter_fragment_spread( + &mut self, + _ctx: &mut ValidatorContext<'a>, + _fragment_spread: &'a FragmentSpread, + ) { + } + fn exit_fragment_spread( + &mut self, + _ctx: &mut ValidatorContext<'a>, + _fragment_spread: &'a FragmentSpread, + ) { + } + + fn enter_inline_fragment( + &mut self, + _ctx: &mut ValidatorContext<'a>, + _inline_fragment: &'a InlineFragment, + ) { + } + fn exit_inline_fragment( + &mut self, + _ctx: &mut ValidatorContext<'a>, + _inline_fragment: &'a InlineFragment, + ) { + } +} + +pub struct VisitorNil; + +impl VisitorNil { + pub fn with(self, visitor: V) -> VisitorCons { + VisitorCons(visitor, self) + } +} + +pub struct VisitorCons(A, B); + +impl VisitorCons { + pub fn with(self, visitor: V) -> VisitorCons { + VisitorCons(visitor, self) + } +} + +impl<'a> Visitor<'a> for VisitorNil {} + +impl<'a, A, B> Visitor<'a> for VisitorCons +where + A: Visitor<'a> + 'a, + B: Visitor<'a> + 'a, +{ + fn enter_document(&mut self, ctx: &mut ValidatorContext<'a>, doc: &'a Document) { + self.0.enter_document(ctx, doc); + self.1.enter_document(ctx, doc); + } + + fn exit_document(&mut self, ctx: &mut ValidatorContext<'a>, doc: &'a Document) { + self.0.exit_document(ctx, doc); + self.1.exit_document(ctx, doc); + } + + fn enter_operation_definition( + &mut self, + ctx: &mut ValidatorContext<'a>, + operation_definition: &'a OperationDefinition, + ) { + self.0.enter_operation_definition(ctx, operation_definition); + self.1.enter_operation_definition(ctx, operation_definition); + } + + fn exit_operation_definition( + &mut self, + ctx: &mut ValidatorContext<'a>, + operation_definition: &'a OperationDefinition, + ) { + self.0.exit_operation_definition(ctx, operation_definition); + self.1.exit_operation_definition(ctx, operation_definition); + } + + fn enter_fragment_definition( + &mut self, + ctx: &mut ValidatorContext<'a>, + fragment_definition: &'a FragmentDefinition, + ) { + self.0.enter_fragment_definition(ctx, fragment_definition); + self.1.enter_fragment_definition(ctx, fragment_definition); + } + + fn exit_fragment_definition( + &mut self, + ctx: &mut ValidatorContext<'a>, + fragment_definition: &'a FragmentDefinition, + ) { + self.0.exit_fragment_definition(ctx, fragment_definition); + self.1.exit_fragment_definition(ctx, fragment_definition); + } + + fn enter_variable_definition( + &mut self, + ctx: &mut ValidatorContext<'a>, + variable_definition: &'a VariableDefinition, + ) { + self.0.enter_variable_definition(ctx, variable_definition); + self.1.enter_variable_definition(ctx, variable_definition); + } + + fn exit_variable_definition( + &mut self, + ctx: &mut ValidatorContext<'a>, + variable_definition: &'a VariableDefinition, + ) { + self.0.exit_variable_definition(ctx, variable_definition); + self.1.exit_variable_definition(ctx, variable_definition); + } + + fn enter_directive(&mut self, ctx: &mut ValidatorContext<'a>, directive: &'a Directive) { + self.0.enter_directive(ctx, directive); + self.1.enter_directive(ctx, directive); + } + + fn exit_directive(&mut self, ctx: &mut ValidatorContext<'a>, directive: &'a Directive) { + self.0.exit_directive(ctx, directive); + self.1.exit_directive(ctx, directive); + } + + fn enter_argument(&mut self, ctx: &mut ValidatorContext<'a>, name: &str, value: &'a Value) { + self.0.enter_argument(ctx, name, value); + self.1.enter_argument(ctx, name, value); + } + + fn exit_argument(&mut self, ctx: &mut ValidatorContext<'a>, name: &str, value: &'a Value) { + self.0.exit_argument(ctx, name, value); + self.1.exit_argument(ctx, name, value); + } + + fn enter_selection(&mut self, ctx: &mut ValidatorContext<'a>, selection: &'a Selection) { + self.0.enter_selection(ctx, selection); + self.1.enter_selection(ctx, selection); + } + + fn exit_selection(&mut self, ctx: &mut ValidatorContext<'a>, selection: &'a Selection) { + self.0.exit_selection(ctx, selection); + self.1.exit_selection(ctx, selection); + } + + fn enter_field(&mut self, ctx: &mut ValidatorContext<'a>, field: &'a Field) { + self.0.enter_field(ctx, field); + self.1.enter_field(ctx, field); + } + + fn exit_field(&mut self, ctx: &mut ValidatorContext<'a>, field: &'a Field) { + self.0.exit_field(ctx, field); + self.1.exit_field(ctx, field); + } + + fn enter_fragment_spread( + &mut self, + ctx: &mut ValidatorContext<'a>, + fragment_spread: &'a FragmentSpread, + ) { + self.0.enter_fragment_spread(ctx, fragment_spread); + self.1.enter_fragment_spread(ctx, fragment_spread); + } + + fn exit_fragment_spread( + &mut self, + ctx: &mut ValidatorContext<'a>, + fragment_spread: &'a FragmentSpread, + ) { + self.0.exit_fragment_spread(ctx, fragment_spread); + self.1.exit_fragment_spread(ctx, fragment_spread); + } + + fn enter_inline_fragment( + &mut self, + ctx: &mut ValidatorContext<'a>, + inline_fragment: &'a InlineFragment, + ) { + self.0.enter_inline_fragment(ctx, inline_fragment); + self.1.enter_inline_fragment(ctx, inline_fragment); + } + + fn exit_inline_fragment( + &mut self, + ctx: &mut ValidatorContext<'a>, + inline_fragment: &'a InlineFragment, + ) { + self.0.exit_inline_fragment(ctx, inline_fragment); + self.1.exit_inline_fragment(ctx, inline_fragment); + } +} + +pub fn visit<'a, V: Visitor<'a>>(v: &mut V, ctx: &mut ValidatorContext<'a>, doc: &'a Document) { + v.enter_document(ctx, doc); + visit_definitions(v, ctx, doc); + v.exit_document(ctx, doc); +} + +fn visit_definitions<'a, V: Visitor<'a>>( + v: &mut V, + ctx: &mut ValidatorContext<'a>, + doc: &'a Document, +) { + for d in &doc.definitions { + match d { + Definition::Operation(operation) => { + visit_operation_definition(v, ctx, operation); + } + Definition::Fragment(fragment) => { + let TypeCondition::On(name) = &fragment.type_condition; + if let Some(ty) = ctx.registry.types.get(name) { + ctx.with_type(ty, |ctx| visit_fragment_definition(v, ctx, fragment)); + } else { + ctx.report_error( + vec![fragment.position], + format!("Unknown type \"{}\".", name), + ); + } + } + } + } +} + +fn visit_operation_definition<'a, V: Visitor<'a>>( + v: &mut V, + ctx: &mut ValidatorContext<'a>, + operation: &'a OperationDefinition, +) { + v.enter_operation_definition(ctx, operation); + match operation { + OperationDefinition::SelectionSet(selection_set) => { + ctx.with_type(&ctx.registry.types[&ctx.registry.query_type], |ctx| { + visit_selection_set(v, ctx, selection_set) + }); + } + OperationDefinition::Query(query) => { + ctx.with_type(&ctx.registry.types[&ctx.registry.query_type], |ctx| { + visit_variable_definitions(v, ctx, &query.variable_definitions); + visit_directives(v, ctx, &query.directives); + visit_selection_set(v, ctx, &query.selection_set); + }); + } + OperationDefinition::Mutation(mutation) => { + if let Some(mutation_type) = &ctx.registry.mutation_type { + ctx.with_type(&ctx.registry.types[mutation_type], |ctx| { + visit_variable_definitions(v, ctx, &mutation.variable_definitions); + visit_directives(v, ctx, &mutation.directives); + visit_selection_set(v, ctx, &mutation.selection_set); + }); + } else { + ctx.report_error( + vec![mutation.position], + "Schema is not configured for mutations.", + ); + } + } + OperationDefinition::Subscription(subscription) => { + ctx.report_error(vec![subscription.position], "Not supported."); + // visit_variable_definitions(v, ctx, &subscription.variable_definitions); + // visit_directives(v, ctx, &subscription.directives); + // visit_selection_set(v, ctx, &subscription.selection_set); + } + } + v.exit_operation_definition(ctx, operation); +} + +fn visit_selection_set<'a, V: Visitor<'a>>( + v: &mut V, + ctx: &mut ValidatorContext<'a>, + selection_set: &'a SelectionSet, +) { + for selection in &selection_set.items { + visit_selection(v, ctx, selection); + } +} + +fn visit_selection<'a, V: Visitor<'a>>( + v: &mut V, + ctx: &mut ValidatorContext<'a>, + selection: &'a Selection, +) { + v.enter_selection(ctx, selection); + match selection { + Selection::Field(field) => visit_field(v, ctx, field), + Selection::FragmentSpread(fragment_spread) => { + visit_fragment_spread(v, ctx, fragment_spread) + } + Selection::InlineFragment(inline_fragment) => { + visit_inline_fragment(v, ctx, inline_fragment) + } + } + v.exit_selection(ctx, selection); +} + +fn visit_field<'a, V: Visitor<'a>>(v: &mut V, ctx: &mut ValidatorContext<'a>, field: &'a Field) { + v.enter_field(ctx, field); + visit_arguments(v, ctx, &field.arguments); + visit_directives(v, ctx, &field.directives); + + if let Some(schema_field) = ctx.parent_type().field_by_name(&field.name) { + ctx.with_type( + ctx.registry.get_basic_type(&schema_field.ty).unwrap(), + |ctx| { + visit_selection_set(v, ctx, &field.selection_set); + }, + ); + } else { + ctx.report_error( + vec![field.position], + format!( + "Cannot query field \"{}\" on type \"{}\".", + field.name, + ctx.parent_type().name() + ), + ); + } + + v.exit_field(ctx, field); +} + +fn visit_arguments<'a, V: Visitor<'a>>( + v: &mut V, + ctx: &mut ValidatorContext<'a>, + arguments: &'a Vec<(Name, Value)>, +) { + for (name, value) in arguments { + v.enter_argument(ctx, name, value); + v.exit_argument(ctx, name, value); + } +} + +fn visit_variable_definitions<'a, V: Visitor<'a>>( + v: &mut V, + ctx: &mut ValidatorContext<'a>, + variable_definitions: &'a Vec, +) { + for d in variable_definitions { + v.enter_variable_definition(ctx, d); + v.exit_variable_definition(ctx, d); + } +} + +fn visit_directives<'a, V: Visitor<'a>>( + v: &mut V, + ctx: &mut ValidatorContext<'a>, + directives: &'a Vec, +) { + for d in directives { + v.enter_directive(ctx, d); + visit_arguments(v, ctx, &d.arguments); + v.exit_directive(ctx, d); + } +} + +fn visit_fragment_definition<'a, V: Visitor<'a>>( + v: &mut V, + ctx: &mut ValidatorContext<'a>, + fragment: &'a FragmentDefinition, +) { + v.enter_fragment_definition(ctx, fragment); + visit_directives(v, ctx, &fragment.directives); + visit_selection_set(v, ctx, &fragment.selection_set); + v.exit_fragment_definition(ctx, fragment); +} + +fn visit_fragment_spread<'a, V: Visitor<'a>>( + v: &mut V, + ctx: &mut ValidatorContext<'a>, + fragment_spread: &'a FragmentSpread, +) { + v.enter_fragment_spread(ctx, fragment_spread); + visit_directives(v, ctx, &fragment_spread.directives); + v.exit_fragment_spread(ctx, fragment_spread); +} + +fn visit_inline_fragment<'a, V: Visitor<'a>>( + v: &mut V, + ctx: &mut ValidatorContext<'a>, + inline_fragment: &'a InlineFragment, +) { + v.enter_inline_fragment(ctx, inline_fragment); + visit_directives(v, ctx, &inline_fragment.directives); + visit_selection_set(v, ctx, &inline_fragment.selection_set); + v.exit_inline_fragment(ctx, inline_fragment); +}