diff --git a/derive/src/args.rs b/derive/src/args.rs index 9beffb47..6b753c22 100644 --- a/derive/src/args.rs +++ b/derive/src/args.rs @@ -143,6 +143,8 @@ pub struct SimpleObjectField { pub guard: Option, #[darling(default)] pub visible: Option, + #[darling(default, multiple)] + pub derived: Vec, } #[derive(FromDeriveInput)] diff --git a/derive/src/complex_object.rs b/derive/src/complex_object.rs index d88d8922..105d2ee8 100644 --- a/derive/src/complex_object.rs +++ b/derive/src/complex_object.rs @@ -1,7 +1,13 @@ use proc_macro::TokenStream; +use proc_macro2::Ident; use quote::quote; +use std::iter::FromIterator; +use std::str::FromStr; use syn::ext::IdentExt; -use syn::{Block, Error, ImplItem, ItemImpl, ReturnType}; +use syn::{ + punctuated::Punctuated, Block, Error, FnArg, ImplItem, ItemImpl, Pat, ReturnType, Token, Type, + TypeReference, +}; use crate::args::{self, ComplexityType, RenameRuleExt, RenameTarget}; use crate::output_type::OutputType; @@ -23,6 +29,86 @@ pub fn generate( let mut resolvers = Vec::new(); let mut schema_fields = Vec::new(); + // Computation of the derivated fields + let mut derived_impls = vec![]; + for item in &mut item_impl.items { + if let ImplItem::Method(method) = item { + let method_args: args::ObjectField = + parse_graphql_attrs(&method.attrs)?.unwrap_or_default(); + + for derived in method_args.derived { + if derived.name.is_some() && derived.into.is_some() { + let base_function_name = &method.sig.ident; + let name = derived.name.unwrap(); + let into = Type::Verbatim( + proc_macro2::TokenStream::from_str(&derived.into.unwrap()).unwrap(), + ); + + let mut new_impl = method.clone(); + new_impl.sig.ident = name; + new_impl.sig.output = + syn::parse2::(quote! { -> #crate_name::Result<#into> }) + .expect("invalid result type"); + + let should_create_context = new_impl + .sig + .inputs + .iter() + .nth(1) + .map(|x| { + if let FnArg::Typed(pat) = x { + if let Type::Reference(TypeReference { elem, .. }) = &*pat.ty { + if let Type::Path(path) = elem.as_ref() { + return path.path.segments.last().unwrap().ident + != "Context"; + } + } + }; + true + }) + .unwrap_or(true); + + if should_create_context { + let arg_ctx = syn::parse2::(quote! { ctx: &Context<'_> }) + .expect("invalid arg type"); + new_impl.sig.inputs.insert(1, arg_ctx); + } + + let other_atts: Punctuated = Punctuated::from_iter( + new_impl + .sig + .inputs + .iter() + .filter_map(|x| match x { + FnArg::Typed(pat) => match &*pat.pat { + Pat::Ident(ident) => Some(Ok(ident.ident.clone())), + _ => Some(Err(Error::new_spanned( + &pat, + "Must be a simple argument", + ) + .into())), + }, + FnArg::Receiver(_) => None, + }) + .collect::, Error>>()? + .into_iter(), + ); + + let new_block = quote!({ + { + ::std::result::Result::Ok(#self_ty::#base_function_name(&self, #other_atts).await?.into()) + } + }); + + new_impl.block = syn::parse2::(new_block).expect("invalid block"); + + derived_impls.push(ImplItem::Method(new_impl)); + } + } + } + } + item_impl.items.append(&mut derived_impls); + for item in &mut item_impl.items { if let ImplItem::Method(method) = item { let method_args: args::ObjectField = diff --git a/derive/src/simple_object.rs b/derive/src/simple_object.rs index d0665abc..1750e0b8 100644 --- a/derive/src/simple_object.rs +++ b/derive/src/simple_object.rs @@ -1,14 +1,25 @@ use darling::ast::Data; use proc_macro::TokenStream; use quote::quote; +use std::str::FromStr; use syn::ext::IdentExt; -use syn::Error; +use syn::{Error, Ident, Type}; -use crate::args::{self, RenameRuleExt, RenameTarget}; +use crate::args::{self, RenameRuleExt, RenameTarget, SimpleObjectField}; use crate::utils::{ gen_deprecation, generate_guards, get_crate_name, get_rustdoc, visible_fn, GeneratorResult, }; +struct DerivedFieldMetadata { + ident: Ident, + into: Type, +} + +struct SimpleObjectFieldGenerator<'a> { + field: &'a SimpleObjectField, + derived: Option, +} + pub fn generate(object_args: &args::SimpleObject) -> GeneratorResult { let crate_name = get_crate_name(object_args.internal); let ident = &object_args.ident; @@ -37,15 +48,48 @@ pub fn generate(object_args: &args::SimpleObject) -> GeneratorResult = vec![]; + + // Before processing the fields, we generate the derivated fields for field in &s.fields { + processed_fields.push(SimpleObjectFieldGenerator { + field: &field, + derived: None, + }); + + for derived in &field.derived { + if derived.name.is_some() && derived.into.is_some() { + let name = derived.name.clone().unwrap(); + let into = Type::Verbatim( + proc_macro2::TokenStream::from_str(&derived.into.clone().unwrap()).unwrap(), + ); + let derived = DerivedFieldMetadata { ident: name, into }; + + processed_fields.push(SimpleObjectFieldGenerator { + field: &field, + derived: Some(derived), + }) + } + } + } + + for SimpleObjectFieldGenerator { field, derived } in &processed_fields { if field.skip { continue; } - let ident = match &field.ident { + + let base_ident = match &field.ident { Some(ident) => ident, None => return Err(Error::new_spanned(&ident, "All fields must be named.").into()), }; + let is_derived = derived.is_some(); + let ident = if let Some(derived) = derived { + &derived.ident + } else { + base_ident + }; + let field_name = field.name.clone().unwrap_or_else(|| { object_args .rename_fields @@ -65,7 +109,11 @@ pub fn generate(object_args: &args::SimpleObject) -> GeneratorResult quote! { ::std::option::Option::None }, }; let vis = &field.vis; - let ty = &field.ty; + let ty = if let Some(derived) = derived { + &derived.into + } else { + &field.ty + }; let cache_control = { let public = field.cache_control.is_public(); @@ -104,23 +152,38 @@ pub fn generate(object_args: &args::SimpleObject) -> GeneratorResult quote! { + &self.#base_ident + }, + false => quote! { + ::std::clone::Clone::clone(&self.#base_ident) + }, + }; + + block = match is_derived { + false => quote! { + #block + }, + true => quote! { + ::std::convert::Into::into(#block) + }, + }; + + let ty = match !field.owned { + true => quote! { &#ty }, + false => quote! { #ty }, + }; + + getters.push( quote! { #[inline] #[allow(missing_docs)] - #vis async fn #ident(&self, ctx: &#crate_name::Context<'_>) -> #crate_name::Result<&#ty> { - ::std::result::Result::Ok(&self.#ident) + #vis async fn #ident(&self, ctx: &#crate_name::Context<'_>) -> #crate_name::Result<#ty> { + ::std::result::Result::Ok(#block) } } - } else { - quote! { - #[inline] - #[allow(missing_docs)] - #vis async fn #ident(&self, ctx: &#crate_name::Context<'_>) -> #crate_name::Result<#ty> { - ::std::result::Result::Ok(::std::clone::Clone::clone(&self.#ident)) - } - } - }); + ); resolvers.push(quote! { if ctx.item.node.name.node == #field_name { diff --git a/src/lib.rs b/src/lib.rs index 9b83da75..123672b0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -451,6 +451,7 @@ pub use async_graphql_derive::Object; /// | name | Field name | string | Y | /// | deprecation | Field deprecated | bool | Y | /// | deprecation | Field deprecation reason | string | Y | +/// | derived | Generate derived fields *[See also the Book](https://async-graphql.github.io/async-graphql/en/derived_fields.html).* | object | Y | /// | owned | Field resolver return a ownedship value | bool | Y | /// | cache_control | Field cache control | [`CacheControl`](struct.CacheControl.html) | Y | /// | external | Mark a field as owned by another service. This allows service A to use fields from service B while also knowing at runtime the types of that field. | bool | Y | @@ -460,6 +461,14 @@ pub use async_graphql_derive::Object; /// | visible | If `false`, it will not be displayed in introspection. *[See also the Book](https://async-graphql.github.io/async-graphql/en/visibility.html).* | bool | Y | /// | visible | Call the specified function. If the return value is `false`, it will not be displayed in introspection. | string | Y | /// +/// # Derived argument parameters +/// +/// | Attribute | description | Type | Optional | +/// |--------------|------------------------------------------|------------ |----------| +/// | name | Generated derived field name | string | N | +/// | into | Type to derived an into | string | Y | +/// +/// /// # Examples /// /// ```rust @@ -506,6 +515,7 @@ pub use async_graphql_derive::SimpleObject; /// | name | Field name | string | Y | /// | deprecation | Field deprecated | bool | Y | /// | deprecation | Field deprecation reason | string | Y | +/// | derived | Generate derived fields *[See also the Book](https://async-graphql.github.io/async-graphql/en/derived_fields.html).* | object | Y | /// | cache_control | Field cache control | [`CacheControl`](struct.CacheControl.html) | Y | /// | external | Mark a field as owned by another service. This allows service A to use fields from service B while also knowing at runtime the types of that field. | bool | Y | /// | provides | Annotate the expected returned fieldset from a field on a base type that is guaranteed to be selectable by the gateway. | string | Y | @@ -515,6 +525,13 @@ pub use async_graphql_derive::SimpleObject; /// | visible | Call the specified function. If the return value is `false`, it will not be displayed in introspection. | string | Y | /// | secret | Mark this field as a secret, it will not output the actual value in the log. | bool | Y | /// +/// # Derived argument parameters +/// +/// | Attribute | description | Type | Optional | +/// |--------------|------------------------------------------|------------ |----------| +/// | name | Generated derived field name | string | N | +/// | into | Type to derived an into | string | Y | +/// /// # Examples /// /// ```rust diff --git a/tests/derived_field.rs b/tests/derived_field.rs index ee1860e6..304802e1 100644 --- a/tests/derived_field.rs +++ b/tests/derived_field.rs @@ -1,7 +1,7 @@ use async_graphql::*; #[tokio::test] -pub async fn test_derived_field() { +pub async fn test_derived_field_object() { use serde::{Deserialize, Serialize}; struct Query; @@ -45,3 +45,121 @@ pub async fn test_derived_field() { }) ); } + +#[tokio::test] +pub async fn test_derived_field_simple_object() { + use serde::{Deserialize, Serialize}; + + struct Query; + + #[derive(Serialize, Deserialize)] + struct ValueDerived(String); + + scalar!(ValueDerived); + + impl From for ValueDerived { + fn from(value: i32) -> Self { + ValueDerived(format!("{}", value)) + } + } + + #[derive(SimpleObject)] + struct TestObj { + #[graphql(owned, derived(name = "value2", into = "ValueDerived"))] + pub value1: i32, + } + + #[Object] + impl Query { + async fn test(&self, #[graphql(default = 100)] input: i32) -> TestObj { + TestObj { value1: input } + } + } + + let query = "{ test { value1 value2 } }"; + let schema = Schema::new(Query, EmptyMutation, EmptySubscription); + assert_eq!( + schema.execute(query).await.data, + value!({ + "test": { + "value1": 100, + "value2": "100", + } + }) + ); + + let query = "{ test(input: 2) { value1 value2 }}"; + let schema = Schema::new(Query, EmptyMutation, EmptySubscription); + dbg!(schema.execute(query).await); + assert_eq!( + schema.execute(query).await.data, + value!({ + "test": { + "value1": 2, + "value2": "2", + } + }) + ); +} + +#[tokio::test] +pub async fn test_derived_field_complex_object() { + use serde::{Deserialize, Serialize}; + + #[derive(SimpleObject)] + #[graphql(complex)] + struct MyObj { + a: i32, + #[graphql(owned, derived(name = "f", into = "ValueDerived"))] + b: i32, + } + + #[derive(Serialize, Deserialize)] + struct ValueDerived(String); + + scalar!(ValueDerived); + + impl From for ValueDerived { + fn from(value: i32) -> Self { + ValueDerived(format!("{}", value)) + } + } + + #[ComplexObject] + impl MyObj { + async fn c(&self) -> i32 { + self.a + self.b + } + + #[graphql(derived(name = "e", into = "ValueDerived"))] + async fn d(&self, v: i32) -> i32 { + self.a + self.b + v + } + } + + struct Query; + + #[Object] + impl Query { + async fn obj(&self) -> MyObj { + MyObj { a: 10, b: 20 } + } + } + + let query = "{ obj { a b c d(v:100) e(v: 200) f } }"; + let schema = Schema::new(Query, EmptyMutation, EmptySubscription); + dbg!(schema.execute(query).await); + assert_eq!( + schema.execute(query).await.data, + value!({ + "obj": { + "a": 10, + "b": 20, + "c": 30, + "d": 130, + "e": "230", + "f": "20", + }, + }) + ); +}