From efad20d4c475850f5502081db5c7c69b180c6225 Mon Sep 17 00:00:00 2001 From: sunli Date: Mon, 6 Apr 2020 18:30:38 +0800 Subject: [PATCH] Add subscription test --- async-graphql-derive/src/subscription.rs | 2 +- src/query.rs | 2 +- src/registry.rs | 16 ++--- src/schema.rs | 5 +- src/subscription/connection.rs | 4 +- src/validation/rules/known_type_names.rs | 2 +- src/validation/rules/scalar_leafs.rs | 2 +- src/validation/rules/upload_file.rs | 4 +- .../rules/variables_are_input_types.rs | 2 +- src/validation/visitor.rs | 4 +- tests/subscription.rs | 70 +++++++++++++++++++ 11 files changed, 91 insertions(+), 22 deletions(-) create mode 100644 tests/subscription.rs diff --git a/async-graphql-derive/src/subscription.rs b/async-graphql-derive/src/subscription.rs index 544b9d31..986cd0bb 100644 --- a/async-graphql-derive/src/subscription.rs +++ b/async-graphql-derive/src/subscription.rs @@ -187,7 +187,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< let schema = schema.clone(); let pos = ctx.position; let environment = environment.clone(); - let stream = #crate_name::futures::stream::StreamExt::then(self.#ident(#(#use_params)*).fuse(), move |msg| { + let stream = #crate_name::futures::stream::StreamExt::then(self.#ident(#(#use_params),*).fuse(), move |msg| { let environment = environment.clone(); let field_selection_set = field_selection_set.clone(); let schema = schema.clone(); diff --git a/src/query.rs b/src/query.rs index f29ba2f4..c3c693fb 100644 --- a/src/query.rs +++ b/src/query.rs @@ -104,7 +104,7 @@ impl QueryBuilder .schema .0 .registry - .basic_type_by_parsed_type(&d.var_type) + .concrete_type_by_parsed_type(&d.var_type) { if ty.name() == "Upload" { return true; diff --git a/src/registry.rs b/src/registry.rs index 6c8de736..b709665c 100644 --- a/src/registry.rs +++ b/src/registry.rs @@ -48,10 +48,10 @@ impl<'a> TypeName<'a> { } } - pub fn get_basic_typename(type_name: &str) -> &str { + pub fn concrete_typename(type_name: &str) -> &str { match TypeName::create(type_name) { - TypeName::List(type_name) => Self::get_basic_typename(type_name), - TypeName::NonNull(type_name) => Self::get_basic_typename(type_name), + TypeName::List(type_name) => Self::concrete_typename(type_name), + TypeName::NonNull(type_name) => Self::concrete_typename(type_name), TypeName::Named(type_name) => type_name, } } @@ -373,14 +373,14 @@ impl Registry { }); } - pub fn basic_type_by_typename(&self, type_name: &str) -> Option<&Type> { - self.types.get(TypeName::get_basic_typename(type_name)) + pub fn concrete_type_by_name(&self, type_name: &str) -> Option<&Type> { + self.types.get(TypeName::concrete_typename(type_name)) } - pub fn basic_type_by_parsed_type(&self, query_type: &ParsedType) -> Option<&Type> { + pub fn concrete_type_by_parsed_type(&self, query_type: &ParsedType) -> Option<&Type> { match query_type { - ParsedType::NonNullType(ty) => self.basic_type_by_parsed_type(ty), - ParsedType::ListType(ty) => self.basic_type_by_parsed_type(ty), + ParsedType::NonNullType(ty) => self.concrete_type_by_parsed_type(ty), + ParsedType::ListType(ty) => self.concrete_type_by_parsed_type(ty), ParsedType::NamedType(name) => self.types.get(name.as_str()), } } diff --git a/src/schema.rs b/src/schema.rs index ac715301..5c7b86f6 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -18,7 +18,6 @@ use graphql_parser::query::{Definition, OperationDefinition}; use itertools::Itertools; use std::any::Any; use std::collections::HashMap; -use std::pin::Pin; use std::sync::atomic::AtomicUsize; use std::sync::Arc; @@ -270,7 +269,7 @@ where source: &str, operation_name: Option<&str>, variables: Variables, - ) -> Result>>> { + ) -> Result> { let document = parse_query(source).map_err(Into::::into)?; check_rules(&self.0.registry, &document)?; @@ -317,7 +316,7 @@ where let mut streams = Vec::new(); create_subscription_stream(self, Arc::new(ctx.create_environment()), &ctx, &mut streams)?; - Ok(Box::pin(futures::stream::select_all(streams))) + Ok(futures::stream::select_all(streams)) } /// Create subscription connection, returns `Sink` and `Stream`. diff --git a/src/subscription/connection.rs b/src/subscription/connection.rs index e39ce3dd..d3b9568b 100644 --- a/src/subscription/connection.rs +++ b/src/subscription/connection.rs @@ -14,8 +14,8 @@ pub struct SubscriptionStreams { #[allow(missing_docs)] impl SubscriptionStreams { - pub fn add(&mut self, stream: Pin>>) -> usize { - self.streams.insert(stream) + pub fn add + 'static>(&mut self, stream: S) -> usize { + self.streams.insert(Box::pin(stream)) } pub fn remove(&mut self, id: usize) { diff --git a/src/validation/rules/known_type_names.rs b/src/validation/rules/known_type_names.rs index 1cb5d7cf..61c55aa9 100644 --- a/src/validation/rules/known_type_names.rs +++ b/src/validation/rules/known_type_names.rs @@ -25,7 +25,7 @@ impl<'a> Visitor<'a> for KnownTypeNames { ) { validate_type( ctx, - TypeName::get_basic_typename(&variable_definition.var_type.to_string()), + TypeName::concrete_typename(&variable_definition.var_type.to_string()), variable_definition.position, ); } diff --git a/src/validation/rules/scalar_leafs.rs b/src/validation/rules/scalar_leafs.rs index c2ef4a36..806df063 100644 --- a/src/validation/rules/scalar_leafs.rs +++ b/src/validation/rules/scalar_leafs.rs @@ -8,7 +8,7 @@ impl<'a> Visitor<'a> for ScalarLeafs { fn enter_field(&mut self, ctx: &mut VisitorContext<'a>, field: &'a Field) { if let Some(ty) = ctx.parent_type() { if let Some(schema_field) = ty.field_by_name(&field.name) { - if let Some(ty) = ctx.registry.basic_type_by_typename(&schema_field.ty) { + if let Some(ty) = ctx.registry.concrete_type_by_name(&schema_field.ty) { if ty.is_leaf() && !field.selection_set.items.is_empty() { ctx.report_error(vec![field.position], format!( "Field \"{}\" must not have a selection since type \"{}\" has no subfields", diff --git a/src/validation/rules/upload_file.rs b/src/validation/rules/upload_file.rs index c3197483..f07c34c0 100644 --- a/src/validation/rules/upload_file.rs +++ b/src/validation/rules/upload_file.rs @@ -12,7 +12,7 @@ impl<'a> Visitor<'a> for UploadFile { ) { if let OperationDefinition::Query(query) = operation_definition { for var in &query.variable_definitions { - if let Some(ty) = ctx.registry.basic_type_by_parsed_type(&var.var_type) { + if let Some(ty) = ctx.registry.concrete_type_by_parsed_type(&var.var_type) { if ty.name() == "Upload" { ctx.report_error( vec![var.position], @@ -23,7 +23,7 @@ impl<'a> Visitor<'a> for UploadFile { } } else if let OperationDefinition::Subscription(subscription) = operation_definition { for var in &subscription.variable_definitions { - if let Some(ty) = ctx.registry.basic_type_by_parsed_type(&var.var_type) { + if let Some(ty) = ctx.registry.concrete_type_by_parsed_type(&var.var_type) { if ty.name() == "Upload" { ctx.report_error( vec![var.position], diff --git a/src/validation/rules/variables_are_input_types.rs b/src/validation/rules/variables_are_input_types.rs index 53fcff2f..d5e7d65d 100644 --- a/src/validation/rules/variables_are_input_types.rs +++ b/src/validation/rules/variables_are_input_types.rs @@ -12,7 +12,7 @@ impl<'a> Visitor<'a> for VariablesAreInputTypes { ) { if let Some(ty) = ctx .registry - .basic_type_by_parsed_type(&variable_definition.var_type) + .concrete_type_by_parsed_type(&variable_definition.var_type) { if !ty.is_input() { ctx.report_error( diff --git a/src/validation/visitor.rs b/src/validation/visitor.rs index 1fd974e5..3d1367ca 100644 --- a/src/validation/visitor.rs +++ b/src/validation/visitor.rs @@ -513,7 +513,7 @@ fn visit_selection<'a, V: Visitor<'a>>( ctx.current_type() .and_then(|ty| ty.field_by_name(&field.name)) .and_then(|schema_field| { - ctx.registry.basic_type_by_typename(&schema_field.ty) + ctx.registry.concrete_type_by_name(&schema_field.ty) }), |ctx| { visit_field(v, ctx, field); @@ -583,7 +583,7 @@ fn visit_input_value<'a, V: Visitor<'a>>( if let Some(ty) = ctx .registry .types - .get(TypeName::get_basic_typename(expected_ty)) + .get(TypeName::concrete_typename(expected_ty)) { if let Type::InputObject { input_fields, .. } = ty { for (item_key, item_value) in values { diff --git a/tests/subscription.rs b/tests/subscription.rs new file mode 100644 index 00000000..a780908c --- /dev/null +++ b/tests/subscription.rs @@ -0,0 +1,70 @@ +use async_graphql::*; +use futures::{Stream, StreamExt}; + +#[async_std::test] +pub async fn test_subscription() { + struct QueryRoot; + + #[SimpleObject] + struct Event { + #[field] + a: i32, + + #[field] + b: i32, + } + + #[Object] + impl QueryRoot {} + + struct SubscriptionRoot; + + #[Subscription] + impl SubscriptionRoot { + #[field] + fn values(&self, start: i32, end: i32) -> impl Stream { + futures::stream::iter(start..end) + } + + #[field] + fn events(&self, start: i32, end: i32) -> impl Stream { + futures::stream::iter((start..end).map(|n| Event { a: n, b: n * 10 })) + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); + + { + let mut stream = schema + .create_subscription_stream( + "subscription { values(start: 10, end: 20) }", + None, + Default::default(), + ) + .unwrap(); + for i in 10..20 { + assert_eq!( + Some(serde_json::json!({ "values": i })), + stream.next().await + ); + } + assert!(stream.next().await.is_none()); + } + + { + let mut stream = schema + .create_subscription_stream( + "subscription { events(start: 10, end: 20) { a b } }", + None, + Default::default(), + ) + .unwrap(); + for i in 10..20 { + assert_eq!( + Some(serde_json::json!({ "events": {"a": i, "b": i * 10} })), + stream.next().await + ); + } + assert!(stream.next().await.is_none()); + } +}