diff --git a/async-graphql-derive/src/args.rs b/async-graphql-derive/src/args.rs index 8bf1cdb5..27c0893b 100644 --- a/async-graphql-derive/src/args.rs +++ b/async-graphql-derive/src/args.rs @@ -1,6 +1,6 @@ use crate::utils::parse_value; use graphql_parser::query::Value; -use syn::{Attribute, AttributeArgs, Error, Meta, NestedMeta, Result}; +use syn::{Attribute, AttributeArgs, Error, Meta, MetaList, NestedMeta, Result, Type}; #[derive(Debug)] pub struct Object { @@ -110,7 +110,7 @@ impl Argument { } else { return Err(Error::new_spanned( &nv.lit, - "Attribute 'deprecation' should be a string.", + "Attribute 'default' should be a string.", )); } } @@ -451,3 +451,257 @@ impl InputObject { }) } } + +#[derive(Debug)] +pub struct InterfaceFieldArgument { + pub name: String, + pub desc: Option, + pub ty: Type, + pub default: Option, +} + +impl InterfaceFieldArgument { + pub fn parse(ls: &MetaList) -> Result { + let mut name = None; + let mut desc = None; + let mut ty = None; + let mut default = None; + + for meta in &ls.nested { + match meta { + NestedMeta::Meta(Meta::NameValue(nv)) => { + if nv.path.is_ident("name") { + if let syn::Lit::Str(lit) = &nv.lit { + name = Some(lit.value()); + } else { + return Err(Error::new_spanned( + &nv.lit, + "Attribute 'name' should be a string.", + )); + } + } else if nv.path.is_ident("desc") { + if let syn::Lit::Str(lit) = &nv.lit { + desc = Some(lit.value()); + } else { + return Err(Error::new_spanned( + &nv.lit, + "Attribute 'desc' should be a string.", + )); + } + } else if nv.path.is_ident("type") { + if let syn::Lit::Str(lit) = &nv.lit { + if let Ok(ty2) = syn::parse_str::(&lit.value()) { + ty = Some(ty2); + } else { + return Err(Error::new_spanned(&lit, "Expect type")); + } + } else { + return Err(Error::new_spanned( + &nv.lit, + "Attribute 'type' should be a string.", + )); + } + } else if nv.path.is_ident("default") { + if let syn::Lit::Str(lit) = &nv.lit { + match parse_value(&lit.value()) { + Ok(Value::Variable(_)) => { + return Err(Error::new_spanned( + &nv.lit, + "The default cannot be a variable", + )) + } + Ok(value) => default = Some(value), + Err(err) => { + return Err(Error::new_spanned( + &nv.lit, + format!("Invalid value: {}", err), + )); + } + } + } else { + return Err(Error::new_spanned( + &nv.lit, + "Attribute 'default' should be a string.", + )); + } + } + } + _ => {} + } + } + + if name.is_none() { + return Err(Error::new_spanned(ls, "Missing name")); + } + + if ty.is_none() { + return Err(Error::new_spanned(ls, "Missing type")); + } + + Ok(Self { + name: name.unwrap(), + desc, + ty: ty.unwrap(), + default, + }) + } +} + +#[derive(Debug)] +pub struct InterfaceField { + pub name: String, + pub method: Option, + pub desc: Option, + pub ty: Type, + pub args: Vec, + pub deprecation: Option, +} + +impl InterfaceField { + pub fn parse(ls: &MetaList) -> Result { + let mut name = None; + let mut method = None; + let mut desc = None; + let mut ty = None; + let mut args = Vec::new(); + let mut deprecation = None; + + for meta in &ls.nested { + match meta { + NestedMeta::Meta(Meta::NameValue(nv)) => { + if nv.path.is_ident("name") { + if let syn::Lit::Str(lit) = &nv.lit { + name = Some(lit.value()); + } else { + return Err(Error::new_spanned( + &nv.lit, + "Attribute 'name' should be a string.", + )); + } + } else if nv.path.is_ident("desc") { + if let syn::Lit::Str(lit) = &nv.lit { + desc = Some(lit.value()); + } else { + return Err(Error::new_spanned( + &nv.lit, + "Attribute 'desc' should be a string.", + )); + } + } else if nv.path.is_ident("method") { + if let syn::Lit::Str(lit) = &nv.lit { + method = Some(lit.value()); + } else { + return Err(Error::new_spanned( + &nv.lit, + "Attribute 'method' should be a string.", + )); + } + } else if nv.path.is_ident("type") { + if let syn::Lit::Str(lit) = &nv.lit { + if let Ok(ty2) = syn::parse_str::(&lit.value()) { + ty = Some(ty2); + } else { + return Err(Error::new_spanned(&lit, "Expect type")); + } + } else { + return Err(Error::new_spanned( + &nv.lit, + "Attribute 'type' should be a string.", + )); + } + } else if nv.path.is_ident("deprecation") { + if let syn::Lit::Str(lit) = &nv.lit { + deprecation = Some(lit.value()); + } else { + return Err(Error::new_spanned( + &nv.lit, + "Attribute 'deprecation' should be a string.", + )); + } + } + } + NestedMeta::Meta(Meta::List(ls)) if ls.path.is_ident("arg") => { + args.push(InterfaceFieldArgument::parse(ls)?); + } + _ => {} + } + } + + if name.is_none() { + return Err(Error::new_spanned(ls, "Missing name")); + } + + if ty.is_none() { + return Err(Error::new_spanned(ls, "Missing type")); + } + + Ok(Self { + name: name.unwrap(), + method, + desc, + ty: ty.unwrap(), + args, + deprecation, + }) + } +} + +#[derive(Debug)] +pub struct Interface { + pub internal: bool, + pub name: Option, + pub desc: Option, + pub fields: Vec, +} + +impl Interface { + pub fn parse(args: AttributeArgs) -> Result { + let mut internal = false; + let mut name = None; + let mut desc = None; + let mut fields = Vec::new(); + + for arg in args { + match arg { + NestedMeta::Meta(Meta::Path(p)) if p.is_ident("internal") => { + internal = true; + } + NestedMeta::Meta(Meta::Path(p)) if p.is_ident("internal") => { + internal = true; + } + NestedMeta::Meta(Meta::NameValue(nv)) => { + if nv.path.is_ident("name") { + if let syn::Lit::Str(lit) = nv.lit { + name = Some(lit.value()); + } else { + return Err(Error::new_spanned( + &nv.lit, + "Attribute 'name' should be a string.", + )); + } + } else if nv.path.is_ident("desc") { + if let syn::Lit::Str(lit) = nv.lit { + desc = Some(lit.value()); + } else { + return Err(Error::new_spanned( + &nv.lit, + "Attribute 'desc' should be a string.", + )); + } + } + } + NestedMeta::Meta(Meta::List(ls)) if ls.path.is_ident("field") => { + fields.push(InterfaceField::parse(&ls)?); + } + _ => {} + } + } + + Ok(Self { + internal, + name, + desc, + fields, + }) + } +} diff --git a/async-graphql-derive/src/enum.rs b/async-graphql-derive/src/enum.rs index fbcdb1e8..cc0145a5 100644 --- a/async-graphql-derive/src/enum.rs +++ b/async-graphql-derive/src/enum.rs @@ -105,8 +105,8 @@ pub fn generate(enum_args: &args::Enum, input: &DeriveInput) -> Result) -> #crate_name::Result { - #crate_name::GQLEnum::resolve_enum(self) + async fn resolve(value: &Self, _: &#crate_name::ContextSelectionSet<'_>) -> #crate_name::Result { + #crate_name::GQLEnum::resolve_enum(value) } } }; diff --git a/async-graphql-derive/src/interface.rs b/async-graphql-derive/src/interface.rs new file mode 100644 index 00000000..1e6fc460 --- /dev/null +++ b/async-graphql-derive/src/interface.rs @@ -0,0 +1,234 @@ +use crate::args; +use crate::args::{InterfaceField, InterfaceFieldArgument}; +use crate::output_type::OutputType; +use crate::utils::{build_value_repr, get_crate_name}; +use proc_macro::TokenStream; +use proc_macro2::{Ident, Span}; +use quote::quote; +use syn::{Data, DeriveInput, Error, Fields, Result, Type}; + +// todo: Context params + +pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result { + let crate_name = get_crate_name(interface_args.internal); + let ident = &input.ident; + let generics = &input.generics; + let attrs = &input.attrs; + let vis = &input.vis; + let s = match &input.data { + Data::Struct(s) => s, + _ => return Err(Error::new_spanned(input, "It should be a struct.")), + }; + let fields = match &s.fields { + Fields::Unnamed(fields) => fields, + _ => return Err(Error::new_spanned(input, "All fields must be unnamed.")), + }; + let mut enum_names = Vec::new(); + let mut enum_items = Vec::new(); + let mut type_into_impls = Vec::new(); + let gql_typename = interface_args + .name + .clone() + .unwrap_or_else(|| ident.to_string()); + let desc = interface_args + .desc + .as_ref() + .map(|s| quote! {Some(#s)}) + .unwrap_or_else(|| quote! {None}); + let mut registry_types = Vec::new(); + let mut possible_types = Vec::new(); + + for field in &fields.unnamed { + if let Type::Path(p) = &field.ty { + let enum_name = &p.path.segments.last().unwrap().ident; + enum_names.push(enum_name); + enum_items.push(quote! { #enum_name(#p) }); + type_into_impls.push(quote! { + impl #generics From<#p> for #ident #generics { + fn from(obj: #p) -> Self { + #ident::#enum_name(obj) + } + } + }); + registry_types.push(quote! { + <#p as async_graphql::GQLType>::create_type_info(registry); + registry.add_implements(&<#p as GQLType>::type_name(), #gql_typename); + }); + possible_types.push(quote! { + <#p as async_graphql::GQLType>::type_name().to_string() + }); + } else { + return Err(Error::new_spanned(field, "Invalid type")); + } + } + + let mut methods = Vec::new(); + let mut schema_fields = Vec::new(); + let mut resolvers = Vec::new(); + + for InterfaceField { + name, + method: method_name, + desc, + ty, + args, + deprecation, + } in &interface_args.fields + { + let method_name = Ident::new( + method_name.as_ref().unwrap_or_else(|| &name), + Span::call_site(), + ); + let mut calls = Vec::new(); + let mut use_params = Vec::new(); + let mut decl_params = Vec::new(); + let mut get_params = Vec::new(); + let mut schema_args = Vec::new(); + + for InterfaceFieldArgument { + name, + desc, + ty, + default, + } in args + { + let ident = Ident::new(name, Span::call_site()); + decl_params.push(quote! { #ident: #ty }); + use_params.push(ident.clone()); + + let param_default = match &default { + Some(default) => { + let repr = build_value_repr(&crate_name, &default); + quote! {|| #repr } + } + None => quote! { || #crate_name::Value::Null }, + }; + get_params.push(quote! { + let #ident: #ty = ctx_field.param_value(#name, #param_default)?; + }); + + let desc = desc + .as_ref() + .map(|s| quote! {Some(#s)}) + .unwrap_or_else(|| quote! {None}); + let schema_default = default + .as_ref() + .map(|v| { + let s = v.to_string(); + quote! {Some(#s)} + }) + .unwrap_or_else(|| quote! {None}); + schema_args.push(quote! { + #crate_name::registry::InputValue { + name: #name, + description: #desc, + ty: <#ty as #crate_name::GQLType>::create_type_info(registry), + default_value: #schema_default, + } + }); + } + + for enum_name in &enum_names { + calls.push(quote! { + #ident::#enum_name(obj) => obj.#method_name(#(#use_params),*).await + }); + } + + methods.push(quote! { + async fn #method_name(&self, #(#decl_params),*) -> #ty { + match self { + #(#calls,)* + } + } + }); + + let desc = desc + .as_ref() + .map(|s| quote! {Some(#s)}) + .unwrap_or_else(|| quote! {None}); + let deprecation = deprecation + .as_ref() + .map(|s| quote! {Some(#s)}) + .unwrap_or_else(|| quote! {None}); + + let ty = OutputType::parse(ty)?; + let value_ty = ty.value_type(); + + schema_fields.push(quote! { + #crate_name::registry::Field { + name: #name, + description: #desc, + args: vec![#(#schema_args),*], + ty: <#value_ty as #crate_name::GQLType>::create_type_info(registry), + deprecation: #deprecation, + } + }); + + let resolve_obj = match &ty { + OutputType::Value(_) => quote! { + self.#method_name(#(#use_params),*).await + }, + OutputType::Result(_, _) => { + quote! { + self.#method_name(#(#use_params),*).await. + map_err(|err| err.with_position(field.position))? + } + } + }; + + resolvers.push(quote! { + if field.name.as_str() == #name { + #(#get_params)* + let ctx_obj = ctx.with_item(&field.selection_set); + return #crate_name::GQLOutputValue::resolve(&#resolve_obj, &ctx_obj).await. + map_err(|err| err.with_position(field.position).into()); + } + }); + } + + let expanded = quote! { + #(#attrs)* + #vis enum #ident #generics { #(#enum_items),* } + + #(#type_into_impls)* + + impl #generics #ident #generics { + #(#methods)* + } + + impl #generics #crate_name::GQLType for #ident #generics { + fn type_name() -> Cow<'static, str> { + Cow::Borrowed(#gql_typename) + } + + fn create_type_info(registry: &mut #crate_name::registry::Registry) -> String { + registry.create_type::(|registry| { + #(#registry_types)* + + async_graphql::registry::Type::Interface { + name: #gql_typename, + description: #desc, + fields: vec![#(#schema_fields),*], + possible_types: vec![#(#possible_types),*], + } + }) + } + } + + #[#crate_name::async_trait::async_trait] + impl #generics #crate_name::GQLObject for #ident #generics { + async fn resolve_field(&self, ctx: &#crate_name::Context<'_>, field: &#crate_name::graphql_parser::query::Field) -> #crate_name::Result<#crate_name::serde_json::Value> { + use #crate_name::ErrorWithPosition; + + #(#resolvers)* + + anyhow::bail!(#crate_name::QueryError::FieldNotFound { + field_name: field.name.clone(), + object: #gql_typename.to_string(), + } + .with_position(field.position)); + } + } + }; + Ok(expanded.into()) +} diff --git a/async-graphql-derive/src/lib.rs b/async-graphql-derive/src/lib.rs index e7cfde30..c9c52481 100644 --- a/async-graphql-derive/src/lib.rs +++ b/async-graphql-derive/src/lib.rs @@ -3,7 +3,9 @@ extern crate proc_macro; mod args; mod r#enum; mod input_object; +mod interface; mod object; +mod output_type; mod utils; use proc_macro::TokenStream; @@ -51,3 +53,17 @@ pub fn InputObject(args: TokenStream, input: TokenStream) -> TokenStream { Err(err) => err.to_compile_error().into(), } } + +#[proc_macro_attribute] +#[allow(non_snake_case)] +pub fn Interface(args: TokenStream, input: TokenStream) -> TokenStream { + let interface_args = match args::Interface::parse(parse_macro_input!(args as AttributeArgs)) { + Ok(interface_args) => interface_args, + Err(err) => return err.to_compile_error().into(), + }; + let input = parse_macro_input!(input as DeriveInput); + match interface::generate(&interface_args, &input) { + Ok(expanded) => expanded, + Err(err) => err.to_compile_error().into(), + } +} diff --git a/async-graphql-derive/src/object.rs b/async-graphql-derive/src/object.rs index 174d5b3a..307df6c7 100644 --- a/async-graphql-derive/src/object.rs +++ b/async-graphql-derive/src/object.rs @@ -1,16 +1,9 @@ use crate::args; +use crate::output_type::OutputType; use crate::utils::{build_value_repr, get_crate_name}; use proc_macro::TokenStream; use quote::quote; -use syn::{ - Error, FnArg, GenericArgument, ImplItem, ItemImpl, Pat, PathArguments, Result, ReturnType, - Type, TypeReference, -}; - -enum OutputType<'a> { - Value(&'a Type), - Result(&'a Type, &'a Type), -} +use syn::{Error, FnArg, ImplItem, ItemImpl, Pat, Result, ReturnType, Type, TypeReference}; pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result { let crate_name = get_crate_name(object_args.internal); @@ -58,49 +51,9 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< .map(|s| quote! {Some(#s)}) .unwrap_or_else(|| quote! {None}); let ty = match &method.sig.output { - ReturnType::Type(_, ty) => { - if let Type::Path(p) = ty.as_ref() { - if p.path.is_ident("Result") { - if let PathArguments::AngleBracketed(args) = - &p.path.segments[0].arguments - { - if args.args.len() == 0 { - return Err(Error::new_spanned( - &method.sig.output, - "Invalid type", - )); - } - let mut res = None; - for arg in &args.args { - if let GenericArgument::Type(value_ty) = arg { - res = Some(OutputType::Result(ty, value_ty)); - break; - } - } - if res.is_none() { - return Err(Error::new_spanned( - &method.sig.output, - "Invalid type", - )); - } - res.unwrap() - } else { - return Err(Error::new_spanned( - &method.sig.output, - "Invalid type", - )); - } - } else { - OutputType::Value(ty) - } - } else if let Type::Reference(_) = ty.as_ref() { - OutputType::Value(ty) - } else { - return Err(Error::new_spanned(&method.sig.output, "Invalid type")); - } - } + ReturnType::Type(_, ty) => OutputType::parse(ty)?, ReturnType::Default => { - return Err(Error::new_spanned(&method.sig.output, "Missing type")); + return Err(Error::new_spanned(&method.sig.output, "Missing type")) } }; @@ -194,7 +147,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< None => quote! { || #crate_name::Value::Null }, }; get_params.push(quote! { - let #ident: #ty = ctx_field.param_value(#name, #default)?; + let #ident: #ty = ctx.param_value(#name, #default)?; }); } @@ -213,7 +166,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< }); let ctx_field = match arg_ctx { - true => quote! { &ctx_field, }, + true => quote! { &ctx, }, false => quote! {}, }; @@ -233,10 +186,9 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< resolvers.push(quote! { if field.name.as_str() == #field_name { #(#get_params)* - let ctx_obj = ctx_field.with_item(&field.selection_set); - let value = #resolve_obj.resolve(&ctx_obj).await.map_err(|err| err.with_position(field.position))?; - result.insert(field.alias.clone().unwrap_or_else(|| field.name.clone()), value.into()); - continue; + let ctx_obj = ctx.with_item(&field.selection_set); + return #crate_name::GQLOutputValue::resolve(&#resolve_obj, &ctx_obj).await. + map_err(|err| err.with_position(field.position).into()); } }); @@ -257,49 +209,25 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< registry.create_type::(|registry| #crate_name::registry::Type::Object { name: #gql_typename, description: #desc, - fields: vec![#(#schema_fields),*] + fields: vec![#(#schema_fields),*], }) } } #[#crate_name::async_trait::async_trait] - impl #generics #crate_name::GQLOutputValue for #self_ty { - async fn resolve(&self, ctx: &#crate_name::ContextSelectionSet<'_>) -> #crate_name::Result<#crate_name::serde_json::Value> { + impl#generics #crate_name::GQLObject for #self_ty { + async fn resolve_field(&self, ctx: &#crate_name::Context<'_>, field: &#crate_name::graphql_parser::query::Field) -> #crate_name::Result<#crate_name::serde_json::Value> { use #crate_name::ErrorWithPosition; - if ctx.items.is_empty() { - #crate_name::anyhow::bail!(#crate_name::QueryError::MustHaveSubFields { - object: #gql_typename, - }.with_position(ctx.span.0)); - } + #(#resolvers)* - let mut result = #crate_name::serde_json::Map::::new(); - for field in ctx.fields(&*ctx) { - let field = field?; - let ctx_field = ctx.with_item(field); - if ctx_field.is_skip(&field.directives)? { - continue; - } - if field.name.as_str() == "__typename" { - let name = field.alias.clone().unwrap_or_else(|| field.name.clone()); - result.insert(name, #gql_typename.into()); - continue; - } - if field.name.as_str() == "__schema" { - continue; - } - #(#resolvers)* - #crate_name::anyhow::bail!(#crate_name::QueryError::FieldNotFound { - field_name: field.name.clone(), - object: #gql_typename, - }.with_position(field.position)); + anyhow::bail!(#crate_name::QueryError::FieldNotFound { + field_name: field.name.clone(), + object: #gql_typename.to_string(), } - - Ok(#crate_name::serde_json::Value::Object(result)) + .with_position(field.position)); } } - - impl#generics #crate_name::GQLObject for #self_ty {} }; Ok(expanded.into()) } diff --git a/async-graphql-derive/src/output_type.rs b/async-graphql-derive/src/output_type.rs new file mode 100644 index 00000000..3ca8451b --- /dev/null +++ b/async-graphql-derive/src/output_type.rs @@ -0,0 +1,47 @@ +use syn::{Error, GenericArgument, PathArguments, Result, Type}; + +pub enum OutputType<'a> { + Value(&'a Type), + Result(&'a Type, &'a Type), +} + +impl<'a> OutputType<'a> { + pub fn parse(input: &'a Type) -> Result { + let ty = if let Type::Path(p) = input { + if p.path.segments.last().unwrap().ident == "Result" { + if let PathArguments::AngleBracketed(args) = &p.path.segments[0].arguments { + if args.args.len() == 0 { + return Err(Error::new_spanned(input, "Invalid type")); + } + let mut res = None; + for arg in &args.args { + if let GenericArgument::Type(value_ty) = arg { + res = Some(OutputType::Result(input, value_ty)); + break; + } + } + if res.is_none() { + return Err(Error::new_spanned(input, "Invalid type")); + } + res.unwrap() + } else { + return Err(Error::new_spanned(input, "Invalid type")); + } + } else { + OutputType::Value(input) + } + } else if let Type::Reference(_) = input { + OutputType::Value(input) + } else { + return Err(Error::new_spanned(input, "Invalid type")); + }; + Ok(ty) + } + + pub fn value_type(&self) -> &Type { + match self { + OutputType::Value(ty) => ty, + OutputType::Result(_, ty) => ty, + } + } +} diff --git a/examples/starwars/mod.rs b/examples/starwars/mod.rs index 18428db9..da82d193 100644 --- a/examples/starwars/mod.rs +++ b/examples/starwars/mod.rs @@ -117,14 +117,6 @@ impl StarWars { } } - pub fn hero(&self, episode: Episode) -> usize { - if episode == Episode::EMPIRE { - self.luke - } else { - self.artoo - } - } - pub fn human(&self, id: &str) -> Option { self.human_data.get(id).cloned() } diff --git a/examples/starwars/model.rs b/examples/starwars/model.rs index 212bad56..5bbd4a9c 100644 --- a/examples/starwars/model.rs +++ b/examples/starwars/model.rs @@ -1,4 +1,6 @@ use super::StarWars; +use async_graphql::GQLType; +use std::borrow::Cow; #[async_graphql::Enum(desc = "One of the films in the Star Wars Trilogy")] #[allow(non_camel_case_types)] @@ -31,13 +33,16 @@ impl<'a> Human<'a> { } #[field(desc = "The friends of the human, or an empty list if they have none.")] - async fn friends(&self) -> Vec> { + async fn friends(&self) -> Vec> { self.starwars.chars[self.id] .friends .iter() - .map(|id| Human { - id: *id, - starwars: self.starwars, + .map(|id| { + Human { + id: *id, + starwars: self.starwars, + } + .into() }) .collect() } @@ -74,13 +79,16 @@ impl<'a> Droid<'a> { } #[field(desc = "The friends of the droid, or an empty list if they have none.")] - async fn friends(&self) -> Vec> { + async fn friends(&self) -> Vec> { self.starwars.chars[self.id] .friends .iter() - .map(|id| Droid { - id: *id, - starwars: self.starwars, + .map(|id| { + Droid { + id: *id, + starwars: self.starwars, + } + .into() }) .collect() } @@ -107,10 +115,19 @@ impl QueryRoot { desc = "If omitted, returns the hero of the whole saga. If provided, returns the hero of that particular episode." )] episode: Episode, - ) -> Human<'_> { - Human { - id: self.0.hero(episode), - starwars: &self.0, + ) -> Character<'_> { + if episode == Episode::EMPIRE { + Human { + id: self.0.luke, + starwars: &self.0, + } + .into() + } else { + Droid { + id: self.0.artoo, + starwars: &self.0, + } + .into() } } @@ -130,3 +147,13 @@ impl QueryRoot { }) } } + +//////////////////////////////////////////////////////////////////////////////////////////////// + +#[async_graphql::Interface( + field(name = "id", type = "&str"), + field(name = "name", type = "&str"), + field(name = "friends", type = "Vec>"), + field(name = "appearsIn", method = "appears_in", type = "&[Episode]") +)] +pub struct Character<'a>(Human<'a>, Droid<'a>); diff --git a/src/base.rs b/src/base.rs index e219858b..a90c3d9b 100644 --- a/src/base.rs +++ b/src/base.rs @@ -1,8 +1,7 @@ -use crate::{registry, ContextSelectionSet, Result}; -use graphql_parser::query::Value; +use crate::{registry, Context, ContextSelectionSet, Result}; +use graphql_parser::query::{Field, Value}; use std::borrow::Cow; -#[doc(hidden)] pub trait GQLType { fn type_name() -> Cow<'static, str>; @@ -13,25 +12,24 @@ pub trait GQLType { fn create_type_info(registry: &mut registry::Registry) -> String; } -#[doc(hidden)] pub trait GQLInputValue: GQLType + Sized { fn parse(value: &Value) -> Option; } -#[doc(hidden)] #[async_trait::async_trait] pub trait GQLOutputValue: GQLType { - async fn resolve(&self, ctx: &ContextSelectionSet<'_>) -> Result; + async fn resolve(value: &Self, ctx: &ContextSelectionSet<'_>) -> Result; } -#[doc(hidden)] +#[async_trait::async_trait] pub trait GQLObject: GQLOutputValue { fn is_empty() -> bool { return false; } + + async fn resolve_field(&self, ctx: &Context<'_>, field: &Field) -> Result; } -#[doc(hidden)] pub trait GQLInputObject: GQLInputValue {} pub trait GQLScalar: Sized + Send { @@ -43,28 +41,43 @@ pub trait GQLScalar: Sized + Send { fn to_json(&self) -> Result; } -impl GQLType for T { - fn type_name() -> Cow<'static, str> { - Cow::Borrowed(T::type_name()) - } +#[macro_export] +macro_rules! impl_scalar { + ($ty:ty) => { + impl crate::GQLType for $ty { + fn type_name() -> std::borrow::Cow<'static, str> { + std::borrow::Cow::Borrowed(<$ty as crate::GQLScalar>::type_name()) + } - fn create_type_info(registry: &mut registry::Registry) -> String { - registry.create_type::(|_| registry::Type::Scalar { - name: T::type_name().to_string(), - description: T::description(), - }) - } -} + fn create_type_info(registry: &mut crate::registry::Registry) -> String { + registry.create_type::<$ty, _>(|_| crate::registry::Type::Scalar { + name: <$ty as crate::GQLScalar>::type_name().to_string(), + description: <$ty>::description(), + }) + } + } -impl GQLInputValue for T { - fn parse(value: &Value) -> Option { - T::parse(value) - } + impl crate::GQLInputValue for $ty { + fn parse(value: &crate::Value) -> Option { + <$ty as crate::GQLScalar>::parse(value) + } + } + + #[async_trait::async_trait] + impl crate::GQLOutputValue for $ty { + async fn resolve( + value: &Self, + _: &crate::ContextSelectionSet<'_>, + ) -> crate::Result { + value.to_json() + } + } + }; } #[async_trait::async_trait] -impl GQLOutputValue for T { - async fn resolve(&self, _: &ContextSelectionSet<'_>) -> Result { - T::to_json(self) +impl GQLOutputValue for T { + async fn resolve(value: &Self, ctx: &ContextSelectionSet<'_>) -> Result { + crate::resolver::do_resolve(ctx, value).await } } diff --git a/src/context.rs b/src/context.rs index f8d8807f..d06782f9 100644 --- a/src/context.rs +++ b/src/context.rs @@ -120,7 +120,17 @@ impl<'a, T> Iterator for FieldIter<'a, T> { .into())); } } - Selection::InlineFragment(_) => {} + Selection::InlineFragment(inline_fragment) => { + let skip = match self.ctx.is_skip(&inline_fragment.directives) { + Ok(skip) => skip, + Err(err) => return Some(Err(err)), + }; + if skip { + continue; + } + // todo: check type + self.stack.push(inline_fragment.selection_set.items.iter()); + } } } else { self.stack.pop(); @@ -287,4 +297,8 @@ impl<'a> ContextBase<'a, &'a Field> { } } } + + pub fn result_name(&self) -> String { + self.item.alias.clone().unwrap_or_else(|| self.name.clone()) + } } diff --git a/src/error.rs b/src/error.rs index a7fbb202..aca795f1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -22,16 +22,13 @@ pub enum QueryError { }, #[error("Cannot query field \"{field_name}\" on type \"{object}\".")] - FieldNotFound { - field_name: String, - object: &'static str, - }, + FieldNotFound { field_name: String, object: String }, #[error("Unknown operation named \"{name}\"")] UnknownOperationNamed { name: String }, #[error("Type \"{object}\" must have a selection of subfields.")] - MustHaveSubFields { object: &'static str }, + MustHaveSubFields { object: String }, #[error("Schema is not configured for mutations.")] NotConfiguredMutations, diff --git a/src/lib.rs b/src/lib.rs index fbdf6207..a5474d0b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -57,6 +57,7 @@ mod base; mod context; mod error; mod model; +mod resolver; mod scalars; mod schema; mod types; @@ -72,7 +73,7 @@ pub use serde_json; pub mod http; -pub use async_graphql_derive::{Enum, InputObject, Object}; +pub use async_graphql_derive::{Enum, InputObject, Interface, Object}; pub use base::{GQLInputObject, GQLInputValue, GQLObject, GQLOutputValue, GQLScalar, GQLType}; pub use context::{Context, ContextBase, Variables}; pub use error::{ErrorWithPosition, PositionError, QueryError, QueryParseError}; diff --git a/src/model/type.rs b/src/model/type.rs index 0f0c757a..756c9564 100644 --- a/src/model/type.rs +++ b/src/model/type.rs @@ -123,6 +123,23 @@ impl<'a> __Type<'a> { }) .collect(), ) + } else if let TypeDetail::Simple(Type::Interface { fields, .. }) = &self.detail { + Some( + fields + .iter() + .filter(|field| { + if include_deprecated { + true + } else { + field.deprecation.is_none() + } + }) + .map(|field| __Field { + registry: self.registry, + field, + }) + .collect(), + ) } else { None } @@ -130,8 +147,16 @@ impl<'a> __Type<'a> { #[field] async fn interfaces(&self) -> Option>> { - if let TypeDetail::Simple(Type::Object { .. }) = &self.detail { - Some(vec![]) + if let TypeDetail::Simple(Type::Object { name, .. }) = &self.detail { + Some( + self.registry + .implements + .get(*name) + .unwrap_or(&Default::default()) + .iter() + .map(|ty| __Type::new(self.registry, ty)) + .collect(), + ) } else { None } @@ -139,7 +164,16 @@ impl<'a> __Type<'a> { #[field(name = "possibleTypes")] async fn possible_types(&self) -> Option>> { - None + if let TypeDetail::Simple(Type::Interface { possible_types, .. }) = &self.detail { + Some( + possible_types + .iter() + .map(|ty| __Type::new(self.registry, ty)) + .collect(), + ) + } else { + None + } } #[field(name = "enumValues")] diff --git a/src/registry.rs b/src/registry.rs index 1817af85..ce3724a4 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -1,5 +1,5 @@ use crate::{model, GQLType}; -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; pub struct InputValue { pub name: &'static str, @@ -36,7 +36,7 @@ pub enum Type { name: &'static str, description: Option<&'static str>, fields: Vec, - possible_types: Vec, + possible_types: Vec, }, Union { name: &'static str, @@ -66,6 +66,7 @@ pub struct Directive { pub struct Registry { pub types: HashMap, pub directives: Vec, + pub implements: HashMap>, } impl Registry { @@ -88,4 +89,17 @@ impl Registry { pub fn add_directive(&mut self, directive: Directive) { self.directives.push(directive); } + + pub fn add_implements(&mut self, ty: &str, interface: &str) { + self.implements + .entry(ty.to_string()) + .and_modify(|interfaces| { + interfaces.insert(interface.to_string()); + }) + .or_insert({ + let mut interfaces = HashSet::new(); + interfaces.insert(interface.to_string()); + interfaces + }); + } } diff --git a/src/resolver.rs b/src/resolver.rs new file mode 100644 index 00000000..927bf6b3 --- /dev/null +++ b/src/resolver.rs @@ -0,0 +1,83 @@ +use crate::{ContextSelectionSet, ErrorWithPosition, GQLObject, QueryError, Result}; +use graphql_parser::query::Selection; +use std::future::Future; +use std::pin::Pin; + +struct Resolver<'a, T> { + ctx: &'a ContextSelectionSet<'a>, + obj: &'a T, + result: &'a mut serde_json::Map, +} + +impl<'a, T: GQLObject + Send + Sync> Resolver<'a, T> { + pub fn resolve(&'a mut self) -> Pin> + 'a + Send>> { + Box::pin(async move { + if self.ctx.items.is_empty() { + anyhow::bail!(QueryError::MustHaveSubFields { + object: T::type_name().to_string(), + } + .with_position(self.ctx.span.0)); + } + + for selection in &self.ctx.item.items { + match selection { + Selection::Field(field) => { + let ctx_field = self.ctx.with_item(field); + if ctx_field.is_skip(&field.directives)? { + continue; + } + + if field.name.as_str() == "__typename" { + self.result + .insert(ctx_field.result_name(), T::type_name().to_string().into()); + continue; + } + + self.result.insert( + ctx_field.result_name(), + self.obj.resolve_field(&ctx_field, field).await?, + ); + } + Selection::FragmentSpread(fragment_spread) => { + if let Some(fragment) = + self.ctx.fragments.get(&fragment_spread.fragment_name) + { + Resolver { + ctx: &self.ctx.with_item(&fragment.selection_set), + obj: self.obj, + result: self.result, + } + .resolve() + .await?; + } else { + return Err(QueryError::UnknownFragment { + name: fragment_spread.fragment_name.clone(), + } + .into()); + } + } + Selection::InlineFragment(_) => {} + } + } + + Ok(()) + }) + } +} + +pub async fn do_resolve<'a, T: GQLObject + Send + Sync>( + ctx: &'a ContextSelectionSet<'a>, + root: &'a T, +) -> Result { + let mut result = serde_json::Map::::new(); + + Resolver { + ctx, + obj: root, + result: &mut result, + } + .resolve() + .await?; + + Ok(serde_json::Value::Object(result)) +} diff --git a/src/scalars/bool.rs b/src/scalars/bool.rs index 90a73dbe..7d3203ae 100644 --- a/src/scalars/bool.rs +++ b/src/scalars/bool.rs @@ -1,4 +1,4 @@ -use crate::{Result, GQLScalar, Value}; +use crate::{impl_scalar, GQLScalar, Result, Value}; impl GQLScalar for bool { fn type_name() -> &'static str { @@ -20,3 +20,5 @@ impl GQLScalar for bool { Ok((*self).into()) } } + +impl_scalar!(bool); diff --git a/src/scalars/datetime.rs b/src/scalars/datetime.rs index f1e4df96..412a4bd7 100644 --- a/src/scalars/datetime.rs +++ b/src/scalars/datetime.rs @@ -1,4 +1,4 @@ -use crate::{Result, GQLScalar, Value}; +use crate::{impl_scalar, GQLScalar, Result, Value}; use chrono::{DateTime, TimeZone, Utc}; impl GQLScalar for DateTime { @@ -17,3 +17,5 @@ impl GQLScalar for DateTime { Ok(self.to_rfc3339().into()) } } + +impl_scalar!(DateTime); diff --git a/src/scalars/floats.rs b/src/scalars/floats.rs index 249c2d57..b7c3b341 100644 --- a/src/scalars/floats.rs +++ b/src/scalars/floats.rs @@ -1,4 +1,4 @@ -use crate::{Result, GQLScalar, Value}; +use crate::{impl_scalar, GQLScalar, Result, Value}; macro_rules! impl_float_scalars { ($($ty:ty),*) => { @@ -24,6 +24,8 @@ macro_rules! impl_float_scalars { Ok((*self).into()) } } + + impl_scalar!($ty); )* }; } diff --git a/src/scalars/id.rs b/src/scalars/id.rs index f698d1b2..c34982c4 100644 --- a/src/scalars/id.rs +++ b/src/scalars/id.rs @@ -1,4 +1,4 @@ -use crate::{Result, GQLScalar, Value}; +use crate::{impl_scalar, GQLScalar, Result, Value}; use std::ops::{Deref, DerefMut}; #[derive(Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)] @@ -35,3 +35,5 @@ impl GQLScalar for ID { Ok(self.0.clone().into()) } } + +impl_scalar!(ID); diff --git a/src/scalars/integers.rs b/src/scalars/integers.rs index f56c5627..7cfaa8ec 100644 --- a/src/scalars/integers.rs +++ b/src/scalars/integers.rs @@ -1,4 +1,4 @@ -use crate::{Result, GQLScalar, Value}; +use crate::{impl_scalar, GQLScalar, Result, Value}; macro_rules! impl_integer_scalars { ($($ty:ty),*) => { @@ -23,6 +23,8 @@ macro_rules! impl_integer_scalars { Ok((*self).into()) } } + + impl_scalar!($ty); )* }; } diff --git a/src/scalars/string.rs b/src/scalars/string.rs index 0a5217f8..80d02219 100644 --- a/src/scalars/string.rs +++ b/src/scalars/string.rs @@ -1,5 +1,6 @@ -use crate::registry; -use crate::{ContextSelectionSet, GQLOutputValue, GQLScalar, GQLType, Result, Value}; +use crate::{ + impl_scalar, registry, ContextSelectionSet, GQLOutputValue, GQLScalar, GQLType, Result, Value, +}; use std::borrow::Cow; const STRING_DESC:&'static str = "The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text."; @@ -25,6 +26,8 @@ impl GQLScalar for String { } } +impl_scalar!(String); + impl<'a> GQLType for &'a str { fn type_name() -> Cow<'static, str> { Cow::Borrowed("String") @@ -40,7 +43,7 @@ impl<'a> GQLType for &'a str { #[async_trait::async_trait] impl<'a> GQLOutputValue for &'a str { - async fn resolve(&self, _: &ContextSelectionSet<'_>) -> Result { - Ok(self.to_string().into()) + async fn resolve(value: &Self, _: &ContextSelectionSet<'_>) -> Result { + Ok(value.to_string().into()) } } diff --git a/src/scalars/uuid.rs b/src/scalars/uuid.rs index 2b02329f..908d2772 100644 --- a/src/scalars/uuid.rs +++ b/src/scalars/uuid.rs @@ -1,4 +1,4 @@ -use crate::{Result, GQLScalar, Value}; +use crate::{impl_scalar, GQLScalar, Result, Value}; use uuid::Uuid; impl GQLScalar for Uuid { @@ -17,3 +17,5 @@ impl GQLScalar for Uuid { Ok(self.to_string().into()) } } + +impl_scalar!(Uuid); \ No newline at end of file diff --git a/src/schema.rs b/src/schema.rs index b1bd11c9..27a79d58 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -150,7 +150,7 @@ impl<'a, Query, Mutation> QueryBuilder<'a, Query, Mutation> { data: self.data, fragments: &fragments, }; - return self.query.resolve(&ctx).await; + return GQLOutputValue::resolve(self.query, &ctx).await; } } Definition::Operation(OperationDefinition::Query(query)) => { @@ -165,7 +165,7 @@ impl<'a, Query, Mutation> QueryBuilder<'a, Query, Mutation> { data: self.data, fragments: &fragments, }; - return self.query.resolve(&ctx).await; + return GQLOutputValue::resolve(self.query, &ctx).await; } } Definition::Operation(OperationDefinition::Mutation(mutation)) => { @@ -180,7 +180,7 @@ impl<'a, Query, Mutation> QueryBuilder<'a, Query, Mutation> { data: self.data, fragments: &fragments, }; - return self.mutation.resolve(&ctx).await; + return GQLOutputValue::resolve(self.mutation, &ctx).await; } } _ => {} diff --git a/src/types/empty_mutation.rs b/src/types/empty_mutation.rs index 75d68b5a..383a44ee 100644 --- a/src/types/empty_mutation.rs +++ b/src/types/empty_mutation.rs @@ -1,7 +1,5 @@ -use crate::{ - registry, ContextSelectionSet, ErrorWithPosition, GQLObject, GQLOutputValue, GQLType, - QueryError, Result, -}; +use crate::{registry, Context, GQLObject, GQLType, QueryError, Result}; +use graphql_parser::query::Field; use std::borrow::Cow; pub struct GQLEmptyMutation; @@ -21,14 +19,12 @@ impl GQLType for GQLEmptyMutation { } #[async_trait::async_trait] -impl GQLOutputValue for GQLEmptyMutation { - async fn resolve(&self, ctx: &ContextSelectionSet<'_>) -> Result { - anyhow::bail!(QueryError::NotConfiguredMutations.with_position(ctx.item.span.0)); - } -} - impl GQLObject for GQLEmptyMutation { fn is_empty() -> bool { return true; } + + async fn resolve_field(&self, _ctx: &Context<'_>, _name: &Field) -> Result { + return Err(QueryError::NotConfiguredMutations.into()); + } } diff --git a/src/types/enum.rs b/src/types/enum.rs index 29f7c567..d6e6e809 100644 --- a/src/types/enum.rs +++ b/src/types/enum.rs @@ -1,13 +1,11 @@ use crate::{GQLType, Result}; use graphql_parser::query::Value; -#[doc(hidden)] pub struct GQLEnumItem { pub name: &'static str, pub value: T, } -#[doc(hidden)] #[async_trait::async_trait] pub trait GQLEnum: GQLType + Sized + Eq + Send + Copy + Sized + 'static { fn items() -> &'static [GQLEnumItem]; diff --git a/src/types/list.rs b/src/types/list.rs index 16bc7c82..66b2fb92 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -33,10 +33,10 @@ impl GQLInputValue for Vec { #[async_trait::async_trait] impl GQLOutputValue for Vec { - async fn resolve(&self, ctx: &ContextSelectionSet<'_>) -> Result { + async fn resolve(value: &Self, ctx: &ContextSelectionSet<'_>) -> Result { let mut res = Vec::new(); - for item in self { - res.push(item.resolve(&ctx).await?); + for item in value { + res.push(GQLOutputValue::resolve(item, &ctx).await?); } Ok(res.into()) } @@ -54,10 +54,10 @@ impl GQLType for &[T] { #[async_trait::async_trait] impl GQLOutputValue for &[T] { - async fn resolve(&self, ctx: &ContextSelectionSet<'_>) -> Result { + async fn resolve(value: &Self, ctx: &ContextSelectionSet<'_>) -> Result { let mut res = Vec::new(); - for item in self.iter() { - res.push(item.resolve(&ctx).await?); + for item in value.iter() { + res.push(GQLOutputValue::resolve(item, &ctx).await?); } Ok(res.into()) } @@ -75,10 +75,10 @@ impl GQLType for &Vec { #[async_trait::async_trait] impl GQLOutputValue for &Vec { - async fn resolve(&self, ctx: &ContextSelectionSet<'_>) -> Result { + async fn resolve(value: &Self, ctx: &ContextSelectionSet<'_>) -> Result { let mut res = Vec::new(); - for item in self.iter() { - res.push(item.resolve(&ctx).await?); + for item in value.iter() { + res.push(GQLOutputValue::resolve(item, &ctx).await?); } Ok(res.into()) } diff --git a/src/types/optional.rs b/src/types/optional.rs index 72db02ac..52cd5ce0 100644 --- a/src/types/optional.rs +++ b/src/types/optional.rs @@ -27,9 +27,10 @@ impl GQLInputValue for Option { #[async_trait::async_trait] impl GQLOutputValue for Option { - async fn resolve(&self, ctx: &ContextSelectionSet<'_>) -> Result where { - if let Some(inner) = self { - inner.resolve(ctx).await + async fn resolve(value: &Self, ctx: &ContextSelectionSet<'_>) -> Result where + { + if let Some(inner) = value { + GQLOutputValue::resolve(inner, ctx).await } else { Ok(serde_json::Value::Null) } @@ -53,9 +54,10 @@ impl GQLType for &Option { #[async_trait::async_trait] impl GQLOutputValue for &Option { - async fn resolve(&self, ctx: &ContextSelectionSet<'_>) -> Result where { - if let Some(inner) = self { - inner.resolve(ctx).await + async fn resolve(value: &Self, ctx: &ContextSelectionSet<'_>) -> Result where + { + if let Some(inner) = value { + GQLOutputValue::resolve(inner, ctx).await } else { Ok(serde_json::Value::Null) } diff --git a/src/types/query_root.rs b/src/types/query_root.rs index 86bebb28..7c995017 100644 --- a/src/types/query_root.rs +++ b/src/types/query_root.rs @@ -1,6 +1,8 @@ -use crate::model::__Schema; -use crate::{registry, ContextSelectionSet, GQLOutputValue, GQLType, Result}; -use graphql_parser::query::Selection; +use crate::model::{__Schema, __Type}; +use crate::{ + registry, Context, ErrorWithPosition, GQLObject, GQLOutputValue, GQLType, Result, Value, +}; +use graphql_parser::query::Field; use std::borrow::Cow; pub struct QueryRoot { @@ -20,30 +22,34 @@ impl GQLType for QueryRoot { } #[async_trait::async_trait] -impl GQLOutputValue for QueryRoot { - async fn resolve(&self, ctx: &ContextSelectionSet<'_>) -> Result { - let mut res = self.inner.resolve(ctx).await?; - - if let serde_json::Value::Object(obj) = &mut res { - for item in &ctx.item.items { - if let Selection::Field(field) = item { - if field.name == "__schema" { - let ctx_obj = ctx.with_item(&field.selection_set); - obj.insert( - "__schema".to_string(), - __Schema { - registry: &ctx.registry, - query_type: &self.query_type, - mutation_type: self.mutation_type.as_deref(), - } - .resolve(&ctx_obj) - .await?, - ); - } - } - } +impl GQLObject for QueryRoot { + async fn resolve_field(&self, ctx: &Context<'_>, field: &Field) -> Result { + if field.name.as_str() == "__schema" { + let ctx_obj = ctx.with_item(&field.selection_set); + return GQLOutputValue::resolve( + &__Schema { + registry: &ctx.registry, + query_type: &self.query_type, + mutation_type: self.mutation_type.as_deref(), + }, + &ctx_obj, + ) + .await + .map_err(|err| err.with_position(field.position).into()); + } else if field.name.as_str() == "__type" { + let type_name: String = ctx.param_value("name", || Value::Null)?; + let ctx_obj = ctx.with_item(&field.selection_set); + return GQLOutputValue::resolve( + &ctx.registry + .types + .get(&type_name) + .map(|ty| __Type::new_simple(ctx.registry, ty)), + &ctx_obj, + ) + .await + .map_err(|err| err.with_position(field.position).into()); } - Ok(res) + return self.inner.resolve_field(ctx, field).await; } }