diff --git a/CHANGELOG.md b/CHANGELOG.md index 2f939d77..1757f3be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +# [3.0.33] 2022-3-4 + +- Add support for oneof field on object. [#766](https://github.com/async-graphql/async-graphql/issues/766) + # [3.0.32] 2022-3-4 - Bump `Actix-web` from `4.0.0-rc.3` to `4.0.1`. diff --git a/Cargo.toml b/Cargo.toml index 01335a55..95ba08b8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,7 +24,6 @@ decimal = ["rust_decimal"] cbor = ["serde_cbor"] chrono-duration = ["chrono", "iso8601-duration"] password-strength-validator = ["zxcvbn"] -unstable_oneof = ["async-graphql-derive/unstable_oneof"] [dependencies] async-graphql-derive = { path = "derive", version = "3.0.32" } diff --git a/README.md b/README.md index d615fe1b..acbc7df0 100644 --- a/README.md +++ b/README.md @@ -81,7 +81,6 @@ This crate offers the following features, all of which are not activated by defa - `smol_str`: Integrate with the [`smol_str` crate](https://crates.io/crates/smol_str). - `hashbrown`: Integrate with the [`hashbrown` crate](https://github.com/rust-lang/hashbrown). - `time`: Integrate with the [`time` crate](https://github.com/time-rs/time). -- `unstable_oneof`: Enable the `OneofObject` macro to define the oneof input object. ## Apollo Studio diff --git a/derive/Cargo.toml b/derive/Cargo.toml index f5f3ada2..c31188a9 100644 --- a/derive/Cargo.toml +++ b/derive/Cargo.toml @@ -14,9 +14,6 @@ categories = ["network-programming", "asynchronous"] [lib] proc-macro = true -[features] -unstable_oneof = [] - [dependencies] async-graphql-parser = { path = "../parser", version = "3.0.32" } proc-macro2 = "1.0.24" diff --git a/derive/src/args.rs b/derive/src/args.rs index e519e1f7..e817206e 100644 --- a/derive/src/args.rs +++ b/derive/src/args.rs @@ -271,6 +271,7 @@ pub struct ObjectField { #[darling(default, multiple)] pub derived: Vec, pub flatten: bool, + pub oneof: bool, } #[derive(FromMeta, Default, Clone)] @@ -401,7 +402,6 @@ pub struct InputObject { pub complex: bool, } -#[cfg(feature = "unstable_oneof")] #[derive(FromVariant)] #[darling(attributes(graphql), forward_attrs(doc))] pub struct OneofObjectField { @@ -419,7 +419,6 @@ pub struct OneofObjectField { pub secret: bool, } -#[cfg(feature = "unstable_oneof")] #[derive(FromDeriveInput)] #[darling(attributes(graphql), forward_attrs(doc))] pub struct OneofObject { @@ -459,7 +458,7 @@ pub struct InterfaceFieldArgument { #[derive(FromMeta)] pub struct InterfaceField { - pub name: String, + pub name: SpannedValue, #[darling(rename = "type")] pub ty: LitStr, #[darling(default)] @@ -478,6 +477,8 @@ pub struct InterfaceField { pub requires: Option, #[darling(default)] pub visible: Option, + #[darling(default)] + pub oneof: bool, } #[derive(FromVariant)] @@ -553,6 +554,7 @@ pub struct SubscriptionField { pub guard: Option>, pub visible: Option, pub complexity: Option, + pub oneof: bool, } #[derive(FromField)] @@ -754,6 +756,7 @@ pub struct ComplexObjectField { #[darling(multiple)] pub derived: Vec, pub flatten: bool, + pub oneof: bool, } #[derive(FromMeta, Default)] diff --git a/derive/src/complex_object.rs b/derive/src/complex_object.rs index 12d23e43..393cb1be 100644 --- a/derive/src/complex_object.rs +++ b/derive/src/complex_object.rs @@ -126,7 +126,7 @@ pub fn generate( if method_args.flatten { // Only used to inject the context placeholder if required. - extract_input_args(&crate_name, method)?; + extract_input_args::(&crate_name, method)?; let ty = match &method.sig.output { ReturnType::Type(_, ty) => OutputType::parse(ty)?, @@ -190,7 +190,120 @@ pub fn generate( } }; - let args = extract_input_args(&crate_name, method)?; + let mut args = extract_input_args::(&crate_name, method)?; + let mut schema_args = Vec::new(); + let mut use_params = Vec::new(); + let mut get_params = Vec::new(); + let mut is_oneof_field = false; + + if method_args.oneof { + is_oneof_field = true; + + if args.len() != 1 { + return Err(Error::new_spanned( + &method, + "The `oneof` field requires exactly one argument.", + ) + .into()); + } + let (ident, ty, argument) = args.pop().unwrap(); + schema_args.push(quote! { + #crate_name::static_assertions::assert_impl_one!(#ty: #crate_name::OneofObjectType); + if let #crate_name::registry::MetaType::InputObject { input_fields, .. } = registry.create_fake_input_type::<#ty>() { + args.extend(input_fields); + } + }); + use_params.push(quote! { #ident }); + + let validators = argument + .validator + .clone() + .unwrap_or_default() + .create_validators( + &crate_name, + quote!(&#ident), + quote!(#ty), + Some(quote!(.map_err(|err| err.into_server_error(__pos)))), + )?; + get_params.push(quote! { + #[allow(non_snake_case, unused_variables)] + let (__pos, #ident) = ctx.oneof_param_value::<#ty>()?; + #validators + }); + } else { + for ( + ident, + ty, + args::Argument { + name, + desc, + default, + default_with, + validator, + visible, + secret, + .. + }, + ) in &args + { + let name = name.clone().unwrap_or_else(|| { + object_args + .rename_args + .rename(ident.ident.unraw().to_string(), RenameTarget::Argument) + }); + let desc = desc + .as_ref() + .map(|s| quote! {::std::option::Option::Some(#s)}) + .unwrap_or_else(|| quote! {::std::option::Option::None}); + let default = generate_default(default, default_with)?; + let schema_default = default + .as_ref() + .map(|value| { + quote! { + ::std::option::Option::Some(::std::string::ToString::to_string( + &<#ty as #crate_name::InputType>::to_value(&#value) + )) + } + }) + .unwrap_or_else(|| quote! {::std::option::Option::None}); + + let visible = visible_fn(visible); + schema_args.push(quote! { + args.insert(::std::borrow::ToOwned::to_owned(#name), #crate_name::registry::MetaInputValue { + name: #name, + description: #desc, + ty: <#ty as #crate_name::InputType>::create_type_info(registry), + default_value: #schema_default, + visible: #visible, + is_secret: #secret, + }); + }); + + let param_ident = &ident.ident; + use_params.push(quote! { #param_ident }); + + let default = match default { + Some(default) => { + quote! { ::std::option::Option::Some(|| -> #ty { #default }) } + } + None => quote! { ::std::option::Option::None }, + }; + + let validators = validator.clone().unwrap_or_default().create_validators( + &crate_name, + quote!(&#ident), + quote!(ty), + Some(quote!(.map_err(|err| err.into_server_error(__pos)))), + )?; + + get_params.push(quote! { + #[allow(non_snake_case)] + let (__pos, #ident) = ctx.param_value::<#ty>(#name, #default)?; + #validators + }); + } + } + let ty = match &method.sig.output { ReturnType::Type(_, ty) => OutputType::parse(ty)?, ReturnType::Default => { @@ -201,83 +314,6 @@ pub fn generate( .into()) } }; - - let mut schema_args = Vec::new(); - let mut use_params = Vec::new(); - let mut get_params = Vec::new(); - - for ( - ident, - ty, - args::Argument { - name, - desc, - default, - default_with, - validator, - visible, - secret, - .. - }, - ) in &args - { - let name = name.clone().unwrap_or_else(|| { - object_args - .rename_args - .rename(ident.ident.unraw().to_string(), RenameTarget::Argument) - }); - let desc = desc - .as_ref() - .map(|s| quote! {::std::option::Option::Some(#s)}) - .unwrap_or_else(|| quote! {::std::option::Option::None}); - let default = generate_default(default, default_with)?; - let schema_default = default - .as_ref() - .map(|value| { - quote! { - ::std::option::Option::Some(::std::string::ToString::to_string( - &<#ty as #crate_name::InputType>::to_value(&#value) - )) - } - }) - .unwrap_or_else(|| quote! {::std::option::Option::None}); - - let visible = visible_fn(visible); - schema_args.push(quote! { - args.insert(#name, #crate_name::registry::MetaInputValue { - name: #name, - description: #desc, - ty: <#ty as #crate_name::InputType>::create_type_info(registry), - default_value: #schema_default, - visible: #visible, - is_secret: #secret, - }); - }); - - let param_ident = &ident.ident; - use_params.push(quote! { #param_ident }); - - let default = match default { - Some(default) => { - quote! { ::std::option::Option::Some(|| -> #ty { #default }) } - } - None => quote! { ::std::option::Option::None }, - }; - - let validators = validator.clone().unwrap_or_default().create_validators( - &crate_name, - quote!(&#ident), - quote!(ty), - Some(quote!(.map_err(|err| err.into_server_error(__pos)))), - )?; - - get_params.push(quote! { - #[allow(non_snake_case)] - let (__pos, #ident) = ctx.param_value::<#ty>(#name, #default)?; - #validators - }); - } - let schema_ty = ty.value_type(); let visible = visible_fn(&method_args.visible); @@ -350,6 +386,7 @@ pub fn generate( requires: #requires, visible: #visible, compute_complexity: #complexity, + oneof: #is_oneof_field, })); }); diff --git a/derive/src/directive.rs b/derive/src/directive.rs index e7c3e916..322b7037 100644 --- a/derive/src/directive.rs +++ b/derive/src/directive.rs @@ -82,7 +82,7 @@ pub fn generate( let visible = visible_fn(&visible); schema_args.push(quote! { - args.insert(#name, #crate_name::registry::MetaInputValue { + args.insert(::std::borrow::ToOwned::to_owned(#name), #crate_name::registry::MetaInputValue { name: #name, description: #desc, ty: <#arg_ty as #crate_name::InputType>::create_type_info(registry), diff --git a/derive/src/interface.rs b/derive/src/interface.rs index 5a224807..a8b5d0f4 100644 --- a/derive/src/interface.rs +++ b/derive/src/interface.rs @@ -136,8 +136,10 @@ pub fn generate(interface_args: &args::Interface) -> GeneratorResult GeneratorResult GeneratorResult }); use_params.push(quote! { ctx }); - for InterfaceFieldArgument { - name, - desc, - ty, - default, - default_with, - visible, - secret, - } in args - { + let mut is_oneof_field = false; + + if *oneof { + is_oneof_field = true; + + if args.len() != 1 { + return Err( + Error::new(name_span, "The `oneof` field requires one parameter.").into(), + ); + } + + let InterfaceFieldArgument { name, ty, .. } = &args[0]; let ident = Ident::new(name, Span::call_site()); - let name = interface_args - .rename_args - .rename(name, RenameTarget::Argument); let ty = match syn::parse_str::(&ty.value()) { Ok(ty) => ty, Err(_) => return Err(Error::new_spanned(&ty, "Expect type").into()), }; + decl_params.push(quote! { #ident: #ty }); use_params.push(quote! { #ident }); - - let default = generate_default(default, default_with)?; - let get_default = match &default { - Some(default) => quote! { ::std::option::Option::Some(|| -> #ty { #default }) }, - None => quote! { ::std::option::Option::None }, - }; get_params.push(quote! { - let (_, #ident) = ctx.param_value::<#ty>(#name, #get_default)?; + #[allow(non_snake_case, unused_variables)] + let (_, #ident) = ctx.oneof_param_value::<#ty>()?; }); - - let desc = desc - .as_ref() - .map(|s| quote! {::std::option::Option::Some(#s)}) - .unwrap_or_else(|| quote! {::std::option::Option::None}); - let schema_default = default - .as_ref() - .map(|value| { - quote! { - ::std::option::Option::Some(::std::string::ToString::to_string( - &<#ty as #crate_name::InputType>::to_value(&#value) - )) - } - }) - .unwrap_or_else(|| quote! {::std::option::Option::None}); - let visible = visible_fn(visible); schema_args.push(quote! { - args.insert(#name, #crate_name::registry::MetaInputValue { - name: #name, - description: #desc, - ty: <#ty as #crate_name::InputType>::create_type_info(registry), - default_value: #schema_default, - visible: #visible, - is_secret: #secret, - }); + #crate_name::static_assertions::assert_impl_one!(#ty: #crate_name::OneofObjectType); + if let #crate_name::registry::MetaType::InputObject { input_fields, .. } = registry.create_fake_input_type::<#ty>() { + args.extend(input_fields); + } }); + } else { + for InterfaceFieldArgument { + name, + desc, + ty, + default, + default_with, + visible, + secret, + } in args + { + let ident = Ident::new(name, Span::call_site()); + let name = interface_args + .rename_args + .rename(name, RenameTarget::Argument); + let ty = match syn::parse_str::(&ty.value()) { + Ok(ty) => ty, + Err(_) => return Err(Error::new_spanned(&ty, "Expect type").into()), + }; + decl_params.push(quote! { #ident: #ty }); + use_params.push(quote! { #ident }); + + let default = generate_default(default, default_with)?; + let get_default = match &default { + Some(default) => quote! { ::std::option::Option::Some(|| -> #ty { #default }) }, + None => quote! { ::std::option::Option::None }, + }; + get_params.push(quote! { + let (_, #ident) = ctx.param_value::<#ty>(#name, #get_default)?; + }); + + let desc = desc + .as_ref() + .map(|s| quote! {::std::option::Option::Some(#s)}) + .unwrap_or_else(|| quote! {::std::option::Option::None}); + let schema_default = default + .as_ref() + .map(|value| { + quote! { + ::std::option::Option::Some(::std::string::ToString::to_string( + &<#ty as #crate_name::InputType>::to_value(&#value) + )) + } + }) + .unwrap_or_else(|| quote! {::std::option::Option::None}); + let visible = visible_fn(visible); + schema_args.push(quote! { + args.insert(::std::borrow::ToOwned::to_owned(#name), #crate_name::registry::MetaInputValue { + name: #name, + description: #desc, + ty: <#ty as #crate_name::InputType>::create_type_info(registry), + default_value: #schema_default, + visible: #visible, + is_secret: #secret, + }); + }); + } } for enum_name in &enum_names { @@ -275,6 +309,7 @@ pub fn generate(interface_args: &args::Interface) -> GeneratorResult TokenStream { } } -#[cfg(feature = "unstable_oneof")] -#[cfg_attr(docsrs, doc(cfg(feature = "unstable_oneof")))] #[proc_macro_derive(OneofObject, attributes(graphql))] pub fn derive_oneof_object(input: TokenStream) -> TokenStream { let object_args = diff --git a/derive/src/object.rs b/derive/src/object.rs index ef24aecf..4bcfdc35 100644 --- a/derive/src/object.rs +++ b/derive/src/object.rs @@ -141,7 +141,15 @@ pub fn generate( return Err(Error::new_spanned(&method, "Must be asynchronous").into()); } - let args = extract_input_args(&crate_name, method)?; + if method_args.oneof { + return Err(Error::new_spanned( + &method, + "The `entity` and `oneof` attributes cannot be specified at the same time.", + ) + .into()); + } + + let args = extract_input_args::(&crate_name, method)?; let ty = match &method.sig.output { ReturnType::Type(_, ty) => OutputType::parse(ty)?, @@ -262,7 +270,7 @@ pub fn generate( if method_args.flatten { // Only used to inject the context placeholder if required. - extract_input_args(&crate_name, method)?; + extract_input_args::(&crate_name, method)?; let ty = match &method.sig.output { ReturnType::Type(_, ty) => OutputType::parse(ty)?, @@ -326,10 +334,121 @@ pub fn generate( } }; - let args = extract_input_args(&crate_name, method)?; + let mut args = extract_input_args::(&crate_name, method)?; let mut schema_args = Vec::new(); let mut use_params = Vec::new(); let mut get_params = Vec::new(); + let mut is_oneof_field = false; + + if method_args.oneof { + is_oneof_field = true; + + if args.len() != 1 { + return Err(Error::new_spanned( + &method, + "The `oneof` field requires exactly one argument.", + ) + .into()); + } + let (ident, ty, argument) = args.pop().unwrap(); + + schema_args.push(quote! { + #crate_name::static_assertions::assert_impl_one!(#ty: #crate_name::OneofObjectType); + if let #crate_name::registry::MetaType::InputObject { input_fields, .. } = registry.create_fake_input_type::<#ty>() { + args.extend(input_fields); + } + }); + use_params.push(quote! { #ident }); + + let validators = argument + .validator + .clone() + .unwrap_or_default() + .create_validators( + &crate_name, + quote!(&#ident), + quote!(#ty), + Some(quote!(.map_err(|err| err.into_server_error(__pos)))), + )?; + get_params.push(quote! { + #[allow(non_snake_case, unused_variables)] + let (__pos, #ident) = ctx.oneof_param_value::<#ty>()?; + #validators + }); + } else { + for ( + ident, + ty, + args::Argument { + name, + desc, + default, + default_with, + validator, + visible, + secret, + .. + }, + ) in &args + { + let name = name.clone().unwrap_or_else(|| { + object_args + .rename_args + .rename(ident.ident.unraw().to_string(), RenameTarget::Argument) + }); + let desc = desc + .as_ref() + .map(|s| quote! {::std::option::Option::Some(#s)}) + .unwrap_or_else(|| quote! {::std::option::Option::None}); + let default = generate_default(default, default_with)?; + let schema_default = default + .as_ref() + .map(|value| { + quote! { + ::std::option::Option::Some(::std::string::ToString::to_string( + &<#ty as #crate_name::InputType>::to_value(&#value) + )) + } + }) + .unwrap_or_else(|| quote! {::std::option::Option::None}); + + let visible = visible_fn(visible); + schema_args.push(quote! { + args.insert(::std::borrow::ToOwned::to_owned(#name), #crate_name::registry::MetaInputValue { + name: #name, + description: #desc, + ty: <#ty as #crate_name::InputType>::create_type_info(registry), + default_value: #schema_default, + visible: #visible, + is_secret: #secret, + }); + }); + + let param_ident = &ident.ident; + use_params.push(quote! { #param_ident }); + + let default = match default { + Some(default) => { + quote! { ::std::option::Option::Some(|| -> #ty { #default }) } + } + None => quote! { ::std::option::Option::None }, + }; + + let validators = validator.clone().unwrap_or_default().create_validators( + &crate_name, + quote!(&#ident), + quote!(#ty), + Some(quote!(.map_err(|err| err.into_server_error(__pos)))), + )?; + + get_params.push(quote! { + #[allow(non_snake_case, unused_variables)] + let (__pos, #ident) = ctx.param_value::<#ty>(#name, #default)?; + #validators + }); + } + } + let ty = match &method.sig.output { ReturnType::Type(_, ty) => OutputType::parse(ty)?, ReturnType::Default => { @@ -340,79 +459,6 @@ pub fn generate( .into()) } }; - - for ( - ident, - ty, - args::Argument { - name, - desc, - default, - default_with, - validator, - visible, - secret, - .. - }, - ) in &args - { - let name = name.clone().unwrap_or_else(|| { - object_args - .rename_args - .rename(ident.ident.unraw().to_string(), RenameTarget::Argument) - }); - let desc = desc - .as_ref() - .map(|s| quote! {::std::option::Option::Some(#s)}) - .unwrap_or_else(|| quote! {::std::option::Option::None}); - let default = generate_default(default, default_with)?; - let schema_default = default - .as_ref() - .map(|value| { - quote! { - ::std::option::Option::Some(::std::string::ToString::to_string( - &<#ty as #crate_name::InputType>::to_value(&#value) - )) - } - }) - .unwrap_or_else(|| quote! {::std::option::Option::None}); - - let visible = visible_fn(visible); - schema_args.push(quote! { - args.insert(#name, #crate_name::registry::MetaInputValue { - name: #name, - description: #desc, - ty: <#ty as #crate_name::InputType>::create_type_info(registry), - default_value: #schema_default, - visible: #visible, - is_secret: #secret, - }); - }); - - let param_ident = &ident.ident; - use_params.push(quote! { #param_ident }); - - let default = match default { - Some(default) => { - quote! { ::std::option::Option::Some(|| -> #ty { #default }) } - } - None => quote! { ::std::option::Option::None }, - }; - - let validators = validator.clone().unwrap_or_default().create_validators( - &crate_name, - quote!(&#ident), - quote!(#ty), - Some(quote!(.map_err(|err| err.into_server_error(__pos)))), - )?; - - get_params.push(quote! { - #[allow(non_snake_case, unused_variables)] - let (__pos, #ident) = ctx.param_value::<#ty>(#name, #default)?; - #validators - }); - } - let schema_ty = ty.value_type(); let visible = visible_fn(&method_args.visible); @@ -485,6 +531,7 @@ pub fn generate( requires: #requires, visible: #visible, compute_complexity: #complexity, + oneof: #is_oneof_field, }); }); diff --git a/derive/src/oneof_object.rs b/derive/src/oneof_object.rs index 8017fd48..588997f9 100644 --- a/derive/src/oneof_object.rs +++ b/derive/src/oneof_object.rs @@ -148,7 +148,7 @@ pub fn generate(object_args: &args::OneofObject) -> GeneratorResult fn parse(value: ::std::option::Option<#crate_name::Value>) -> #crate_name::InputValueResult { if let ::std::option::Option::Some(#crate_name::Value::Object(mut obj)) = value { #(#parse_item)* - ::std::result::Result::Err(#crate_name::InputValueError::expected_type(async_graphql::Value::Object(obj))) + ::std::result::Result::Err(#crate_name::InputValueError::expected_type(#crate_name::Value::Object(obj))) } else { ::std::result::Result::Err(#crate_name::InputValueError::expected_type(value.unwrap_or_default())) } @@ -170,6 +170,9 @@ pub fn generate(object_args: &args::OneofObject) -> GeneratorResult ::std::option::Option::Some(self) } } + + impl #crate_name::InputObjectType for #ident {} + impl #crate_name::OneofObjectType for #ident {} } } else { let mut code = Vec::new(); @@ -195,7 +198,7 @@ pub fn generate(object_args: &args::OneofObject) -> GeneratorResult fn __internal_parse(value: ::std::option::Option<#crate_name::Value>) -> #crate_name::InputValueResult where Self: #crate_name::InputType { if let ::std::option::Option::Some(#crate_name::Value::Object(mut obj)) = value { #(#parse_item)* - ::std::result::Result::Err(#crate_name::InputValueError::expected_type(async_graphql::Value::Object(obj))) + ::std::result::Result::Err(#crate_name::InputValueError::expected_type(#crate_name::Value::Object(obj))) } else { ::std::result::Result::Err(#crate_name::InputValueError::expected_type(value.unwrap_or_default())) } @@ -247,6 +250,7 @@ pub fn generate(object_args: &args::OneofObject) -> GeneratorResult } impl #crate_name::InputObjectType for #concrete_type {} + impl #crate_name::OneofObjectType for #concrete_type {} }; code.push(expanded); } diff --git a/derive/src/simple_object.rs b/derive/src/simple_object.rs index f98f9965..c4ae6f71 100644 --- a/derive/src/simple_object.rs +++ b/derive/src/simple_object.rs @@ -167,6 +167,7 @@ pub fn generate(object_args: &args::SimpleObject) -> GeneratorResult(&crate_name, method)?; + + if field.oneof { + is_oneof_field = true; + + if args.len() != 1 { + return Err(Error::new_spanned( + &method, + "The `oneof` field requires exactly one argument.", + ) + .into()); + } + let (ident, ty, argument) = args.pop().unwrap(); + + schema_args.push(quote! { + #crate_name::static_assertions::assert_impl_one!(#ty: #crate_name::OneofObjectType); + if let #crate_name::registry::MetaType::InputObject { input_fields, .. } = registry.create_fake_input_type::<#ty>() { + args.extend(input_fields); + } + }); + use_params.push(quote! { #ident }); + + let validators = argument + .validator + .clone() + .unwrap_or_default() + .create_validators( + &crate_name, + quote!(&#ident), + quote!(#ty), + Some(quote!(.map_err(|err| err.into_server_error(__pos)))), + )?; + get_params.push(quote! { + #[allow(non_snake_case, unused_variables)] + let (__pos, #ident) = ctx.oneof_param_value::<#ty>()?; + #validators + }); + } else { + for ( + ident, + ty, + args::SubscriptionFieldArgument { + name, + desc, + default, + default_with, + validator, + visible: arg_visible, + secret, + }, + ) in &args + { + let name = name.clone().unwrap_or_else(|| { + subscription_args + .rename_args + .rename(ident.ident.unraw().to_string(), RenameTarget::Argument) + }); + let desc = desc + .as_ref() + .map(|s| quote! {::std::option::Option::Some(#s)}) + .unwrap_or_else(|| quote! {::std::option::Option::None}); + let default = generate_default(default, default_with)?; + + let schema_default = default + .as_ref() + .map(|value| { + quote! { + ::std::option::Option::Some(::std::string::ToString::to_string( + &<#ty as #crate_name::InputType>::to_value(&#value) + )) + } + }) + .unwrap_or_else(|| quote! {::std::option::Option::None}); + + let visible = visible_fn(arg_visible); + schema_args.push(quote! { + args.insert(::std::borrow::ToOwned::to_owned(#name), #crate_name::registry::MetaInputValue { + name: #name, + description: #desc, + ty: <#ty as #crate_name::InputType>::create_type_info(registry), + default_value: #schema_default, + visible: #visible, + is_secret: #secret, + }); + }); + + use_params.push(quote! { #ident }); + + let default = match default { + Some(default) => { + quote! { ::std::option::Option::Some(|| -> #ty { #default }) } + } + None => quote! { ::std::option::Option::None }, + }; + let validators = validator.clone().unwrap_or_default().create_validators( + &crate_name, + quote!(&#ident), + quote!(#ty), + Some(quote!(.map_err(|err| err.into_server_error(__pos)))), + )?; + + get_params.push(quote! { + #[allow(non_snake_case)] + let (__pos, #ident) = ctx.param_value::<#ty>(#name, #default)?; + #validators + }); + } + } + let ty = match &method.sig.output { ReturnType::Type(_, ty) => OutputType::parse(ty)?, ReturnType::Default => { @@ -78,138 +189,6 @@ pub fn generate( .into()) } }; - - let mut create_ctx = true; - let mut args = Vec::new(); - - for (idx, arg) in method.sig.inputs.iter_mut().enumerate() { - if let FnArg::Receiver(receiver) = arg { - if idx != 0 { - return Err(Error::new_spanned( - receiver, - "The self receiver must be the first parameter.", - ) - .into()); - } - } else if let FnArg::Typed(pat) = arg { - if idx == 0 { - return Err(Error::new_spanned( - pat, - "The self receiver must be the first parameter.", - ) - .into()); - } - - match (&*pat.pat, &*pat.ty) { - (Pat::Ident(arg_ident), Type::Path(arg_ty)) => { - args.push(( - arg_ident.clone(), - arg_ty.clone(), - parse_graphql_attrs::(&pat.attrs)? - .unwrap_or_default(), - )); - remove_graphql_attrs(&mut pat.attrs); - } - (arg, Type::Reference(TypeReference { elem, .. })) => { - if let Type::Path(path) = elem.as_ref() { - if idx != 1 || path.path.segments.last().unwrap().ident != "Context" - { - return Err(Error::new_spanned( - arg, - "Only types that implement `InputType` can be used as input arguments.", - ) - .into()); - } else { - create_ctx = false; - } - } - } - _ => { - return Err(Error::new_spanned(arg, "Incorrect argument type").into()); - } - } - } else { - return Err(Error::new_spanned(arg, "Incorrect argument type").into()); - } - } - - if create_ctx { - let arg = syn::parse2::(quote! { _: &#crate_name::Context<'_> }).unwrap(); - method.sig.inputs.insert(1, arg); - } - - let mut schema_args = Vec::new(); - let mut use_params = Vec::new(); - let mut get_params = Vec::new(); - - for ( - ident, - ty, - args::SubscriptionFieldArgument { - name, - desc, - default, - default_with, - validator, - visible: arg_visible, - secret, - }, - ) in &args - { - let name = name.clone().unwrap_or_else(|| { - subscription_args - .rename_args - .rename(ident.ident.unraw().to_string(), RenameTarget::Argument) - }); - let desc = desc - .as_ref() - .map(|s| quote! {::std::option::Option::Some(#s)}) - .unwrap_or_else(|| quote! {::std::option::Option::None}); - let default = generate_default(default, default_with)?; - - let schema_default = default - .as_ref() - .map(|value| { - quote! { - ::std::option::Option::Some(::std::string::ToString::to_string( - &<#ty as #crate_name::InputType>::to_value(&#value) - )) - } - }) - .unwrap_or_else(|| quote! {::std::option::Option::None}); - - let visible = visible_fn(arg_visible); - schema_args.push(quote! { - args.insert(#name, #crate_name::registry::MetaInputValue { - name: #name, - description: #desc, - ty: <#ty as #crate_name::InputType>::create_type_info(registry), - default_value: #schema_default, - visible: #visible, - is_secret: #secret, - }); - }); - - use_params.push(quote! { #ident }); - - let default = match default { - Some(default) => quote! { ::std::option::Option::Some(|| -> #ty { #default }) }, - None => quote! { ::std::option::Option::None }, - }; - let validators = validator.clone().unwrap_or_default().create_validators( - &crate_name, - quote!(&#ident), - quote!(#ty), - Some(quote!(.map_err(|err| err.into_server_error(__pos)))), - )?; - - get_params.push(quote! { - #[allow(non_snake_case)] - let (__pos, #ident) = ctx.param_value::<#ty>(#name, #default)?; - #validators - }); - } - let res_ty = ty.value_type(); let stream_ty = if let Type::ImplTrait(TypeImplTrait { bounds, .. }) = &res_ty { let mut r = None; @@ -307,6 +286,7 @@ pub fn generate( provides: ::std::option::Option::None, visible: #visible, compute_complexity: #complexity, + oneof: #is_oneof_field, }); }); diff --git a/derive/src/utils.rs b/derive/src/utils.rs index 1b578496..871483d8 100644 --- a/derive/src/utils.rs +++ b/derive/src/utils.rs @@ -13,7 +13,7 @@ use syn::{ }; use thiserror::Error; -use crate::args::{self, Argument, Deprecation, Visible}; +use crate::args::{self, Deprecation, Visible}; #[derive(Error, Debug)] pub enum GeneratorError { @@ -142,7 +142,9 @@ pub fn get_cfg_attrs(attrs: &[Attribute]) -> Vec { .collect() } -pub fn parse_graphql_attrs(attrs: &[Attribute]) -> GeneratorResult> { +pub fn parse_graphql_attrs( + attrs: &[Attribute], +) -> GeneratorResult> { for attr in attrs { if attr.path.is_ident("graphql") { let meta = attr.parse_meta()?; @@ -238,10 +240,10 @@ pub fn gen_deprecation(deprecation: &Deprecation, crate_name: &TokenStream) -> T } } -pub fn extract_input_args( +pub fn extract_input_args( crate_name: &proc_macro2::TokenStream, method: &mut ImplItemMethod, -) -> GeneratorResult> { +) -> GeneratorResult> { let mut args = Vec::new(); let mut create_ctx = true; @@ -278,8 +280,7 @@ pub fn extract_input_args( args.push(( arg_ident.clone(), pat.ty.as_ref().clone(), - parse_graphql_attrs::(&pat.attrs)? - .unwrap_or_default(), + parse_graphql_attrs::(&pat.attrs)?.unwrap_or_default(), )); } else { create_ctx = false; @@ -290,7 +291,7 @@ pub fn extract_input_args( args.push(( arg_ident.clone(), ty.clone(), - parse_graphql_attrs::(&pat.attrs)?.unwrap_or_default(), + parse_graphql_attrs::(&pat.attrs)?.unwrap_or_default(), )); remove_graphql_attrs(&mut pat.attrs); } diff --git a/src/base.rs b/src/base.rs index f9474e32..c0e7338e 100644 --- a/src/base.rs +++ b/src/base.rs @@ -148,6 +148,9 @@ pub trait UnionType: ContainerType {} /// A GraphQL input object. pub trait InputObjectType: InputType {} +/// A GraphQL oneof input object. +pub trait OneofObjectType: InputObjectType {} + #[async_trait::async_trait] impl OutputType for Box { fn type_name() -> Cow<'static, str> { diff --git a/src/context.rs b/src/context.rs index 41aa822b..39970b15 100644 --- a/src/context.rs +++ b/src/context.rs @@ -19,8 +19,8 @@ use crate::parser::types::{ }; use crate::schema::SchemaEnv; use crate::{ - Error, InputType, Lookahead, Name, PathSegment, Pos, Positioned, Result, ServerError, - ServerResult, UploadValue, Value, + Error, InputType, Lookahead, Name, OneofObjectType, PathSegment, Pos, Positioned, Result, + ServerError, ServerResult, UploadValue, Value, }; /// Data related functions of the context. @@ -604,6 +604,22 @@ impl<'a> ContextBase<'a, &'a Positioned> { self.get_param_value(&self.item.node.arguments, name, default) } + #[doc(hidden)] + pub fn oneof_param_value(&self) -> ServerResult<(Pos, T)> { + use indexmap::IndexMap; + + let mut map = IndexMap::new(); + + for (name, value) in &self.item.node.arguments { + let value = self.resolve_input_value(value.clone())?; + map.insert(name.node.clone(), value); + } + + InputType::parse(Some(Value::Object(map))) + .map(|value| (self.item.pos, value)) + .map_err(|e| e.into_server_error(self.item.pos)) + } + /// Creates a uniform interface to inspect the forthcoming selections. /// /// # Examples diff --git a/src/docs/complex_object.md b/src/docs/complex_object.md index b2bca4a1..06649fc1 100644 --- a/src/docs/complex_object.md +++ b/src/docs/complex_object.md @@ -36,6 +36,7 @@ some simple fields, and use the `ComplexObject` macro to define some other field | complexity | Custom field complexity. | string | Y | | derived | Generate derived fields *[See also the Book](https://async-graphql.github.io/async-graphql/en/derived_fields.html).* | object | Y | | flatten | Similar to serde (flatten) | boolean | Y | +| oneof | Oneof field | bool | Y | # Field argument attributes diff --git a/src/docs/interface.md b/src/docs/interface.md index d630af42..6a1953c8 100644 --- a/src/docs/interface.md +++ b/src/docs/interface.md @@ -30,6 +30,7 @@ Define a GraphQL interface | requires | Annotate the required input fieldset from a base type for a resolver. It is used to develop a query plan where the required fields may not be needed by the client, but the service may need additional information from other services. | string | Y | | visible | If `false`, it will not be displayed in introspection. *[See also the Book](https://async-graphql.github.io/async-graphql/en/visibility.html).* | bool | Y | | visible | Call the specified function. If the return value is `false`, it will not be displayed in introspection. | string | Y | +| oneof | Oneof field | bool | Y | # Field argument attributes diff --git a/src/docs/object.md b/src/docs/object.md index 9faeb8d3..c93953fa 100644 --- a/src/docs/object.md +++ b/src/docs/object.md @@ -39,6 +39,7 @@ All methods are converted to camelCase. | complexity | Custom field complexity. | string | Y | | derived | Generate derived fields *[See also the Book](https://async-graphql.github.io/async-graphql/en/derived_fields.html).* | object | Y | | flatten | Similar to serde (flatten) | boolean | Y | +| oneof | Oneof field | bool | Y | # Field argument attributes diff --git a/src/docs/subscription.md b/src/docs/subscription.md index b9a86fc2..9152d85f 100644 --- a/src/docs/subscription.md +++ b/src/docs/subscription.md @@ -32,6 +32,7 @@ The filter function should be synchronous. | complexity | Custom field complexity. *[See also the Book](https://async-graphql.github.io/async-graphql/en/depth_and_complexity.html).* | bool | Y | | complexity | Custom field complexity. | string | Y | | secret | Mark this field as a secret, it will not output the actual value in the log. | bool | Y | +| oneof | Oneof field | bool | Y | # Field argument attributes diff --git a/src/lib.rs b/src/lib.rs index 50f1e35d..df612cc0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,7 +76,6 @@ //! - `smol_str`: Integrate with the [`smol_str` crate](https://crates.io/crates/smol_str). //! - `hashbrown`: Integrate with the [`hashbrown` crate](https://github.com/rust-lang/hashbrown). //! - `time`: Integrate with the [`time` crate](https://github.com/time-rs/time). -//! - `unstable_oneof`: Enable the `OneofObject` macro to define the oneof input object. //! //! ## Integrations //! @@ -214,8 +213,8 @@ pub use async_graphql_value::{ SerializerError, Variables, }; pub use base::{ - ComplexObject, Description, InputObjectType, InputType, InterfaceType, ObjectType, OutputType, - UnionType, + ComplexObject, Description, InputObjectType, InputType, InterfaceType, ObjectType, + OneofObjectType, OutputType, UnionType, }; pub use custom_directive::{CustomDirective, CustomDirectiveFactory}; pub use error::{ @@ -267,8 +266,6 @@ pub use async_graphql_derive::MergedSubscription; pub use async_graphql_derive::NewType; #[doc = include_str!("docs/object.md")] pub use async_graphql_derive::Object; -#[cfg(feature = "unstable_oneof")] -#[cfg_attr(docsrs, doc(cfg(feature = "unstable_oneof")))] #[doc = include_str!("docs/oneof_object.md")] pub use async_graphql_derive::OneofObject; #[doc = include_str!("docs/scalar.md")] diff --git a/src/model/field.rs b/src/model/field.rs index 200ee3cd..e5b4ec17 100644 --- a/src/model/field.rs +++ b/src/model/field.rs @@ -50,4 +50,8 @@ impl<'a> __Field<'a> { async fn deprecation_reason(&self) -> Option<&str> { self.field.deprecation.reason() } + + async fn one_of(&self) -> bool { + self.field.oneof + } } diff --git a/src/model/type.rs b/src/model/type.rs index 991547da..160a3aa2 100644 --- a/src/model/type.rs +++ b/src/model/type.rs @@ -229,7 +229,6 @@ impl<'a> __Type<'a> { } } - #[cfg(feature = "unstable_oneof")] async fn one_of(&self) -> Option { if let TypeDetail::Named(registry::MetaType::InputObject { oneof, .. }) = &self.detail { Some(*oneof) diff --git a/src/registry/export_sdl.rs b/src/registry/export_sdl.rs index d70b6b8e..b19153bb 100644 --- a/src/registry/export_sdl.rs +++ b/src/registry/export_sdl.rs @@ -69,6 +69,10 @@ impl Registry { write!(sdl, "\t{}: {}", field.name, field.ty).ok(); } + if field.oneof { + write!(sdl, " @oneof").ok(); + } + if federation { if field.external { write!(sdl, " @external").ok(); @@ -202,7 +206,6 @@ impl Registry { name, input_fields, description, - #[cfg(feature = "unstable_oneof")] oneof, .. } => { @@ -210,7 +213,6 @@ impl Registry { writeln!(sdl, "\"\"\"\n{}\n\"\"\"", description.unwrap()).ok(); } write!(sdl, "input {} ", name).ok(); - #[cfg(feature = "unstable_oneof")] if *oneof { write!(sdl, "@oneof ").ok(); } diff --git a/src/registry/mod.rs b/src/registry/mod.rs index fc8dac36..e818e39e 100644 --- a/src/registry/mod.rs +++ b/src/registry/mod.rs @@ -157,7 +157,7 @@ impl Deprecation { pub struct MetaField { pub name: String, pub description: Option<&'static str>, - pub args: IndexMap<&'static str, MetaInputValue>, + pub args: IndexMap, pub ty: String, pub deprecation: Deprecation, pub cache_control: CacheControl, @@ -166,6 +166,7 @@ pub struct MetaField { pub provides: Option<&'static str>, pub visible: Option, pub compute_complexity: Option, + pub oneof: bool, } #[derive(Clone)] @@ -351,7 +352,7 @@ pub struct MetaDirective { pub name: &'static str, pub description: Option<&'static str>, pub locations: Vec, - pub args: IndexMap<&'static str, MetaInputValue>, + pub args: IndexMap, pub is_repeatable: bool, pub visible: Option, } @@ -566,6 +567,7 @@ impl Registry { provides: None, visible: None, compute_complexity: None, + oneof: false, }, ); @@ -577,7 +579,7 @@ impl Registry { args: { let mut args = IndexMap::new(); args.insert( - "representations", + "representations".to_string(), MetaInputValue { name: "representations", description: None, @@ -597,6 +599,7 @@ impl Registry { provides: None, visible: None, compute_complexity: None, + oneof: false, }, ); } @@ -627,6 +630,7 @@ impl Registry { provides: None, visible: None, compute_complexity: None, + oneof: false, }, ); fields diff --git a/src/schema.rs b/src/schema.rs index 65002fdb..bfdd0920 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -318,7 +318,7 @@ where ], args: { let mut args = IndexMap::new(); - args.insert("if", MetaInputValue { + args.insert("if".to_string(), MetaInputValue { name: "if", description: Some("Included when true."), ty: "Boolean!".to_string(), @@ -342,7 +342,7 @@ where ], args: { let mut args = IndexMap::new(); - args.insert("if", MetaInputValue { + args.insert("if".to_string(), MetaInputValue { name: "if", description: Some("Skipped when true."), ty: "Boolean!".to_string(), diff --git a/src/types/connection/connection_type.rs b/src/types/connection/connection_type.rs index 1d7a1499..e58dd7f0 100644 --- a/src/types/connection/connection_type.rs +++ b/src/types/connection/connection_type.rs @@ -193,6 +193,7 @@ where provides: None, visible: None, compute_complexity: None, + oneof: false, }, ); @@ -212,6 +213,7 @@ where provides: None, visible: None, compute_complexity: None, + oneof: false, }, ); diff --git a/src/types/connection/edge.rs b/src/types/connection/edge.rs index 1955c394..b4704f05 100644 --- a/src/types/connection/edge.rs +++ b/src/types/connection/edge.rs @@ -101,6 +101,7 @@ where provides: None, visible: None, compute_complexity: None, + oneof: false, }, ); @@ -118,6 +119,7 @@ where provides: None, visible: None, compute_complexity: None, + oneof: false, }, ); diff --git a/src/types/query_root.rs b/src/types/query_root.rs index f40dd5b4..d984b965 100644 --- a/src/types/query_root.rs +++ b/src/types/query_root.rs @@ -112,6 +112,7 @@ impl OutputType for QueryRoot { provides: None, visible: None, compute_complexity: None, + oneof: false, }, ); @@ -123,7 +124,7 @@ impl OutputType for QueryRoot { args: { let mut args = IndexMap::new(); args.insert( - "name", + "name".to_string(), registry::MetaInputValue { name: "name", description: None, @@ -143,6 +144,7 @@ impl OutputType for QueryRoot { provides: None, visible: None, compute_complexity: None, + oneof: false, }, ); } diff --git a/src/validation/mod.rs b/src/validation/mod.rs index fe8c030c..53b153e0 100644 --- a/src/validation/mod.rs +++ b/src/validation/mod.rs @@ -73,6 +73,7 @@ pub fn check_rules( .with(rules::KnownDirectives::default()) .with(rules::DirectivesUnique::default()) .with(rules::OverlappingFieldsCanBeMerged) + .with(rules::OneofFieldsHaveExactlyOneArgument) .with(rules::UploadFile) .with(visitors::CacheControlCalculate { cache_control: &mut cache_control, diff --git a/src/validation/rules/arguments_of_correct_type.rs b/src/validation/rules/arguments_of_correct_type.rs index 8e6d09ae..c4dc1382 100644 --- a/src/validation/rules/arguments_of_correct_type.rs +++ b/src/validation/rules/arguments_of_correct_type.rs @@ -11,7 +11,7 @@ use crate::{Name, Positioned, QueryPathSegment}; #[derive(Default)] pub struct ArgumentsOfCorrectType<'a> { - current_args: Option<&'a IndexMap<&'static str, MetaInputValue>>, + current_args: Option<&'a IndexMap>, } impl<'a> Visitor<'a> for ArgumentsOfCorrectType<'a> { @@ -983,4 +983,82 @@ mod tests { "#, ); } + + #[test] + fn oneof() { + expect_passes_rule!( + factory, + r#" + { + oneofArg(arg: {a: 10}) + } + "#, + ); + + expect_passes_rule!( + factory, + r#" + { + oneofArg(arg: {b: "abc"}) + } + "#, + ); + + expect_fails_rule!( + factory, + r#" + { + oneofArg(arg: {a: 10, b: "abc"}) + } + "#, + ); + } + + #[test] + fn oneof_opt() { + expect_passes_rule!( + factory, + r#" + { + oneofOpt(arg: {a: 10}) + } + "#, + ); + + expect_passes_rule!( + factory, + r#" + { + oneofOpt(arg: {b: "abc"}) + } + "#, + ); + + expect_passes_rule!( + factory, + r#" + { + oneofOpt + } + "#, + ); + + expect_passes_rule!( + factory, + r#" + { + oneofOpt(arg: null) + } + "#, + ); + + expect_fails_rule!( + factory, + r#" + { + oneofOpt(arg: {a: 10, b: "abc"}) + } + "#, + ); + } } diff --git a/src/validation/rules/known_argument_names.rs b/src/validation/rules/known_argument_names.rs index b80fe396..715193a1 100644 --- a/src/validation/rules/known_argument_names.rs +++ b/src/validation/rules/known_argument_names.rs @@ -17,7 +17,7 @@ enum ArgsType<'a> { #[derive(Default)] pub struct KnownArgumentNames<'a> { - current_args: Option<(&'a IndexMap<&'static str, MetaInputValue>, ArgsType<'a>)>, + current_args: Option<(&'a IndexMap, ArgsType<'a>)>, } impl<'a> KnownArgumentNames<'a> { @@ -26,7 +26,7 @@ impl<'a> KnownArgumentNames<'a> { " Did you mean", self.current_args .iter() - .map(|(args, _)| args.iter().map(|arg| *arg.0)) + .map(|(args, _)| args.iter().map(|arg| arg.0.as_str())) .flatten(), name, ) diff --git a/src/validation/rules/mod.rs b/src/validation/rules/mod.rs index e36f0391..b37d55f0 100644 --- a/src/validation/rules/mod.rs +++ b/src/validation/rules/mod.rs @@ -11,6 +11,7 @@ mod no_fragment_cycles; mod no_undefined_variables; mod no_unused_fragments; mod no_unused_variables; +mod oneof_fields_have_exactly_one_argument; mod overlapping_fields_can_be_merged; mod possible_fragment_spreads; mod provided_non_null_arguments; @@ -34,6 +35,7 @@ pub use no_fragment_cycles::NoFragmentCycles; pub use no_undefined_variables::NoUndefinedVariables; pub use no_unused_fragments::NoUnusedFragments; pub use no_unused_variables::NoUnusedVariables; +pub use oneof_fields_have_exactly_one_argument::OneofFieldsHaveExactlyOneArgument; pub use overlapping_fields_can_be_merged::OverlappingFieldsCanBeMerged; pub use possible_fragment_spreads::PossibleFragmentSpreads; pub use provided_non_null_arguments::ProvidedNonNullArguments; diff --git a/src/validation/rules/oneof_fields_have_exactly_one_argument.rs b/src/validation/rules/oneof_fields_have_exactly_one_argument.rs new file mode 100644 index 00000000..8a872cb7 --- /dev/null +++ b/src/validation/rules/oneof_fields_have_exactly_one_argument.rs @@ -0,0 +1,66 @@ +use crate::validation::visitor::{RuleError, Visitor}; +use crate::VisitorContext; +use async_graphql_parser::types::Field; +use async_graphql_parser::Positioned; + +pub struct OneofFieldsHaveExactlyOneArgument; + +impl<'a> Visitor<'a> for OneofFieldsHaveExactlyOneArgument { + fn enter_field(&mut self, ctx: &mut VisitorContext<'a>, field: &'a Positioned) { + if let Some(parent_type) = ctx.parent_type() { + if let Some(field_def) = parent_type + .fields() + .and_then(|fields| fields.get(field.node.name.node.as_str())) + { + if field_def.oneof && field.node.arguments.len() != 1 { + ctx.errors.push(RuleError::new( + vec![field.pos], + "Oneof fields requires have exactly one argument".to_string(), + )); + } + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn factory() -> OneofFieldsHaveExactlyOneArgument { + OneofFieldsHaveExactlyOneArgument + } + + #[test] + fn oneof_field() { + expect_passes_rule!( + factory, + r#" + query Foo { + oneofField(a: 10) + } + "#, + ); + } + + #[test] + fn oneof_not_exactly_one_argument() { + expect_fails_rule!( + factory, + r#" + query Foo { + oneofField(a: 10, b: "abc") + } + "#, + ); + + expect_fails_rule!( + factory, + r#" + query Foo { + oneofField + } + "#, + ); + } +} diff --git a/src/validation/suggestion.rs b/src/validation/suggestion.rs index a935d898..6e04566c 100644 --- a/src/validation/suggestion.rs +++ b/src/validation/suggestion.rs @@ -18,18 +18,20 @@ fn levenshtein_distance(s1: &str, s2: &str) -> usize { column[s1.len()] } -pub fn make_suggestion<'a, I>(prefix: &str, options: I, input: &str) -> Option +pub fn make_suggestion(prefix: &str, options: I, input: &str) -> Option where - I: Iterator, + I: IntoIterator, + A: AsRef, { let mut selected = Vec::new(); let mut distances = HashMap::new(); for opt in options { - let distance = levenshtein_distance(input, opt); + let opt = opt.as_ref().to_string(); + let distance = levenshtein_distance(input, &opt); let threshold = (input.len() / 2).max((opt.len() / 2).max(1)); if distance < threshold { - selected.push(opt); + selected.push(opt.clone()); distances.insert(opt, distance); } } diff --git a/src/validation/test_harness.rs b/src/validation/test_harness.rs index 5f97f505..fedb2d81 100644 --- a/src/validation/test_harness.rs +++ b/src/validation/test_harness.rs @@ -288,6 +288,13 @@ impl ComplicatedArgs { } } +#[derive(OneofObject)] +#[graphql(internal)] +enum OneofArg { + A(i32), + B(String), +} + pub struct Query; #[Object(internal)] @@ -335,6 +342,19 @@ impl Query { async fn complicated_args(&self) -> Option { unimplemented!() } + + async fn oneof_arg(&self, arg: OneofArg) -> String { + unimplemented!() + } + + async fn oneof_opt(&self, arg: Option) -> String { + unimplemented!() + } + + #[graphql(oneof)] + async fn oneof_field(&self, arg: OneofArg) -> String { + unimplemented!() + } } pub struct Mutation; diff --git a/src/validation/utils.rs b/src/validation/utils.rs index 5f1051c9..0c1214a4 100644 --- a/src/validation/utils.rs +++ b/src/validation/utils.rs @@ -120,9 +120,17 @@ pub fn is_valid_input_value( registry::MetaType::InputObject { input_fields, name: object_name, + oneof, .. } => match value { ConstValue::Object(values) => { + if *oneof && values.len() != 1 { + return Some(valid_error( + &path_node, + "Oneof input objects requires have exactly one field".to_string(), + )); + } + let mut input_names = values.keys().map(AsRef::as_ref).collect::>(); diff --git a/tests/complex_object.rs b/tests/complex_object.rs index 556599c2..f9e3d90f 100644 --- a/tests/complex_object.rs +++ b/tests/complex_object.rs @@ -492,3 +492,87 @@ async fn test_flatten_with_result() { }) ); } + +#[tokio::test] +async fn test_oneof_field() { + #[derive(OneofObject)] + enum TestArg { + A(i32), + B(String), + } + + #[derive(SimpleObject)] + #[graphql(complex)] + struct Query { + a: i32, + } + + #[ComplexObject] + impl Query { + #[graphql(oneof)] + async fn test(&self, arg: TestArg) -> String { + match arg { + TestArg::A(a) => format!("a:{}", a), + TestArg::B(b) => format!("b:{}", b), + } + } + } + + let schema = Schema::new(Query { a: 10 }, EmptyMutation, EmptySubscription); + let query = "{ test(a: 10) }"; + assert_eq!( + schema.execute(query).await.into_result().unwrap().data, + value!({ + "test": "a:10" + }) + ); + + let query = r#"{ test(b: "abc") }"#; + assert_eq!( + schema.execute(query).await.into_result().unwrap().data, + value!({ + "test": "b:abc" + }) + ); + + let query = r#"{ + __type(name: "Query") { + fields { + name + args { + name + type { + kind + name + } + } + } + } + }"#; + assert_eq!( + schema.execute(query).await.into_result().unwrap().data, + value!({ + "__type": { + "fields": [{ + "name": "a", + "args": [] + }, { + "name": "test", + "args": [{ + "name": "a", + "type": { + "kind": "SCALAR", + "name": "Int" + } + }, { + "name": "b", + "type": { + "kind": "SCALAR", + "name": "String" + } + }] + }] + } + }) + ); +} diff --git a/tests/interface.rs b/tests/interface.rs index 6d43a71b..b57b1118 100644 --- a/tests/interface.rs +++ b/tests/interface.rs @@ -473,3 +473,118 @@ pub async fn test_issue_330() { }) ); } + +#[tokio::test] +pub async fn test_oneof() { + #[derive(OneofObject)] + enum TestArg { + A(i32), + B(String), + } + + struct A; + + #[Object] + impl A { + async fn test(&self, arg: TestArg) -> String { + match arg { + TestArg::A(a) => format!("A:a:{}", a), + TestArg::B(b) => format!("A:b:{}", b), + } + } + } + + struct B; + + #[Object] + impl B { + async fn test(&self, arg: TestArg) -> String { + match arg { + TestArg::A(a) => format!("B:a:{}", a), + TestArg::B(b) => format!("B:b:{}", b), + } + } + } + + #[derive(Interface)] + #[graphql(field( + name = "test", + type = "String", + oneof, + arg(name = "arg", type = "TestArg") + ))] + enum Obj { + A(A), + B(B), + } + + struct Query; + + #[Object] + impl Query { + async fn a(&self) -> Obj { + A.into() + } + + async fn b(&self) -> Obj { + B.into() + } + } + + let schema = Schema::new(Query, EmptyMutation, EmptySubscription); + assert_eq!( + schema + .execute("{ a { test(a: 10) } }") + .await + .into_result() + .unwrap() + .data, + value!({ + "a": { + "test": "A:a:10" + } + }) + ); + + assert_eq!( + schema + .execute("{ a { test(b: \"abc\") } }") + .await + .into_result() + .unwrap() + .data, + value!({ + "a": { + "test": "A:b:abc" + } + }) + ); + + assert_eq!( + schema + .execute("{ b { test(a: 10) } }") + .await + .into_result() + .unwrap() + .data, + value!({ + "b": { + "test": "B:a:10" + } + }) + ); + + assert_eq!( + schema + .execute("{ b { test(b: \"def\") } }") + .await + .into_result() + .unwrap() + .data, + value!({ + "b": { + "test": "B:b:def" + } + }) + ); +} diff --git a/tests/object.rs b/tests/object.rs index e11d740e..71377b15 100644 --- a/tests/object.rs +++ b/tests/object.rs @@ -117,3 +117,80 @@ async fn test_flatten_with_context() { }) ); } + +#[tokio::test] +async fn test_oneof_field() { + #[derive(OneofObject)] + enum TestArg { + A(i32), + B(String), + } + + struct Query; + + #[Object] + impl Query { + #[graphql(oneof)] + async fn test(&self, arg: TestArg) -> String { + match arg { + TestArg::A(a) => format!("a:{}", a), + TestArg::B(b) => format!("b:{}", b), + } + } + } + + let schema = Schema::new(Query, EmptyMutation, EmptySubscription); + let query = "{ test(a: 10) }"; + assert_eq!( + schema.execute(query).await.into_result().unwrap().data, + value!({ + "test": "a:10" + }) + ); + + let query = r#"{ test(b: "abc") }"#; + assert_eq!( + schema.execute(query).await.into_result().unwrap().data, + value!({ + "test": "b:abc" + }) + ); + + let query = r#"{ + __type(name: "Query") { + fields { + name + args { + name + type { + kind + name + } + } + } + } + }"#; + assert_eq!( + schema.execute(query).await.into_result().unwrap().data, + value!({ + "__type": { + "fields": [{ + "name": "test", + "args": [{ + "name": "a", + "type": { + "kind": "SCALAR", + "name": "Int" + } + }, { + "name": "b", + "type": { + "kind": "SCALAR", + "name": "String" + } + }] + }] + } + }) + ); +} diff --git a/tests/subscription.rs b/tests/subscription.rs index 4fa494b0..9efd3956 100644 --- a/tests/subscription.rs +++ b/tests/subscription.rs @@ -425,3 +425,46 @@ pub async fn test_subscription_fieldresult() { assert!(stream.next().await.is_none()); } + +#[tokio::test] +pub async fn test_oneof_field() { + #[derive(OneofObject)] + enum TestArg { + A(i32), + B(String), + } + + struct Subscription; + + #[Subscription] + impl Subscription { + #[graphql(oneof)] + async fn test(&self, arg: TestArg) -> impl Stream { + let value = match arg { + TestArg::A(a) => format!("a:{}", a), + TestArg::B(b) => format!("b:{}", b), + }; + futures_util::stream::once(async move { value }) + } + } + + let schema = Schema::new(Query, EmptyMutation, Subscription); + + let mut stream = schema.execute_stream(r#"subscription { test(a: 10) }"#); + assert_eq!( + Response::new(value!({ + "test": "a:10" + })), + stream.next().await.unwrap() + ); + assert!(stream.next().await.is_none()); + + let mut stream = schema.execute_stream(r#"subscription { test(b: "abc") }"#); + assert_eq!( + Response::new(value!({ + "test": "b:abc" + })), + stream.next().await.unwrap() + ); + assert!(stream.next().await.is_none()); +} diff --git a/tests/validators.rs b/tests/validators.rs index 655e4aed..9bc2317a 100644 --- a/tests/validators.rs +++ b/tests/validators.rs @@ -789,3 +789,103 @@ pub async fn test_list_both_max_items_and_max_length() { }) ); } + +#[tokio::test] +pub async fn test_custom_validator_on_oneof_field() { + struct MyValidator; + + impl CustomValidator for MyValidator { + fn check(&self, value: &TestArg) -> Result<(), String> { + match value { + TestArg::A(a) => { + if *a < 100 { + Ok(()) + } else { + Err("invalid a".into()) + } + } + TestArg::B(b) => { + if b.len() < 5 { + Ok(()) + } else { + Err("invalid b".into()) + } + } + } + } + } + + #[derive(OneofObject)] + enum TestArg { + A(i32), + B(String), + } + + struct Query; + + #[Object] + impl Query { + #[graphql(oneof)] + async fn test(&self, #[graphql(validator(custom = "MyValidator"))] arg: TestArg) -> String { + match arg { + TestArg::A(a) => format!("a:{}", a), + TestArg::B(b) => format!("b:{}", b), + } + } + } + + let schema = Schema::new(Query, EmptyMutation, EmptySubscription); + assert_eq!( + schema + .execute("{ test(a: 10) }") + .await + .into_result() + .unwrap() + .data, + value!({ + "test": "a:10" + }) + ); + + assert_eq!( + schema + .execute("{ test(a: 200) }") + .await + .into_result() + .unwrap_err(), + vec![ServerError { + message: r#"Failed to parse "TestArg": invalid a"#.to_string(), + source: None, + locations: vec![Pos { column: 3, line: 1 }], + path: vec![PathSegment::Field("test".to_string())], + extensions: None + }] + ); + + assert_eq!( + schema + .execute(r#"{ test(b: "abcd") }"#) + .await + .into_result() + .unwrap() + .data, + value!({ + "test": "b:abcd" + }) + ); + + assert_eq!( + schema + .execute(r#"{ test(b: "abcde") }"#) + .await + .into_result() + .unwrap_err(), + vec![ServerError { + message: r#"Failed to parse "TestArg": invalid b"#.to_string(), + source: None, + locations: vec![Pos { column: 3, line: 1 }], + path: vec![PathSegment::Field("test".to_string())], + extensions: None + }] + ); +}