diff --git a/derive/src/subscription.rs b/derive/src/subscription.rs index 40d10dd2..b5d93964 100644 --- a/derive/src/subscription.rs +++ b/derive/src/subscription.rs @@ -20,6 +20,8 @@ pub fn generate( ) -> GeneratorResult { let crate_name = get_crate_name(subscription_args.internal); let (self_ty, self_name) = get_type_path_and_name(item_impl.self_ty.as_ref())?; + let generics = &item_impl.generics; + let where_clause = &item_impl.generics.where_clause; let extends = subscription_args.extends; let gql_typename = subscription_args @@ -426,7 +428,7 @@ pub fn generate( #item_impl #[allow(clippy::all, clippy::pedantic)] - impl #crate_name::Type for #self_ty { + impl #generics #crate_name::Type for #self_ty #where_clause { fn type_name() -> ::std::borrow::Cow<'static, ::std::primitive::str> { ::std::borrow::Cow::Borrowed(#gql_typename) } @@ -451,7 +453,7 @@ pub fn generate( #[allow(clippy::all, clippy::pedantic)] #[allow(unused_braces, unused_variables)] - impl #crate_name::SubscriptionType for #self_ty { + impl #generics #crate_name::SubscriptionType for #self_ty #where_clause { fn create_field_stream<'__life>( &'__life self, ctx: &'__life #crate_name::Context<'_>, diff --git a/tests/generic_types.rs b/tests/generic_types.rs index 7ec36e82..bcf7f17f 100644 --- a/tests/generic_types.rs +++ b/tests/generic_types.rs @@ -1,4 +1,5 @@ use async_graphql::*; +use futures_util::stream::{Stream, StreamExt}; #[async_std::test] pub async fn test_generic_object() { @@ -122,7 +123,7 @@ pub async fn test_input_object_generic() { "__type": { "fields": [ { - "name": "q1", + "name": "q1", "args": [{ "name": "input", "type": { @@ -132,7 +133,7 @@ pub async fn test_input_object_generic() { }] }, { - "name": "q2", + "name": "q2", "args": [{ "name": "input", "type": { @@ -249,3 +250,45 @@ pub async fn test_generic_simple_object() { }) ); } + +#[async_std::test] +pub async fn test_generic_subscription() { + struct MySubscription { + values: Vec, + } + + #[Subscription] + impl MySubscription + where + T: Clone + Send + Sync + Unpin, + { + async fn values(&self) -> Result + '_> { + Ok(async_stream::stream! { + for value in self.values.iter().cloned() { + yield value + } + }) + } + } + + struct Query; + + #[Object] + impl Query { + async fn dummy(&self) -> bool { + false + } + } + + let schema = Schema::new(Query, EmptyMutation, MySubscription { values: vec![1, 2] }); + { + let mut stream = schema + .execute_stream("subscription { values }") + .map(|resp| resp.into_result().unwrap().data) + .boxed(); + for i in 1..=2 { + assert_eq!(value!({ "values": i }), stream.next().await.unwrap()); + } + assert!(stream.next().await.is_none()); + } +}