diff --git a/async-graphql-derive/src/args.rs b/async-graphql-derive/src/args.rs index 3cac51a1..a80f5ac8 100644 --- a/async-graphql-derive/src/args.rs +++ b/async-graphql-derive/src/args.rs @@ -848,22 +848,53 @@ impl DataSource { pub struct Scalar { pub internal: bool, + pub name: Option, + pub desc: Option, } impl Scalar { pub fn parse(args: AttributeArgs) -> Result { let mut internal = false; + let mut name = None; + let mut desc = None; for arg in args { match arg { - NestedMeta::Meta(Meta::Path(p)) if p.is_ident("internal") => { - internal = true; + NestedMeta::Meta(Meta::Path(p)) => { + if p.is_ident("internal") { + internal = true; + } + } + NestedMeta::Meta(Meta::NameValue(nv)) => { + if nv.path.is_ident("name") { + if let syn::Lit::Str(lit) = nv.lit { + name = Some(lit.value()); + } else { + return Err(Error::new_spanned( + &nv.lit, + "Attribute 'name' should be a string.", + )); + } + } else if nv.path.is_ident("desc") { + if let syn::Lit::Str(lit) = nv.lit { + desc = Some(lit.value()); + } else { + return Err(Error::new_spanned( + &nv.lit, + "Attribute 'desc' should be a string.", + )); + } + } } _ => {} } } - Ok(Self { internal }) + Ok(Self { + internal, + name, + desc, + }) } } diff --git a/async-graphql-derive/src/lib.rs b/async-graphql-derive/src/lib.rs index 98ef4d69..0f171478 100644 --- a/async-graphql-derive/src/lib.rs +++ b/async-graphql-derive/src/lib.rs @@ -8,6 +8,7 @@ mod input_object; mod interface; mod object; mod output_type; +mod scalar; mod simple_object; mod subscription; mod union; @@ -141,46 +142,9 @@ pub fn Scalar(args: TokenStream, input: TokenStream) -> TokenStream { Ok(scalar_args) => scalar_args, Err(err) => return err.to_compile_error().into(), }; - let input2: proc_macro2::TokenStream = input.clone().into(); - let item_impl = parse_macro_input!(input as ItemImpl); - let self_ty = &item_impl.self_ty; - let generic = &item_impl.generics; - let where_clause = &item_impl.generics.where_clause; - let crate_name = get_crate_name(scalar_args.internal); - let expanded = quote! { - #input2 - - impl #generic #crate_name::Type for #self_ty #where_clause { - fn type_name() -> std::borrow::Cow<'static, str> { - std::borrow::Cow::Borrowed(<#self_ty as #crate_name::ScalarType>::type_name()) - } - - fn create_type_info(registry: &mut #crate_name::registry::Registry) -> String { - registry.create_type::<#self_ty, _>(|_| #crate_name::registry::MetaType::Scalar { - name: <#self_ty as #crate_name::ScalarType>::type_name().to_string(), - description: <#self_ty>::description(), - is_valid: |value| <#self_ty as #crate_name::ScalarType>::is_valid(value), - }) - } - } - - impl #generic #crate_name::InputValueType for #self_ty #where_clause { - fn parse(value: #crate_name::Value) -> #crate_name::InputValueResult { - <#self_ty as #crate_name::ScalarType>::parse(value) - } - } - - #[allow(clippy::ptr_arg)] - #[#crate_name::async_trait::async_trait] - impl #generic #crate_name::OutputValueType for #self_ty #where_clause { - async fn resolve( - &self, - _: &#crate_name::ContextSelectionSet<'_>, - _pos: #crate_name::Pos, - ) -> #crate_name::Result<#crate_name::serde_json::Value> { - self.to_json() - } - } - }; - expanded.into() + let mut item_impl = parse_macro_input!(input as ItemImpl); + match scalar::generate(&scalar_args, &mut item_impl) { + Ok(expanded) => expanded, + Err(err) => err.to_compile_error().into(), + } } diff --git a/async-graphql-derive/src/scalar.rs b/async-graphql-derive/src/scalar.rs new file mode 100644 index 00000000..6856a083 --- /dev/null +++ b/async-graphql-derive/src/scalar.rs @@ -0,0 +1,68 @@ +use crate::args; +use crate::utils::{check_reserved_name, get_crate_name, get_rustdoc}; +use proc_macro::TokenStream; +use quote::quote; +use syn::{Error, ItemImpl, Result, Type}; + +pub fn generate(scalar_args: &args::Scalar, item_impl: &mut ItemImpl) -> Result { + let self_name = match item_impl.self_ty.as_ref() { + Type::Path(path) => path + .path + .segments + .last() + .map(|s| s.ident.to_string()) + .unwrap(), + _ => return Err(Error::new_spanned(&item_impl.self_ty, "Invalid type")), + }; + let gql_typename = scalar_args + .name + .clone() + .unwrap_or_else(|| self_name.clone()); + check_reserved_name(&gql_typename, scalar_args.internal)?; + let desc = scalar_args + .desc + .clone() + .or_else(|| get_rustdoc(&item_impl.attrs).ok().flatten()) + .map(|s| quote! { Some(#s) }) + .unwrap_or_else(|| quote! {None}); + let self_ty = &item_impl.self_ty; + let generic = &item_impl.generics; + let where_clause = &item_impl.generics.where_clause; + let crate_name = get_crate_name(scalar_args.internal); + let expanded = quote! { + #item_impl + + impl #generic #crate_name::Type for #self_ty #where_clause { + fn type_name() -> std::borrow::Cow<'static, str> { + std::borrow::Cow::Borrowed(#gql_typename) + } + + fn create_type_info(registry: &mut #crate_name::registry::Registry) -> String { + registry.create_type::<#self_ty, _>(|_| #crate_name::registry::MetaType::Scalar { + name: #gql_typename.to_string(), + description: #desc, + is_valid: |value| <#self_ty as #crate_name::ScalarType>::is_valid(value), + }) + } + } + + impl #generic #crate_name::InputValueType for #self_ty #where_clause { + fn parse(value: #crate_name::Value) -> #crate_name::InputValueResult { + <#self_ty as #crate_name::ScalarType>::parse(value) + } + } + + #[allow(clippy::ptr_arg)] + #[#crate_name::async_trait::async_trait] + impl #generic #crate_name::OutputValueType for #self_ty #where_clause { + async fn resolve( + &self, + _: &#crate_name::ContextSelectionSet<'_>, + _pos: #crate_name::Pos, + ) -> #crate_name::Result<#crate_name::serde_json::Value> { + self.to_json() + } + } + }; + Ok(expanded.into()) +} diff --git a/docs/en/src/custom_scalars.md b/docs/en/src/custom_scalars.md index e6f2498d..c191bf66 100644 --- a/docs/en/src/custom_scalars.md +++ b/docs/en/src/custom_scalars.md @@ -14,11 +14,6 @@ struct StringNumber(i64); #[Scalar] impl ScalarType for StringNumber { - fn type_name() -> &'static str { - // Name of type - "StringNumber" - } - fn parse(value: Value) -> InputValueResult { if let Value::String(value) = value { // Parse the integer value diff --git a/docs/zh-CN/src/custom_scalars.md b/docs/zh-CN/src/custom_scalars.md index 1605ee5d..546f3007 100644 --- a/docs/zh-CN/src/custom_scalars.md +++ b/docs/zh-CN/src/custom_scalars.md @@ -14,11 +14,6 @@ struct StringNumber(i64); #[Scalar] impl ScalarType for StringNumber { - fn type_name() -> &'static str { - // 类型名 - "StringNumber" - } - fn parse(value: Value) -> InputValueResult { if let Value::String(value) = value { // 解析整数 diff --git a/src/base.rs b/src/base.rs index 82ac02eb..d416b4a0 100644 --- a/src/base.rs +++ b/src/base.rs @@ -124,10 +124,6 @@ pub trait InputObjectType: InputValueType {} /// /// #[Scalar] /// impl ScalarType for MyInt { -/// fn type_name() -> &'static str { -/// "MyInt" -/// } -/// /// fn parse(value: Value) -> InputValueResult { /// if let Value::Int(n) = value { /// Ok(MyInt(n as i32)) @@ -142,14 +138,6 @@ pub trait InputObjectType: InputValueType {} /// } /// ``` pub trait ScalarType: Sized + Send { - /// The type name of a scalar. - fn type_name() -> &'static str; - - /// The description of a scalar. - fn description() -> Option<&'static str> { - None - } - /// Parse a scalar value, return `Some(Self)` if successful, otherwise return `None`. fn parse(value: Value) -> InputValueResult; diff --git a/src/lib.rs b/src/lib.rs index 9b450bbd..52e3032c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -656,4 +656,12 @@ pub use async_graphql_derive::Subscription; pub use async_graphql_derive::DataSource; /// Define a Scalar +/// +/// # Macro parameters +/// +/// | Attribute | description | Type | Optional | +/// |-------------|---------------------------|----------|----------| +/// | name | Scalar name | string | Y | +/// | desc | Scalar description | string | Y | +/// pub use async_graphql_derive::Scalar; diff --git a/src/scalars/any.rs b/src/scalars/any.rs index 15d03dc2..99679428 100644 --- a/src/scalars/any.rs +++ b/src/scalars/any.rs @@ -8,16 +8,9 @@ use serde::de::DeserializeOwned; #[derive(Clone, PartialEq, Debug)] pub struct Any(pub Value); -#[Scalar(internal)] +/// The `_Any` scalar is used to pass representations of entities from external services into the root `_entities` field for execution. +#[Scalar(internal, name = "_Any")] impl ScalarType for Any { - fn type_name() -> &'static str { - "_Any" - } - - fn description() -> Option<&'static str> { - Some("The `_Any` scalar is used to pass representations of entities from external services into the root `_entities` field for execution.") - } - fn parse(value: Value) -> InputValueResult { Ok(Self(value)) } diff --git a/src/scalars/bool.rs b/src/scalars/bool.rs index 0bb94e89..30aa84d9 100644 --- a/src/scalars/bool.rs +++ b/src/scalars/bool.rs @@ -1,16 +1,9 @@ use crate::{InputValueError, InputValueResult, Result, ScalarType, Value}; use async_graphql_derive::Scalar; -#[Scalar(internal)] +/// The `Boolean` scalar type represents `true` or `false`. +#[Scalar(internal, name = "Boolean")] impl ScalarType for bool { - fn type_name() -> &'static str { - "Boolean" - } - - fn description() -> Option<&'static str> { - Some("The `Boolean` scalar type represents `true` or `false`.") - } - fn parse(value: Value) -> InputValueResult { match value { Value::Boolean(n) => Ok(n), diff --git a/src/scalars/bson.rs b/src/scalars/bson.rs index 27014abc..bf77c23b 100644 --- a/src/scalars/bson.rs +++ b/src/scalars/bson.rs @@ -5,10 +5,6 @@ use chrono::{DateTime, Utc}; #[Scalar(internal)] impl ScalarType for ObjectId { - fn type_name() -> &'static str { - "ObjectId" - } - fn parse(value: Value) -> InputValueResult { match value { Value::String(s) => Ok(ObjectId::with_string(&s)?), @@ -21,12 +17,8 @@ impl ScalarType for ObjectId { } } -#[Scalar(internal)] +#[Scalar(internal, name = "DateTime")] impl ScalarType for UtcDateTime { - fn type_name() -> &'static str { - "DateTime" - } - fn parse(value: Value) -> InputValueResult { DateTime::::parse(value).map(UtcDateTime::from) } diff --git a/src/scalars/chrono_tz.rs b/src/scalars/chrono_tz.rs index 7790040f..eb2d53f3 100644 --- a/src/scalars/chrono_tz.rs +++ b/src/scalars/chrono_tz.rs @@ -3,12 +3,8 @@ use async_graphql_derive::Scalar; use chrono_tz::Tz; use std::str::FromStr; -#[Scalar(internal)] +#[Scalar(internal, name = "TimeZone")] impl ScalarType for Tz { - fn type_name() -> &'static str { - "TimeZone" - } - fn parse(value: Value) -> InputValueResult { match value { Value::String(s) => Ok(Tz::from_str(&s)?), diff --git a/src/scalars/datetime.rs b/src/scalars/datetime.rs index d82463cd..c5c8df26 100644 --- a/src/scalars/datetime.rs +++ b/src/scalars/datetime.rs @@ -5,12 +5,8 @@ use chrono::{DateTime, TimeZone, Utc}; /// Implement the DateTime scalar /// /// The input/output is a string in RFC3339 format. -#[Scalar(internal)] +#[Scalar(internal, name = "DateTime")] impl ScalarType for DateTime { - fn type_name() -> &'static str { - "DateTime" - } - fn parse(value: Value) -> InputValueResult { match value { Value::String(s) => Ok(Utc.datetime_from_str(&s, "%+")?), diff --git a/src/scalars/floats.rs b/src/scalars/floats.rs index 15531a90..3b9c3b9c 100644 --- a/src/scalars/floats.rs +++ b/src/scalars/floats.rs @@ -4,16 +4,9 @@ use async_graphql_derive::Scalar; macro_rules! impl_float_scalars { ($($ty:ty),*) => { $( - #[Scalar(internal)] + /// The `Float` scalar type represents signed double-precision fractional values as specified by [IEEE 754](https://en.wikipedia.org/wiki/IEEE_floating_point). + #[Scalar(internal, name = "Float")] impl ScalarType for $ty { - fn type_name() -> &'static str { - "Float" - } - - fn description() -> Option<&'static str> { - Some("The `Float` scalar type represents signed double-precision fractional values as specified by [IEEE 754](https://en.wikipedia.org/wiki/IEEE_floating_point).") - } - fn parse(value: Value) -> InputValueResult { match value { Value::Int(n) => Ok(n as Self), diff --git a/src/scalars/id.rs b/src/scalars/id.rs index 14ce2229..3a7e4ca6 100644 --- a/src/scalars/id.rs +++ b/src/scalars/id.rs @@ -75,10 +75,6 @@ impl PartialEq<&str> for ID { #[Scalar(internal)] impl ScalarType for ID { - fn type_name() -> &'static str { - "ID" - } - fn parse(value: Value) -> InputValueResult { match value { Value::Int(n) => Ok(ID(n.to_string())), diff --git a/src/scalars/integers.rs b/src/scalars/integers.rs index 40b5b90d..5a8f8c85 100644 --- a/src/scalars/integers.rs +++ b/src/scalars/integers.rs @@ -4,16 +4,9 @@ use async_graphql_derive::Scalar; macro_rules! impl_integer_scalars { ($($ty:ty),*) => { $( - #[Scalar(internal)] + /// The `Int` scalar type represents non-fractional signed whole numeric values. Int can represent values between -(2^31) and 2^31 - 1. + #[Scalar(internal, name = "Int")] impl ScalarType for $ty { - fn type_name() -> &'static str { - "Int" - } - - fn description() -> Option<&'static str> { - Some("The `Int` scalar type represents non-fractional signed whole numeric values. Int can represent values between -(2^31) and 2^31 - 1.") - } - fn parse(value: Value) -> InputValueResult { match value { Value::Int(n) => Ok(n as Self), @@ -41,16 +34,9 @@ impl_integer_scalars!(i8, i16, i32, u8, u16, u32); macro_rules! impl_int64_scalars { ($($ty:ty),*) => { $( - #[Scalar(internal)] + /// The `Int64` scalar type represents non-fractional signed whole numeric values. Int can represent values between -(2^64) and 2^64 - 1. + #[Scalar(internal, name = "Int64")] impl ScalarType for $ty { - fn type_name() -> &'static str { - "Int64" - } - - fn description() -> Option<&'static str> { - Some("The `Int64` scalar type represents non-fractional signed whole numeric values. Int can represent values between -(2^64) and 2^64 - 1.") - } - fn parse(value: Value) -> InputValueResult { match value { Value::Int(n) => Ok(n as Self), diff --git a/src/scalars/json.rs b/src/scalars/json.rs index d8ad77a6..2a0628ab 100644 --- a/src/scalars/json.rs +++ b/src/scalars/json.rs @@ -22,12 +22,9 @@ impl DerefMut for Json { } } -#[Scalar(internal)] +/// A scalar that can represent any JSON value. +#[Scalar(internal, name = "JSON")] impl ScalarType for Json { - fn type_name() -> &'static str { - "JSON" - } - fn parse(value: Value) -> InputValueResult { Ok(serde_json::from_value(value.into()).map(Json)?) } diff --git a/src/scalars/naive_date.rs b/src/scalars/naive_date.rs index 563a6b76..d7d25a24 100644 --- a/src/scalars/naive_date.rs +++ b/src/scalars/naive_date.rs @@ -2,13 +2,8 @@ use crate::{InputValueError, InputValueResult, Result, ScalarType, Value}; use async_graphql_derive::Scalar; use chrono::NaiveDate; -/// Implement the NaiveDate scalar #[Scalar(internal)] impl ScalarType for NaiveDate { - fn type_name() -> &'static str { - "NaiveDate" - } - fn parse(value: Value) -> InputValueResult { match value { Value::String(s) => Ok(NaiveDate::parse_from_str(&s, "%Y-%m-%d")?), diff --git a/src/scalars/naive_time.rs b/src/scalars/naive_time.rs index 2a7c469e..4cc8de38 100644 --- a/src/scalars/naive_time.rs +++ b/src/scalars/naive_time.rs @@ -2,13 +2,8 @@ use crate::{InputValueError, InputValueResult, Result, ScalarType, Value}; use async_graphql_derive::Scalar; use chrono::NaiveTime; -/// Implement the NaiveTime scalar #[Scalar(internal)] impl ScalarType for NaiveTime { - fn type_name() -> &'static str { - "NaiveTime" - } - fn parse(value: Value) -> InputValueResult { match value { Value::String(s) => Ok(NaiveTime::parse_from_str(&s, "%H:%M:%S")?), diff --git a/src/scalars/string.rs b/src/scalars/string.rs index 37cf8c94..60bc5df0 100644 --- a/src/scalars/string.rs +++ b/src/scalars/string.rs @@ -6,18 +6,9 @@ use crate::{ use async_graphql_derive::Scalar; use std::borrow::Cow; -const STRING_DESC: &str = "The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text."; - +/// The `String` scalar type represents textual data, represented as UTF-8 character sequences. The String type is most often used by GraphQL to represent free-form human-readable text. #[Scalar(internal)] impl ScalarType for String { - fn type_name() -> &'static str { - "String" - } - - fn description() -> Option<&'static str> { - Some(STRING_DESC) - } - fn parse(value: Value) -> InputValueResult { match value { Value::String(s) => Ok(s), @@ -43,14 +34,7 @@ impl<'a> Type for &'a str { } fn create_type_info(registry: &mut registry::Registry) -> String { - registry.create_type::(|_| registry::MetaType::Scalar { - name: Self::type_name().to_string(), - description: Some(STRING_DESC), - is_valid: |value| match value { - Value::String(_) => true, - _ => false, - }, - }) + ::create_type_info(registry) } } diff --git a/src/scalars/url.rs b/src/scalars/url.rs index 5c5d3c53..ef9e48a3 100644 --- a/src/scalars/url.rs +++ b/src/scalars/url.rs @@ -4,10 +4,6 @@ use url::Url; #[Scalar(internal)] impl ScalarType for Url { - fn type_name() -> &'static str { - "Url" - } - fn parse(value: Value) -> InputValueResult { match value { Value::String(s) => Ok(Url::parse(&s)?), diff --git a/src/scalars/uuid.rs b/src/scalars/uuid.rs index 738738e5..f50c8cb9 100644 --- a/src/scalars/uuid.rs +++ b/src/scalars/uuid.rs @@ -2,12 +2,8 @@ use crate::{InputValueError, InputValueResult, Result, ScalarType, Value}; use async_graphql_derive::Scalar; use uuid::Uuid; -#[Scalar(internal)] +#[Scalar(internal, name = "UUID")] impl ScalarType for Uuid { - fn type_name() -> &'static str { - "UUID" - } - fn parse(value: Value) -> InputValueResult { match value { Value::String(s) => Ok(Uuid::parse_str(&s)?), diff --git a/src/types/connection/cursor.rs b/src/types/connection/cursor.rs index abd4351f..06b05cce 100644 --- a/src/types/connection/cursor.rs +++ b/src/types/connection/cursor.rs @@ -40,10 +40,6 @@ impl From for Cursor { #[Scalar(internal)] impl ScalarType for Cursor { - fn type_name() -> &'static str { - "Cursor" - } - fn parse(value: Value) -> InputValueResult { match value { Value::String(s) => Ok(Cursor(s)), diff --git a/tests/introspection.rs b/tests/introspection.rs index 2bfe511f..009aa591 100644 --- a/tests/introspection.rs +++ b/tests/introspection.rs @@ -75,10 +75,6 @@ struct TestScalar(i32); #[Scalar(desc = "Test scalar")] impl ScalarType for TestScalar { - fn type_name() -> &'static str { - "TestScalar" - } - fn parse(_value: Value) -> InputValueResult { Ok(TestScalar(42)) } @@ -689,7 +685,7 @@ pub async fn test_introspection_scalar() { "__type": { "kind": "SCALAR", "name": "TestScalar", - "description": null, + "description": "Test scalar", } });