diff --git a/derive/src/args.rs b/derive/src/args.rs index 6b753c22..e2ba5e86 100644 --- a/derive/src/args.rs +++ b/derive/src/args.rs @@ -255,6 +255,9 @@ pub struct ObjectField { pub struct DerivedField { pub name: Option, pub into: Option, + pub with: Option, + #[darling(default)] + pub owned: Option, } #[derive(FromDeriveInput)] diff --git a/derive/src/complex_object.rs b/derive/src/complex_object.rs index 8d282690..2926518a 100644 --- a/derive/src/complex_object.rs +++ b/derive/src/complex_object.rs @@ -40,6 +40,7 @@ pub fn generate( if derived.name.is_some() && derived.into.is_some() { let base_function_name = &method.sig.ident; let name = derived.name.unwrap(); + let with = derived.with; let into = Type::Verbatim( proc_macro2::TokenStream::from_str(&derived.into.unwrap()).unwrap(), ); @@ -93,11 +94,16 @@ pub fn generate( .into_iter(), ); - let new_block = quote!({ - { - ::std::result::Result::Ok(#self_ty::#base_function_name(&self, #other_atts).await?.into()) - } - }); + let new_block = match with { + Some(with) => quote!({ + ::std::result::Result::Ok(#with(#self_ty::#base_function_name(&self, #other_atts).await?)) + }), + None => quote!({ + { + ::std::result::Result::Ok(#self_ty::#base_function_name(&self, #other_atts).await?.into()) + } + }), + }; new_impl.block = syn::parse2::(new_block).expect("invalid block"); diff --git a/derive/src/object.rs b/derive/src/object.rs index 6cfe0b79..665a9976 100644 --- a/derive/src/object.rs +++ b/derive/src/object.rs @@ -55,6 +55,7 @@ pub fn generate( if derived.name.is_some() && derived.into.is_some() { let base_function_name = &method.sig.ident; let name = derived.name.unwrap(); + let with = derived.with; let into = Type::Verbatim( proc_macro2::TokenStream::from_str(&derived.into.unwrap()).unwrap(), ); @@ -108,11 +109,16 @@ pub fn generate( .into_iter(), ); - let new_block = quote!({ - { - ::std::result::Result::Ok(#self_ty::#base_function_name(&self, #other_atts).await?.into()) - } - }); + let new_block = match with { + Some(with) => quote!({ + ::std::result::Result::Ok(#with(#self_ty::#base_function_name(&self, #other_atts).await?)) + }), + None => quote!({ + { + ::std::result::Result::Ok(#self_ty::#base_function_name(&self, #other_atts).await?.into()) + } + }), + }; new_impl.block = syn::parse2::(new_block).expect("invalid block"); diff --git a/derive/src/simple_object.rs b/derive/src/simple_object.rs index 784f46d9..a170749d 100644 --- a/derive/src/simple_object.rs +++ b/derive/src/simple_object.rs @@ -3,16 +3,19 @@ use proc_macro::TokenStream; use quote::quote; use std::str::FromStr; use syn::ext::IdentExt; -use syn::{Error, Ident, Type}; +use syn::{Error, Ident, Path, Type}; use crate::args::{self, RenameRuleExt, RenameTarget, SimpleObjectField}; use crate::utils::{ gen_deprecation, generate_guards, get_crate_name, get_rustdoc, visible_fn, GeneratorResult, }; +#[derive(Debug)] struct DerivedFieldMetadata { ident: Ident, into: Type, + owned: Option, + with: Option, } struct SimpleObjectFieldGenerator<'a> { @@ -60,10 +63,25 @@ pub fn generate(object_args: &args::SimpleObject) -> GeneratorResult( proc_macro2::TokenStream::from_str(&derived.into.clone().unwrap()).unwrap(), - ); - let derived = DerivedFieldMetadata { ident: name, into }; + ) { + Ok(e) => e, + _ => { + return Err(Error::new_spanned( + &name, + "derived into must be a valid type.", + ) + .into()); + } + }; + + let derived = DerivedFieldMetadata { + ident: name, + into, + owned: derived.owned, + with: derived.with.clone(), + }; processed_fields.push(SimpleObjectFieldGenerator { field, @@ -83,7 +101,6 @@ pub fn generate(object_args: &args::SimpleObject) -> GeneratorResult return Err(Error::new_spanned(&ident, "All fields must be named.").into()), }; - let is_derived = derived.is_some(); let ident = if let Some(derived) = derived { &derived.ident } else { @@ -109,12 +126,19 @@ pub fn generate(object_args: &args::SimpleObject) -> GeneratorResult quote! { ::std::option::Option::None }, }; let vis = &field.vis; + let ty = if let Some(derived) = derived { &derived.into } else { &field.ty }; + let owned = if let Some(derived) = derived { + derived.owned.unwrap_or(field.owned) + } else { + field.owned + }; + let cache_control = { let public = field.cache_control.is_public(); let max_age = field.cache_control.max_age; @@ -152,7 +176,9 @@ pub fn generate(object_args: &args::SimpleObject) -> GeneratorResult quote! { &self.#base_ident }, @@ -161,16 +187,17 @@ pub fn generate(object_args: &args::SimpleObject) -> GeneratorResult quote! { - #block + block = match (derived, with_function) { + (Some(_), Some(with)) => quote! { + #with(#block) }, - true => quote! { + (Some(_), None) => quote! { ::std::convert::Into::into(#block) }, + (_, _) => block, }; - let ty = match !field.owned { + let ty = match !owned { true => quote! { &#ty }, false => quote! { #ty }, }; diff --git a/docs/en/src/derived_fields.md b/docs/en/src/derived_fields.md index 23c1fa63..6214a58a 100644 --- a/docs/en/src/derived_fields.md +++ b/docs/en/src/derived_fields.md @@ -59,3 +59,45 @@ type Query { duration_rfc3339(arg: String): DateRFC3339! } ``` + +## Wrapper types + +A derived field won't be able to manage everythings easily: without the specialization from the Rust language, you won't be able to implement specialized trait like: +``` +impl From> for Vec { + ... +} +``` + +So you wouldn't be able to generate derived fields for existing wrapper type structures like `Vec` or `Option`. But when you implement a `From for T` you should be able to derived a `From> for Vec` and a `From> for Option`. +We included a `with` parameter to help you define a function to call instead of using the `Into` trait implementation between wrapper structures. + + +### Example + +```rust +#[derive(Serialize, Deserialize, Clone)] +struct ValueDerived(String); + +#[derive(Serialize, Deserialize, Clone)] +struct ValueDerived2(String); + +scalar!(ValueDerived); +scalar!(ValueDerived2); + +impl From for ValueDerived2 { + fn from(value: ValueDerived) -> Self { + ValueDerived2(value.0) + } +} + +fn option_to_option>(value: Option) -> Option { + value.map(|x| x.into()) +} + +#[derive(SimpleObject)] +struct TestObj { + #[graphql(derived(owned, name = "value2", into = "Option", with = "option_to_option"))] + pub value1: Option, +} +``` diff --git a/src/lib.rs b/src/lib.rs index 123672b0..1d80fb24 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -295,6 +295,7 @@ pub type FieldResult = Result; /// |--------------|------------------------------------------|------------ |----------| /// | name | Generated derived field name | string | N | /// | into | Type to derived an into | string | Y | +/// | with | Function to apply to manage advanced use cases | string| Y | /// /// # Valid field return types /// @@ -467,6 +468,8 @@ pub use async_graphql_derive::Object; /// |--------------|------------------------------------------|------------ |----------| /// | name | Generated derived field name | string | N | /// | into | Type to derived an into | string | Y | +/// | owned | Field resolver return a ownedship value | bool | Y | +/// | with | Function to apply to manage advanced use cases | string| Y | /// /// /// # Examples @@ -531,6 +534,7 @@ pub use async_graphql_derive::SimpleObject; /// |--------------|------------------------------------------|------------ |----------| /// | name | Generated derived field name | string | N | /// | into | Type to derived an into | string | Y | +/// | with | Function to apply to manage advanced use cases | string| Y | /// /// # Examples /// diff --git a/tests/derived_field.rs b/tests/derived_field.rs index 304802e1..36724c81 100644 --- a/tests/derived_field.rs +++ b/tests/derived_field.rs @@ -46,6 +46,60 @@ pub async fn test_derived_field_object() { ); } +#[tokio::test] +pub async fn test_derived_field_object_with() { + use serde::{Deserialize, Serialize}; + + struct Query; + + #[derive(Serialize, Deserialize)] + struct ValueDerived(String); + + scalar!(ValueDerived); + + impl From for ValueDerived { + fn from(value: i32) -> Self { + ValueDerived(format!("{}", value)) + } + } + + fn option_to_option>(value: Option) -> Option { + value.map(|x| x.into()) + } + + #[Object] + impl Query { + #[graphql(derived( + name = "value2", + into = "Option", + with = "option_to_option" + ))] + async fn value1(&self, #[graphql(default = 100)] input: i32) -> Option { + Some(input) + } + } + + let query = "{ value1 value2 }"; + let schema = Schema::new(Query, EmptyMutation, EmptySubscription); + assert_eq!( + schema.execute(query).await.data, + value!({ + "value1": 100, + "value2": "100", + }) + ); + + let query = "{ value1(input: 1) value2(input: 2) }"; + let schema = Schema::new(Query, EmptyMutation, EmptySubscription); + assert_eq!( + schema.execute(query).await.data, + value!({ + "value1": 1, + "value2": "2", + }) + ); +} + #[tokio::test] pub async fn test_derived_field_simple_object() { use serde::{Deserialize, Serialize}; @@ -102,6 +156,104 @@ pub async fn test_derived_field_simple_object() { ); } +#[tokio::test] +pub async fn test_derived_field_simple_object_option() { + use serde::{Deserialize, Serialize}; + + struct Query; + + #[derive(Serialize, Deserialize, Clone)] + struct ValueDerived(String); + + #[derive(Serialize, Deserialize, Clone)] + struct ValueDerived2(String); + + scalar!(ValueDerived); + scalar!(ValueDerived2); + + impl From for ValueDerived2 { + fn from(value: ValueDerived) -> Self { + ValueDerived2(value.0) + } + } + + fn option_to_option>(value: Option) -> Option { + value.map(|x| x.into()) + } + + fn vec_to_vec>(value: Vec) -> Vec { + value.into_iter().map(|x| x.into()).collect() + } + + fn vecopt_to_vecopt>(value: Vec>) -> Vec> { + value.into_iter().map(|x| x.map(|opt| opt.into())).collect() + } + + fn optvec_to_optvec>(value: Option>) -> Option> { + value.map(|x| x.into_iter().map(|y| y.into()).collect()) + } + + #[derive(SimpleObject)] + struct TestObj { + #[graphql(derived( + owned, + name = "value2", + into = "Option", + with = "option_to_option" + ))] + pub value1: Option, + #[graphql(derived( + owned, + name = "value_vec_2", + into = "Vec", + with = "vec_to_vec" + ))] + pub value_vec_1: Vec, + #[graphql(derived( + owned, + name = "value_opt_vec_2", + into = "Option>", + with = "optvec_to_optvec" + ))] + pub value_opt_vec_1: Option>, + #[graphql(derived( + owned, + name = "value_vec_opt_2", + into = "Vec>", + with = "vecopt_to_vecopt" + ))] + pub value_vec_opt_1: Vec>, + } + + #[Object] + impl Query { + async fn test(&self) -> TestObj { + TestObj { + value1: Some(ValueDerived("Test".to_string())), + value_vec_1: vec![ValueDerived("Test".to_string())], + value_opt_vec_1: Some(vec![ValueDerived("Test".to_string())]), + value_vec_opt_1: vec![Some(ValueDerived("Test".to_string()))], + } + } + } + + let query = "{ test { value1 value2 valueVec1 valueVec2 valueOptVec1 valueOptVec2 } }"; + let schema = Schema::new(Query, EmptyMutation, EmptySubscription); + assert_eq!( + schema.execute(query).await.data, + value!({ + "test": { + "value1": "Test", + "value2": "Test", + "valueVec1": vec!["Test"], + "valueVec2": vec!["Test"], + "valueOptVec1": vec!["Test"], + "valueOptVec2": vec!["Test"], + } + }) + ); +} + #[tokio::test] pub async fn test_derived_field_complex_object() { use serde::{Deserialize, Serialize}; @@ -163,3 +315,68 @@ pub async fn test_derived_field_complex_object() { }) ); } + +#[tokio::test] +pub async fn test_derived_field_complex_object_derived() { + use serde::{Deserialize, Serialize}; + + #[derive(SimpleObject)] + #[graphql(complex)] + struct MyObj { + a: i32, + #[graphql(owned, derived(name = "f", into = "ValueDerived"))] + b: i32, + } + + #[derive(Serialize, Deserialize)] + struct ValueDerived(String); + + scalar!(ValueDerived); + + impl From for ValueDerived { + fn from(value: i32) -> Self { + ValueDerived(format!("{}", value)) + } + } + + fn option_to_option>(value: Option) -> Option { + value.map(|x| x.into()) + } + + #[ComplexObject] + impl MyObj { + async fn c(&self) -> i32 { + self.a + self.b + } + + #[graphql(derived(name = "e", into = "Option", with = "option_to_option"))] + async fn d(&self, v: i32) -> Option { + Some(self.a + self.b + v) + } + } + + struct Query; + + #[Object] + impl Query { + async fn obj(&self) -> MyObj { + MyObj { a: 10, b: 20 } + } + } + + let query = "{ obj { a b c d(v:100) e(v: 200) f } }"; + let schema = Schema::new(Query, EmptyMutation, EmptySubscription); + assert_eq!( + schema.execute(query).await.data, + value!({ + "obj": { + "a": 10, + "b": 20, + "c": 30, + "d": 130, + "e": "230", + "f": "20", + }, + }) + ); +}