diff --git a/async-graphql-derive/src/interface.rs b/async-graphql-derive/src/interface.rs index fe38045f..82c5d6a2 100644 --- a/async-graphql-derive/src/interface.rs +++ b/async-graphql-derive/src/interface.rs @@ -38,6 +38,7 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result let mut registry_types = Vec::new(); let mut possible_types = Vec::new(); let mut collect_inline_fields = Vec::new(); + let mut get_introspection_typename = Vec::new(); for field in &fields.unnamed { if let Type::Path(p) = &field.ty { @@ -66,6 +67,9 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result unreachable!() } }); + get_introspection_typename.push(quote! { + #ident::#enum_name(obj) => <#p as #crate_name::Type>::type_name() + }) } else { return Err(Error::new_spanned(field, "Invalid type")); } @@ -225,6 +229,12 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result std::borrow::Cow::Borrowed(#gql_typename) } + fn introspection_type_name(&self) -> std::borrow::Cow<'static, str> { + match self { + #(#get_introspection_typename),* + } + } + fn create_type_info(registry: &mut #crate_name::registry::Registry) -> String { registry.create_type::(|registry| { #(#registry_types)* diff --git a/async-graphql-derive/src/union.rs b/async-graphql-derive/src/union.rs index c986f9cb..61e9a4c4 100644 --- a/async-graphql-derive/src/union.rs +++ b/async-graphql-derive/src/union.rs @@ -35,6 +35,7 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result let mut registry_types = Vec::new(); let mut possible_types = Vec::new(); let mut collect_inline_fields = Vec::new(); + let mut get_introspection_typename = Vec::new(); for field in &fields.unnamed { if let Type::Path(p) = &field.ty { @@ -62,6 +63,9 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result unreachable!() } }); + get_introspection_typename.push(quote! { + #ident::#enum_name(obj) => <#p as #crate_name::Type>::type_name() + }) } else { return Err(Error::new_spanned(field, "Invalid type")); } @@ -78,6 +82,12 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result std::borrow::Cow::Borrowed(#gql_typename) } + fn introspection_type_name(&self) -> std::borrow::Cow<'static, str> { + match self { + #(#get_introspection_typename),* + } + } + fn create_type_info(registry: &mut #crate_name::registry::Registry) -> String { registry.create_type::(|registry| { #(#registry_types)* diff --git a/src/base.rs b/src/base.rs index 196cccaf..b15ffa84 100644 --- a/src/base.rs +++ b/src/base.rs @@ -18,6 +18,13 @@ pub trait Type { format!("{}!", Self::type_name()) } + /// Introspection type name + /// + /// Is the return value of field `__type`, the interface and union should return the current type, and the others return `Type::type_name`. + fn introspection_type_name(&self) -> Cow<'static, str> { + Self::type_name() + } + /// Create type information in the registry and return qualified typename. fn create_type_info(registry: &mut registry::Registry) -> String; diff --git a/src/registry.rs b/src/registry.rs index 55dfe0ea..8a4807a9 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -282,20 +282,7 @@ impl Registry { cache_control: Default::default(), }, ); - let mut ty = f(self); - if let Type::Object { fields, .. } = &mut ty { - fields.insert( - "__typename".to_string(), - Field { - name: "__typename".to_string(), - description: None, - args: Default::default(), - ty: "String!".to_string(), - deprecation: None, - cache_control: Default::default(), - }, - ); - } + let ty = f(self); self.types.insert(name.to_string(), ty); } T::qualified_type_name() diff --git a/src/resolver.rs b/src/resolver.rs index 31047267..84feadf0 100644 --- a/src/resolver.rs +++ b/src/resolver.rs @@ -45,8 +45,10 @@ pub fn collect_fields<'a, T: ObjectType + Send + Sync>( let ctx_field = ctx.with_field(field); let field_name = ctx_field.result_name().to_string(); futures.push(Box::pin( - future::ok::(T::type_name().to_string().into()) - .map_ok(move |value| (field_name, value)), + future::ok::( + root.introspection_type_name().to_string().into(), + ) + .map_ok(move |value| (field_name, value)), )); continue; } diff --git a/src/validation/rules/fields_on_correct_type.rs b/src/validation/rules/fields_on_correct_type.rs index 09a4b3e6..8087df6d 100644 --- a/src/validation/rules/fields_on_correct_type.rs +++ b/src/validation/rules/fields_on_correct_type.rs @@ -13,7 +13,9 @@ impl<'a> Visitor<'a> for FieldsOnCorrectType { .field_by_name(&field.name) .is_none() { - if let Some(registry::Type::Union { .. }) = ctx.parent_type() { + if let Some(registry::Type::Union { .. }) | Some(registry::Type::Interface { .. }) = + ctx.parent_type() + { if field.name == "__typename" { return; } diff --git a/src/validation/visitor.rs b/src/validation/visitor.rs index 20ec26ad..3d319350 100644 --- a/src/validation/visitor.rs +++ b/src/validation/visitor.rs @@ -480,24 +480,26 @@ fn visit_selection<'a, V: Visitor<'a>>( v.enter_selection(ctx, selection); match selection { Selection::Field(field) => { - if let Some(schema_field) = ctx.current_type().field_by_name(&field.name) { - ctx.with_type( - ctx.registry - .basic_type_by_typename(&schema_field.ty) - .unwrap(), - |ctx| { - visit_field(v, ctx, field); - }, - ); - } else { - ctx.report_error( - vec![field.position], - format!( - "Cannot query field \"{}\" on type \"{}\".", - field.name, - ctx.current_type().name() - ), - ); + if field.name != "__typename" { + if let Some(schema_field) = ctx.current_type().field_by_name(&field.name) { + ctx.with_type( + ctx.registry + .basic_type_by_typename(&schema_field.ty) + .unwrap(), + |ctx| { + visit_field(v, ctx, field); + }, + ); + } else { + ctx.report_error( + vec![field.position], + format!( + "Cannot query field \"{}\" on type \"{}\".", + field.name, + ctx.current_type().name() + ), + ); + } } } Selection::FragmentSpread(fragment_spread) => {