diff --git a/src/schema.rs b/src/schema.rs index 7a1b9b0c..daf07a16 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -453,8 +453,11 @@ where resp } - /// Execute an GraphQL subscription. - pub fn execute_stream(&self, request: impl Into) -> impl Stream { + pub(crate) fn execute_stream_with_ctx_data( + &self, + request: impl Into, + ctx_data: Arc, + ) -> impl Stream { let schema = self.clone(); async_stream::stream! { let request = request.into(); @@ -480,7 +483,7 @@ where extensions, request.variables, document, - Arc::new(request.ctx_data), + ctx_data, ); let ctx = env.create_context( @@ -507,4 +510,11 @@ where } } } + + /// Execute an GraphQL subscription. + pub fn execute_stream(&self, request: impl Into) -> impl Stream { + let mut request = request.into(); + let ctx_data = std::mem::replace(&mut request.ctx_data, Default::default()); + self.execute_stream_with_ctx_data(request, Arc::new(ctx_data)) + } } diff --git a/src/subscription/transports/websocket/mod.rs b/src/subscription/transports/websocket/mod.rs index ecb2f0eb..24298956 100644 --- a/src/subscription/transports/websocket/mod.rs +++ b/src/subscription/transports/websocket/mod.rs @@ -1,6 +1,6 @@ //! WebSocket transport for subscription -use crate::{http, Data, FieldResult, ObjectType, Response, Schema, SubscriptionType}; +use crate::{http, Data, FieldResult, ObjectType, Request, Response, Schema, SubscriptionType}; use futures::channel::mpsc; use futures::task::{Context, Poll}; use futures::{Future, Stream, StreamExt}; @@ -181,8 +181,11 @@ where payload: None, }, ); + break; + } + Poll::Pending => { + break; } - Poll::Pending => break, } } } @@ -191,6 +194,10 @@ where ctx.streams.remove(&id); } } + + if !ctx.send_buf.is_empty() { + continue; + } } return Poll::Pending; @@ -227,7 +234,10 @@ where "start" => { if let (Some(id), Some(payload)) = (msg.id, msg.payload) { if let Ok(request) = serde_json::from_value::(payload) { - let stream = schema.execute_stream(request).boxed(); + let request = Request::from(request); + let stream = schema + .execute_stream_with_ctx_data(request, ctx.ctx_data.clone()) + .boxed(); ctx.streams.insert(id, stream); } } diff --git a/tests/subscription_websocket.rs b/tests/subscription_websocket.rs index 511472c2..95db221b 100644 --- a/tests/subscription_websocket.rs +++ b/tests/subscription_websocket.rs @@ -1,291 +1,307 @@ -// use async_graphql::*; -// use futures::{SinkExt, Stream, StreamExt}; -// -// #[async_std::test] -// pub async fn test_subscription_ws_transport() { -// struct QueryRoot; -// -// #[Object] -// impl QueryRoot {} -// -// struct SubscriptionRoot; -// -// #[Subscription] -// impl SubscriptionRoot { -// async fn values(&self) -> impl Stream { -// futures::stream::iter(0..10) -// } -// } -// -// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); -// let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::default()); -// -// sink.send( -// serde_json::to_vec(&serde_json::json!({ -// "type": "connection_init", -// "payload": { "token": "123456" } -// })) -// .unwrap(), -// ) -// .await -// .unwrap(); -// -// assert_eq!( -// Some(serde_json::json!({ -// "type": "connection_ack", -// })), -// serde_json::from_slice(&stream.next().await.unwrap()).unwrap() -// ); -// -// sink.send( -// serde_json::to_vec(&serde_json::json!({ -// "type": "start", -// "id": "1", -// "payload": { -// "query": "subscription { values }" -// }, -// })) -// .unwrap(), -// ) -// .await -// .unwrap(); -// -// for i in 0..10 { -// assert_eq!( -// Some(serde_json::json!({ -// "type": "data", -// "id": "1", -// "payload": { "data": { "values": i } }, -// })), -// serde_json::from_slice(&stream.next().await.unwrap()).unwrap() -// ); -// } -// } -// -// #[async_std::test] -// pub async fn test_subscription_ws_transport_with_token() { -// struct Token(String); -// -// struct QueryRoot; -// -// #[Object] -// impl QueryRoot {} -// -// struct SubscriptionRoot; -// -// #[Subscription] -// impl SubscriptionRoot { -// async fn values(&self, ctx: &Context<'_>) -> FieldResult> { -// if ctx.data_unchecked::().0 != "123456" { -// return Err("forbidden".into()); -// } -// Ok(futures::stream::iter(0..10)) -// } -// } -// -// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); -// -// let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::new(|value| { -// #[derive(serde::Deserialize)] -// struct Payload { -// token: String, -// } -// -// let payload: Payload = serde_json::from_value(value).unwrap(); -// let mut data = Data::default(); -// data.insert(Token(payload.token)); -// Ok(data) -// })); -// -// sink.send( -// serde_json::to_vec(&serde_json::json!({ -// "type": "connection_init", -// "payload": { "token": "123456" } -// })) -// .unwrap(), -// ) -// .await -// .unwrap(); -// -// assert_eq!( -// Some(serde_json::json!({ -// "type": "connection_ack", -// })), -// serde_json::from_slice(&stream.next().await.unwrap()).unwrap() -// ); -// -// sink.send( -// serde_json::to_vec(&serde_json::json!({ -// "type": "start", -// "id": "1", -// "payload": { -// "query": "subscription { values }" -// }, -// })) -// .unwrap(), -// ) -// .await -// .unwrap(); -// -// for i in 0..10 { -// assert_eq!( -// Some(serde_json::json!({ -// "type": "data", -// "id": "1", -// "payload": { "data": { "values": i } }, -// })), -// serde_json::from_slice(&stream.next().await.unwrap()).unwrap() -// ); -// } -// } -// -// #[async_std::test] -// pub async fn test_subscription_ws_transport_error() { -// struct QueryRoot; -// -// struct Event { -// value: i32, -// } -// -// #[Object] -// impl Event { -// async fn value(&self) -> FieldResult { -// if self.value < 5 { -// Ok(self.value) -// } else { -// Err("TestError".into()) -// } -// } -// } -// -// #[Object] -// impl QueryRoot {} -// -// struct SubscriptionRoot; -// -// #[Subscription] -// impl SubscriptionRoot { -// async fn events(&self) -> impl Stream { -// futures::stream::iter((0..10).map(|n| Event { value: n })) -// } -// } -// -// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); -// -// let (mut sink, mut stream) = -// schema.subscription_connection(WebSocketTransport::new(|_| Ok(Data::default()))); -// -// sink.send( -// serde_json::to_vec(&serde_json::json!({ -// "type": "connection_init" -// })) -// .unwrap(), -// ) -// .await -// .unwrap(); -// -// assert_eq!( -// Some(serde_json::json!({ -// "type": "connection_ack", -// })), -// serde_json::from_slice(&stream.next().await.unwrap()).unwrap() -// ); -// -// sink.send( -// serde_json::to_vec(&serde_json::json!({ -// "type": "start", -// "id": "1", -// "payload": { -// "query": "subscription { events { value } }" -// }, -// })) -// .unwrap(), -// ) -// .await -// .unwrap(); -// -// for i in 0i32..5 { -// assert_eq!( -// Some(serde_json::json!({ -// "type": "data", -// "id": "1", -// "payload": { "data": { "events": { "value": i } } }, -// })), -// serde_json::from_slice(&stream.next().await.unwrap()).unwrap() -// ); -// } -// -// assert_eq!( -// Some(serde_json::json!({ -// "type": "error", -// "id": "1", -// "payload": [{ -// "message": "TestError", -// "locations": [{"line": 1, "column": 25}], -// "path": ["events", "value"], -// }], -// })), -// serde_json::from_slice(&stream.next().await.unwrap()).unwrap() -// ); -// } -// -// #[async_std::test] -// pub async fn test_query_over_websocket() { -// struct QueryRoot; -// -// #[Object] -// impl QueryRoot { -// async fn value(&self) -> i32 { -// 999 -// } -// } -// -// let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription); -// let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::default()); -// -// sink.send( -// serde_json::to_vec(&serde_json::json!({ -// "type": "connection_init", -// })) -// .unwrap(), -// ) -// .await -// .unwrap(); -// -// assert_eq!( -// Some(serde_json::json!({ -// "type": "connection_ack", -// })), -// serde_json::from_slice(&stream.next().await.unwrap()).unwrap() -// ); -// -// sink.send( -// serde_json::to_vec(&serde_json::json!({ -// "type": "start", -// "id": "1", -// "payload": { -// "query": "query { value }" -// }, -// })) -// .unwrap(), -// ) -// .await -// .unwrap(); -// -// assert_eq!( -// Some(serde_json::json!({ -// "type": "data", -// "id": "1", -// "payload": { "data": { "value": 999 } }, -// })), -// serde_json::from_slice(&stream.next().await.unwrap()).unwrap() -// ); -// -// assert_eq!( -// Some(serde_json::json!({ -// "type": "complete", -// "id": "1", -// })), -// serde_json::from_slice(&stream.next().await.unwrap()).unwrap() -// ); -// } +use async_graphql::*; +use futures::{SinkExt, Stream, StreamExt}; + +#[async_std::test] +pub async fn test_subscription_ws_transport() { + struct QueryRoot; + + #[Object] + impl QueryRoot {} + + struct SubscriptionRoot; + + #[Subscription] + impl SubscriptionRoot { + async fn values(&self) -> impl Stream { + futures::stream::iter(0..10) + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); + let (mut sink, stream) = transports::websocket::create(&schema); + futures::pin_mut!(stream); + + sink.send( + serde_json::to_vec(&serde_json::json!({ + "type": "connection_init", + })) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + Some(serde_json::json!({ + "type": "connection_ack", + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); + + sink.send( + serde_json::to_vec(&serde_json::json!({ + "type": "start", + "id": "1", + "payload": { + "query": "subscription { values }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); + + for i in 0..10 { + assert_eq!( + Some(serde_json::json!({ + "type": "data", + "id": "1", + "payload": { "data": { "values": i } }, + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); + } + + assert_eq!( + Some(serde_json::json!({ + "type": "complete", + "id": "1", + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); +} + +#[async_std::test] +pub async fn test_subscription_ws_transport_with_token() { + struct Token(String); + + struct QueryRoot; + + #[Object] + impl QueryRoot {} + + struct SubscriptionRoot; + + #[Subscription] + impl SubscriptionRoot { + async fn values(&self, ctx: &Context<'_>) -> FieldResult> { + if ctx.data_unchecked::().0 != "123456" { + return Err("forbidden".into()); + } + Ok(futures::stream::iter(0..10)) + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); + let (mut sink, stream) = transports::websocket::create_with_initializer(&schema, |value| { + #[derive(serde::Deserialize)] + struct Payload { + token: String, + } + + let payload: Payload = serde_json::from_value(value).unwrap(); + let mut data = Data::default(); + data.insert(Token(payload.token)); + Ok(data) + }); + futures::pin_mut!(stream); + + sink.send( + serde_json::to_vec(&serde_json::json!({ + "type": "connection_init", + "payload": { "token": "123456" } + })) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + Some(serde_json::json!({ + "type": "connection_ack", + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); + + sink.send( + serde_json::to_vec(&serde_json::json!({ + "type": "start", + "id": "1", + "payload": { + "query": "subscription { values }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); + + for i in 0..10 { + assert_eq!( + Some(serde_json::json!({ + "type": "data", + "id": "1", + "payload": { "data": { "values": i } }, + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); + } + + assert_eq!( + Some(serde_json::json!({ + "type": "complete", + "id": "1", + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); +} + +#[async_std::test] +pub async fn test_subscription_ws_transport_error() { + struct QueryRoot; + + struct Event { + value: i32, + } + + #[Object] + impl Event { + async fn value(&self) -> FieldResult { + if self.value < 5 { + Ok(self.value) + } else { + Err("TestError".into()) + } + } + } + + #[Object] + impl QueryRoot {} + + struct SubscriptionRoot; + + #[Subscription] + impl SubscriptionRoot { + async fn events(&self) -> impl Stream { + futures::stream::iter((0..10).map(|n| Event { value: n })) + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); + let (mut sink, stream) = transports::websocket::create(&schema); + futures::pin_mut!(stream); + + sink.send( + serde_json::to_vec(&serde_json::json!({ + "type": "connection_init" + })) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + Some(serde_json::json!({ + "type": "connection_ack", + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); + + sink.send( + serde_json::to_vec(&serde_json::json!({ + "type": "start", + "id": "1", + "payload": { + "query": "subscription { events { value } }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); + + for i in 0i32..5 { + assert_eq!( + Some(serde_json::json!({ + "type": "data", + "id": "1", + "payload": { "data": { "events": { "value": i } } }, + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); + } + + assert_eq!( + Some(serde_json::json!({ + "type": "error", + "id": "1", + "payload": [{ + "message": "TestError", + "locations": [{"line": 1, "column": 25}], + "path": ["events", "value"], + }], + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); +} + +#[async_std::test] +pub async fn test_query_over_websocket() { + struct QueryRoot; + + #[Object] + impl QueryRoot { + async fn value(&self) -> i32 { + 999 + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription); + let (mut sink, stream) = transports::websocket::create(&schema); + futures::pin_mut!(stream); + + sink.send( + serde_json::to_vec(&serde_json::json!({ + "type": "connection_init", + })) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + Some(serde_json::json!({ + "type": "connection_ack", + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); + + sink.send( + serde_json::to_vec(&serde_json::json!({ + "type": "start", + "id": "1", + "payload": { + "query": "query { value }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + Some(serde_json::json!({ + "type": "data", + "id": "1", + "payload": { "data": { "value": 999 } }, + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); + + assert_eq!( + Some(serde_json::json!({ + "type": "complete", + "id": "1", + })), + serde_json::from_slice(&stream.next().await.unwrap()).unwrap() + ); +}