From e991ffb8143da2aced8bda1c2ef52d23f75018b2 Mon Sep 17 00:00:00 2001 From: Sunli Date: Wed, 24 Mar 2021 17:24:12 +0800 Subject: [PATCH] Fix internal context missing when there's no connection_init frame. #451 --- src/http/websocket.rs | 24 +++++++------ tests/subscription_websocket_graphql_ws.rs | 34 +++++++++++++++++++ ...on_websocket_subscriptions_transport_ws.rs | 34 +++++++++++++++++++ 3 files changed, 82 insertions(+), 10 deletions(-) diff --git a/src/http/websocket.rs b/src/http/websocket.rs index 9c4ff34e..feb2fafe 100644 --- a/src/http/websocket.rs +++ b/src/http/websocket.rs @@ -61,7 +61,7 @@ pin_project! { pub struct WebSocket { data_initializer: Option, init_fut: Option>>, - data: Arc, + data: Option>, schema: Schema, streams: HashMap + Send>>>, #[pin] @@ -105,7 +105,7 @@ impl WebSocket { - this.streams.insert( - id, - Box::pin( - this.schema - .execute_stream_with_ctx_data(request, Arc::clone(this.data)), - ), - ); + if let Some(data) = this.data.clone() { + this.streams.insert( + id, + Box::pin(this.schema.execute_stream_with_ctx_data(request, data)), + ); + } else { + return Poll::Ready(Some(WsMessage::Close( + 1011, + "The handshake is not completed.".to_string(), + ))); + } } ClientMessage::Stop { id } => { if this.streams.remove(id).is_some() { @@ -202,7 +206,7 @@ where *this.init_fut = None; return match res { Ok(data) => { - *this.data = Arc::new(data); + *this.data = Some(Arc::new(data)); Poll::Ready(Some(WsMessage::Text( serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(), ))) diff --git a/tests/subscription_websocket_graphql_ws.rs b/tests/subscription_websocket_graphql_ws.rs index 13f21d7a..d8434447 100644 --- a/tests/subscription_websocket_graphql_ws.rs +++ b/tests/subscription_websocket_graphql_ws.rs @@ -430,3 +430,37 @@ pub async fn test_query_over_websocket() { serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap() ); } + +#[tokio::test] +pub async fn test_start_before_connection_init() { + struct QueryRoot; + + #[Object] + impl QueryRoot { + async fn value(&self) -> i32 { + 999 + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription); + let (mut tx, rx) = mpsc::unbounded(); + let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS); + + tx.send( + serde_json::to_string(&value!({ + "type": "start", + "id": "1", + "payload": { + "query": "query { value }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + stream.next().await.unwrap().unwrap_close(), + (1011, "The handshake is not completed.".to_string()) + ); +} diff --git a/tests/subscription_websocket_subscriptions_transport_ws.rs b/tests/subscription_websocket_subscriptions_transport_ws.rs index 57d74be0..25a3258f 100644 --- a/tests/subscription_websocket_subscriptions_transport_ws.rs +++ b/tests/subscription_websocket_subscriptions_transport_ws.rs @@ -391,3 +391,37 @@ pub async fn test_query_over_websocket() { serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap() ); } + +#[tokio::test] +pub async fn test_start_before_connection_init() { + struct QueryRoot; + + #[Object] + impl QueryRoot { + async fn value(&self) -> i32 { + 999 + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription); + let (mut tx, rx) = mpsc::unbounded(); + let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS); + + tx.send( + serde_json::to_string(&value!({ + "type": "start", + "id": "1", + "payload": { + "query": "query { value }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + stream.next().await.unwrap().unwrap_close(), + (1011, "The handshake is not completed.".to_string()) + ); +}