From 5f25b296b0a91c3e5894c555513fc555008b7c69 Mon Sep 17 00:00:00 2001 From: Patrick Fernie Date: Wed, 2 Dec 2020 08:13:00 -0500 Subject: [PATCH 1/7] Add `graphql-ws` feature for subprotcol selection On 2020-09-14 a revised version of the [GraphQL over WebSocket](https://the-guild.dev/blog/graphql-over-websockets) spec was released as [graphql-ws](https://github.com/enisdenjo/graphql-ws), which differs from the de facto standard implemented by [subscriptions-transport-ws](https://github.com/apollographql/subscriptions-transport-ws). This adds a new `cargo` (non-default) feature, `graphql_ws`, which when enabled utilizes the new protocol spec. Original [subscriptions-transport-ws protocol](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md) Revised [graphql-ws protocol](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md) --- Cargo.toml | 2 ++ integrations/warp/Cargo.toml | 4 ++++ integrations/warp/src/subscription.rs | 13 ++++++++++++- src/http/websocket.rs | 3 +++ 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 63a15b51..75dd9eee 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,8 @@ unblock = ["blocking"] string_number = ["num-traits"] # Used for doc(cfg()) nightly = [] +# To enable the use of [graphql-ws spec](https://github.com/enisdenjo/graphql-ws) +graphql_ws = [] [dependencies] async-graphql-derive = { path = "derive", version = "=2.1.4" } diff --git a/integrations/warp/Cargo.toml b/integrations/warp/Cargo.toml index 44461368..4e3c49f1 100644 --- a/integrations/warp/Cargo.toml +++ b/integrations/warp/Cargo.toml @@ -11,6 +11,10 @@ repository = "https://github.com/async-graphql/async-graphql" keywords = ["futures", "async", "graphql"] categories = ["network-programming", "asynchronous"] +[features] +# To enable the use of [graphql-ws spec](https://github.com/enisdenjo/graphql-ws) +graphql_ws = [] + [dependencies] async-graphql = { path = "../..", version = "=2.1.4" } diff --git a/integrations/warp/src/subscription.rs b/integrations/warp/src/subscription.rs index 60199e5e..4ba17483 100644 --- a/integrations/warp/src/subscription.rs +++ b/integrations/warp/src/subscription.rs @@ -86,6 +86,17 @@ where .await; } }); - warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws") + + #[cfg(not(feature = "graphql_ws"))] + { + // confusingly, the old subprotocol name for the susbscription-transport-ws spec was + // `graphql-ws` + warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws") + } + #[cfg(feature = "graphql_ws")] + { + // ...and the new one `graphql-transport-ws` + warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-transport-ws") + } }) } diff --git a/src/http/websocket.rs b/src/http/websocket.rs index 91096a40..dd5402ed 100644 --- a/src/http/websocket.rs +++ b/src/http/websocket.rs @@ -169,7 +169,9 @@ where #[serde(tag = "type", rename_all = "snake_case")] enum ClientMessage<'a> { ConnectionInit { payload: Option }, + #[cfg_attr(feature = "graphql_ws", serde(rename = "subscribe"))] Start { id: String, payload: Request }, + #[cfg_attr(feature = "graphql_ws", serde(rename = "complete"))] Stop { id: &'a str }, ConnectionTerminate, } @@ -179,6 +181,7 @@ enum ClientMessage<'a> { enum ServerMessage<'a> { ConnectionError { payload: Error }, ConnectionAck, + #[cfg_attr(feature = "graphql_ws", serde(rename = "next"))] Data { id: &'a str, payload: Box }, // Not used by this library, as it's not necessary to send // Error { From 9d12337972247130b60db11a42915d340adf49d2 Mon Sep 17 00:00:00 2001 From: Patrick Fernie Date: Wed, 2 Dec 2020 08:41:12 -0500 Subject: [PATCH 2/7] rustfmt fixes --- src/http/websocket.rs | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/src/http/websocket.rs b/src/http/websocket.rs index dd5402ed..15389dc3 100644 --- a/src/http/websocket.rs +++ b/src/http/websocket.rs @@ -168,27 +168,41 @@ where #[derive(Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] enum ClientMessage<'a> { - ConnectionInit { payload: Option }, + ConnectionInit { + payload: Option, + }, #[cfg_attr(feature = "graphql_ws", serde(rename = "subscribe"))] - Start { id: String, payload: Request }, + Start { + id: String, + payload: Request, + }, #[cfg_attr(feature = "graphql_ws", serde(rename = "complete"))] - Stop { id: &'a str }, + Stop { + id: &'a str, + }, ConnectionTerminate, } #[derive(Serialize)] #[serde(tag = "type", rename_all = "snake_case")] enum ServerMessage<'a> { - ConnectionError { payload: Error }, + ConnectionError { + payload: Error, + }, ConnectionAck, #[cfg_attr(feature = "graphql_ws", serde(rename = "next"))] - Data { id: &'a str, payload: Box }, + Data { + id: &'a str, + payload: Box, + }, // Not used by this library, as it's not necessary to send // Error { // id: &'a str, // payload: serde_json::Value, // }, - Complete { id: &'a str }, + Complete { + id: &'a str, + }, // Not used by this library // #[serde(rename = "ka")] // KeepAlive From f7b9fa03f634b3b707382eba52a54b76cdf217e5 Mon Sep 17 00:00:00 2001 From: Patrick Fernie Date: Thu, 3 Dec 2020 16:16:07 -0500 Subject: [PATCH 3/7] use `Sec-WebSocket-Protocol` header to negotiate protocol --- Cargo.toml | 2 - integrations/warp/Cargo.toml | 4 -- integrations/warp/src/subscription.rs | 73 +++++++++++++----------- src/http/mod.rs | 2 +- src/http/websocket.rs | 80 ++++++++++++++++++++++++--- 5 files changed, 113 insertions(+), 48 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 75dd9eee..63a15b51 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,8 +34,6 @@ unblock = ["blocking"] string_number = ["num-traits"] # Used for doc(cfg()) nightly = [] -# To enable the use of [graphql-ws spec](https://github.com/enisdenjo/graphql-ws) -graphql_ws = [] [dependencies] async-graphql-derive = { path = "derive", version = "=2.1.4" } diff --git a/integrations/warp/Cargo.toml b/integrations/warp/Cargo.toml index 4e3c49f1..44461368 100644 --- a/integrations/warp/Cargo.toml +++ b/integrations/warp/Cargo.toml @@ -11,10 +11,6 @@ repository = "https://github.com/async-graphql/async-graphql" keywords = ["futures", "async", "graphql"] categories = ["network-programming", "asynchronous"] -[features] -# To enable the use of [graphql-ws spec](https://github.com/enisdenjo/graphql-ws) -graphql_ws = [] - [dependencies] async-graphql = { path = "../..", version = "=2.1.4" } diff --git a/integrations/warp/src/subscription.rs b/integrations/warp/src/subscription.rs index 4ba17483..171d7aea 100644 --- a/integrations/warp/src/subscription.rs +++ b/integrations/warp/src/subscription.rs @@ -64,39 +64,48 @@ where Subscription: SubscriptionType + Send + Sync + 'static, F: FnOnce(serde_json::Value) -> Result + Send + Sync + Clone + 'static, { - warp::ws().map(move |ws: ws::Ws| { - let schema = schema.clone(); - let initializer = initializer.clone(); + use async_graphql::http::WebSocketProtocols; - let reply = ws.on_upgrade(move |websocket| { - let (ws_sender, ws_receiver) = websocket.split(); + warp::ws() + .and(warp::header::optional::( + "sec-websocket-protocol", + )) + .map(move |ws: ws::Ws, protocol| { + let schema = schema.clone(); + let initializer = initializer.clone(); - async move { - let _ = async_graphql::http::WebSocket::with_data( - schema, - ws_receiver - .take_while(|msg| future::ready(msg.is_ok())) - .map(Result::unwrap) - .map(ws::Message::into_bytes), - initializer, - ) - .map(ws::Message::text) - .map(Ok) - .forward(ws_sender) - .await; - } - }); + let protocol = match protocol { + Some(protocol) => protocol, + None => { + // default to the prior standard + WebSocketProtocols::SubscriptionsTransportWS + } + }; - #[cfg(not(feature = "graphql_ws"))] - { - // confusingly, the old subprotocol name for the susbscription-transport-ws spec was - // `graphql-ws` - warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws") - } - #[cfg(feature = "graphql_ws")] - { - // ...and the new one `graphql-transport-ws` - warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-transport-ws") - } - }) + let reply = ws.on_upgrade(move |websocket| { + let (ws_sender, ws_receiver) = websocket.split(); + + async move { + let _ = async_graphql::http::WebSocket::with_data( + schema, + ws_receiver + .take_while(|msg| future::ready(msg.is_ok())) + .map(Result::unwrap) + .map(ws::Message::into_bytes), + initializer, + protocol, + ) + .map(ws::Message::text) + .map(Ok) + .forward(ws_sender) + .await; + } + }); + + warp::reply::with_header( + reply, + "Sec-WebSocket-Protocol", + protocol.sec_websocket_protocol(), + ) + }) } diff --git a/src/http/mod.rs b/src/http/mod.rs index 069cd610..eacb4794 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -14,7 +14,7 @@ pub use graphiql_source::graphiql_source; #[cfg(feature = "multipart")] pub use multipart::MultipartOptions; pub use playground_source::{playground_source, GraphQLPlaygroundConfig}; -pub use websocket::WebSocket; +pub use websocket::{Protocols as WebSocketProtocols, WebSocket}; #[cfg(feature = "multipart")] #[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))] diff --git a/src/http/websocket.rs b/src/http/websocket.rs index 15389dc3..853ee16c 100644 --- a/src/http/websocket.rs +++ b/src/http/websocket.rs @@ -22,6 +22,7 @@ pin_project! { streams: HashMap + Send>>>, #[pin] stream: S, + protocol: Protocols, } } @@ -30,13 +31,18 @@ impl { /// Create a new websocket. #[must_use] - pub fn new(schema: Schema, stream: S) -> Self { + pub fn new( + schema: Schema, + stream: S, + protocol: Protocols, + ) -> Self { Self { data_initializer: None, data: Arc::default(), schema, streams: HashMap::new(), stream, + protocol, } } } @@ -52,6 +58,7 @@ impl WebSocket, stream: S, data_initializer: Option, + protocol: Protocols, ) -> Self { Self { data_initializer, @@ -59,6 +66,7 @@ impl WebSocket server` message; rather, disconnection is + // handled by disconnecting the websocket ClientMessage::ConnectionTerminate => return Poll::Ready(None), } } @@ -143,11 +154,7 @@ where match Pin::new(stream).poll_next(cx) { Poll::Ready(Some(payload)) => { return Poll::Ready(Some( - serde_json::to_string(&ServerMessage::Data { - id, - payload: Box::new(payload), - }) - .unwrap(), + serde_json::to_string(&this.protocol.next_message(id, payload)).unwrap(), )); } Poll::Ready(None) => { @@ -165,18 +172,68 @@ where } } +/// Specification of which GraphQL Over WebSockets protocol is being utilized +#[derive(Copy, Clone)] +pub enum Protocols { + /// [subscriptions-transport-ws protocol](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md). + SubscriptionsTransportWS, + /// [graphql-ws protocol](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md). + GraphQLWS, +} + +impl Protocols { + /// Returns the `Sec-WebSocket-Protocol` header value for the protocol + pub fn sec_websocket_protocol(&self) -> &str { + match self { + Protocols::SubscriptionsTransportWS => "graphql-ws", + Protocols::GraphQLWS => "graphql-transport-ws", + } + } + + #[inline] + fn next_message<'s>(&self, id: &'s str, payload: Response) -> ServerMessage<'s> { + match self { + Protocols::SubscriptionsTransportWS => ServerMessage::Data { + id, + payload: Box::new(payload), + }, + Protocols::GraphQLWS => ServerMessage::Next { + id, + payload: Box::new(payload), + }, + } + } +} + +impl std::str::FromStr for Protocols { + type Err = Error; + + fn from_str(protocol: &str) -> Result { + if protocol.eq_ignore_ascii_case("graphql-ws") { + Ok(Protocols::SubscriptionsTransportWS) + } else if protocol.eq_ignore_ascii_case("graphql-transport-ws") { + Ok(Protocols::GraphQLWS) + } else { + Err(Error::new(format!( + "Unsupported Sec-WebSocket-Protocol: {}", + protocol + ))) + } + } +} + #[derive(Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] enum ClientMessage<'a> { ConnectionInit { payload: Option, }, - #[cfg_attr(feature = "graphql_ws", serde(rename = "subscribe"))] + #[serde(alias = "subscribe")] Start { id: String, payload: Request, }, - #[cfg_attr(feature = "graphql_ws", serde(rename = "complete"))] + #[serde(alias = "complete")] Stop { id: &'a str, }, @@ -190,11 +247,16 @@ enum ServerMessage<'a> { payload: Error, }, ConnectionAck, - #[cfg_attr(feature = "graphql_ws", serde(rename = "next"))] + /// subscriptions-transport-ws protocol next payload Data { id: &'a str, payload: Box, }, + /// graphql-ws protocol next payload + Next { + id: &'a str, + payload: Box, + }, // Not used by this library, as it's not necessary to send // Error { // id: &'a str, From c9790087c5b8d71db04d2c26a5ae1ab9a5a5aa3c Mon Sep 17 00:00:00 2001 From: Sunli Date: Fri, 4 Dec 2020 12:13:52 +0800 Subject: [PATCH 4/7] Add websocket subprotocol for actix-web. --- integrations/actix-web/src/subscription.rs | 78 ++++++++++++++++------ tests/subscription_websocket.rs | 8 ++- 2 files changed, 63 insertions(+), 23 deletions(-) diff --git a/integrations/actix-web/src/subscription.rs b/integrations/actix-web/src/subscription.rs index 16455294..cdb22b5e 100644 --- a/integrations/actix-web/src/subscription.rs +++ b/integrations/actix-web/src/subscription.rs @@ -1,19 +1,26 @@ +use std::str::FromStr; +use std::time::{Duration, Instant}; + use actix::{ Actor, ActorContext, ActorFuture, ActorStream, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture, WrapStream, }; -use actix_http::ws; +use actix_http::error::PayloadError; +use actix_http::{ws, Error}; +use actix_web::web::Bytes; +use actix_web::{HttpRequest, HttpResponse}; use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext}; -use async_graphql::http::WebSocket; +use async_graphql::http::{WebSocket, WebSocketProtocols}; use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; -use std::time::{Duration, Instant}; +use futures_util::stream::Stream; const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); const CLIENT_TIMEOUT: Duration = Duration::from_secs(10); /// Actor for subscription via websocket pub struct WSSubscription { - schema: Option>, + schema: Schema, + protocol: WebSocketProtocols, last_heartbeat: Instant, messages: Option>>, initializer: Option Result + Send + Sync>>, @@ -26,15 +33,41 @@ where Mutation: ObjectType + Send + Sync + 'static, Subscription: SubscriptionType + Send + Sync + 'static, { - /// Create an actor for subscription connection via websocket. - pub fn new(schema: Schema) -> Self { - Self { - schema: Some(schema), - last_heartbeat: Instant::now(), - messages: None, - initializer: None, - continuation: Vec::new(), - } + /// Start an actor for subscription connection via websocket. + pub fn start( + schema: Schema, + request: &HttpRequest, + stream: T, + ) -> Result + where + T: Stream> + 'static, + { + let protocol = match request + .headers() + .get("sec-websocket-protocol") + .and_then(|value| value.to_str().ok()) + .and_then(|value| WebSocketProtocols::from_str(value).ok()) + { + Some(protocol) => protocol, + None => { + // default to the prior standard + WebSocketProtocols::SubscriptionsTransportWS + } + }; + + actix_web_actors::ws::start_with_protocols( + Self { + schema, + protocol, + last_heartbeat: Instant::now(), + messages: None, + initializer: None, + continuation: Vec::new(), + }, + &["graphql-transport-ws", "graphql-ws"], + request, + stream, + ) } /// Set a context data initialization function. @@ -71,13 +104,18 @@ where let (tx, rx) = async_channel::unbounded(); - WebSocket::with_data(self.schema.take().unwrap(), rx, self.initializer.take()) - .into_actor(self) - .map(|response, _act, ctx| { - ctx.text(response); - }) - .finish() - .spawn(ctx); + WebSocket::with_data( + self.schema.clone(), + rx, + self.initializer.take(), + self.protocol, + ) + .into_actor(self) + .map(|response, _act, ctx| { + ctx.text(response); + }) + .finish() + .spawn(ctx); self.messages = Some(tx); } diff --git a/tests/subscription_websocket.rs b/tests/subscription_websocket.rs index 23cf7edf..3209b1ac 100644 --- a/tests/subscription_websocket.rs +++ b/tests/subscription_websocket.rs @@ -1,3 +1,4 @@ +use async_graphql::http::WebSocketProtocols; use async_graphql::*; use futures_util::stream::{Stream, StreamExt}; @@ -23,7 +24,7 @@ pub async fn test_subscription_ws_transport() { let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let (tx, rx) = async_channel::unbounded(); - let mut stream = http::WebSocket::new(schema, rx); + let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS); tx.send( serde_json::to_string(&value!({ @@ -115,6 +116,7 @@ pub async fn test_subscription_ws_transport_with_token() { data.insert(Token(payload.token)); Ok(data) }), + WebSocketProtocols::SubscriptionsTransportWS, ); tx.send( @@ -204,7 +206,7 @@ pub async fn test_subscription_ws_transport_error() { let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let (tx, rx) = async_channel::unbounded(); - let mut stream = http::WebSocket::new(schema, rx); + let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS); tx.send( serde_json::to_string(&value!({ @@ -276,7 +278,7 @@ pub async fn test_query_over_websocket() { let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription); let (tx, rx) = async_channel::unbounded(); - let mut stream = http::WebSocket::new(schema, rx); + let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS); tx.send( serde_json::to_string(&value!({ From 6884b6e105901068e1e5eac7bb50d0a45d462fed Mon Sep 17 00:00:00 2001 From: Sunli Date: Fri, 4 Dec 2020 12:35:35 +0800 Subject: [PATCH 5/7] Add `WSSubscription::start_with_initializer` and update examples. --- examples | 2 +- integrations/actix-web/src/subscription.rs | 27 ++++++++++++---------- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/examples b/examples index 71d7318c..b51536fb 160000 --- a/examples +++ b/examples @@ -1 +1 @@ -Subproject commit 71d7318ca5bd5f9b77ffc7136990c80db35a2b46 +Subproject commit b51536fb0a933d25d9fcc91a08c2a53daf74daf8 diff --git a/integrations/actix-web/src/subscription.rs b/integrations/actix-web/src/subscription.rs index cdb22b5e..0d3fe3fc 100644 --- a/integrations/actix-web/src/subscription.rs +++ b/integrations/actix-web/src/subscription.rs @@ -41,6 +41,20 @@ where ) -> Result where T: Stream> + 'static, + { + Self::start_with_initializer(schema, request, stream, |_| Ok(Default::default())) + } + + /// Start an actor for subscription connection via websocket with an initialization function. + pub fn start_with_initializer( + schema: Schema, + request: &HttpRequest, + stream: T, + initializer: F, + ) -> Result + where + T: Stream> + 'static, + F: FnOnce(serde_json::Value) -> Result + Send + Sync + 'static, { let protocol = match request .headers() @@ -61,7 +75,7 @@ where protocol, last_heartbeat: Instant::now(), messages: None, - initializer: None, + initializer: Some(Box::new(initializer)), continuation: Vec::new(), }, &["graphql-transport-ws", "graphql-ws"], @@ -70,17 +84,6 @@ where ) } - /// Set a context data initialization function. - pub fn initializer(self, f: F) -> Self - where - F: FnOnce(serde_json::Value) -> Result + Send + Sync + 'static, - { - Self { - initializer: Some(Box::new(f)), - ..self - } - } - fn send_heartbeats(&self, ctx: &mut WebsocketContext) { ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { if Instant::now().duration_since(act.last_heartbeat) > CLIENT_TIMEOUT { From 7796b8c24887af9ee89d87e9b8cc7a978eb94aee Mon Sep 17 00:00:00 2001 From: Sunli Date: Fri, 4 Dec 2020 16:44:48 +0800 Subject: [PATCH 6/7] Add tests for `graphql-ws`. --- tests/subscription_websocket_graphql_ws.rs | 328 ++++++++++++++++++ ...n_websocket_subscriptions_transport_ws.rs} | 0 2 files changed, 328 insertions(+) create mode 100644 tests/subscription_websocket_graphql_ws.rs rename tests/{subscription_websocket.rs => subscription_websocket_subscriptions_transport_ws.rs} (100%) diff --git a/tests/subscription_websocket_graphql_ws.rs b/tests/subscription_websocket_graphql_ws.rs new file mode 100644 index 00000000..d4b28724 --- /dev/null +++ b/tests/subscription_websocket_graphql_ws.rs @@ -0,0 +1,328 @@ +use async_graphql::http::WebSocketProtocols; +use async_graphql::*; +use futures_util::stream::{Stream, StreamExt}; + +#[async_std::test] +pub async fn test_subscription_ws_transport() { + struct QueryRoot; + + #[Object] + impl QueryRoot { + async fn value(&self) -> i32 { + 10 + } + } + + struct SubscriptionRoot; + + #[Subscription] + impl SubscriptionRoot { + async fn values(&self) -> impl Stream { + futures_util::stream::iter(0..10) + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); + let (tx, rx) = async_channel::unbounded(); + let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS); + + tx.send( + serde_json::to_string(&value!({ + "type": "connection_init", + })) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + serde_json::from_str::(&stream.next().await.unwrap()).unwrap(), + serde_json::json!({ + "type": "connection_ack", + }), + ); + + tx.send( + serde_json::to_string(&value!({ + "type": "start", + "id": "1", + "payload": { + "query": "subscription { values }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); + + for i in 0..10 { + assert_eq!( + serde_json::from_str::(&stream.next().await.unwrap()).unwrap(), + serde_json::json!({ + "type": "next", + "id": "1", + "payload": { "data": { "values": i } }, + }), + ); + } + + assert_eq!( + serde_json::from_str::(&stream.next().await.unwrap()).unwrap(), + serde_json::json!({ + "type": "complete", + "id": "1", + }), + ); +} + +#[async_std::test] +pub async fn test_subscription_ws_transport_with_token() { + struct Token(String); + + struct QueryRoot; + + #[Object] + impl QueryRoot { + async fn value(&self) -> i32 { + 10 + } + } + + struct SubscriptionRoot; + + #[Subscription] + impl SubscriptionRoot { + async fn values(&self, ctx: &Context<'_>) -> Result> { + if ctx.data_unchecked::().0 != "123456" { + return Err("forbidden".into()); + } + Ok(futures_util::stream::iter(0..10)) + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); + let (tx, rx) = async_channel::unbounded(); + let mut stream = http::WebSocket::with_data( + schema, + rx, + Some(|value| { + #[derive(serde::Deserialize)] + struct Payload { + token: String, + } + + let payload: Payload = serde_json::from_value(value).unwrap(); + let mut data = Data::default(); + data.insert(Token(payload.token)); + Ok(data) + }), + WebSocketProtocols::GraphQLWS, + ); + + tx.send( + serde_json::to_string(&value!({ + "type": "connection_init", + "payload": { "token": "123456" } + })) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + Some(value!({ + "type": "connection_ack", + })), + serde_json::from_str(&stream.next().await.unwrap()).unwrap() + ); + + tx.send( + serde_json::to_string(&value!({ + "type": "start", + "id": "1", + "payload": { + "query": "subscription { values }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); + + for i in 0..10 { + assert_eq!( + Some(value!({ + "type": "next", + "id": "1", + "payload": { "data": { "values": i } }, + })), + serde_json::from_str(&stream.next().await.unwrap()).unwrap() + ); + } + + assert_eq!( + Some(value!({ + "type": "complete", + "id": "1", + })), + serde_json::from_str(&stream.next().await.unwrap()).unwrap() + ); +} + +#[async_std::test] +pub async fn test_subscription_ws_transport_error() { + struct Event { + value: i32, + } + + #[Object] + impl Event { + async fn value(&self) -> Result { + if self.value < 5 { + Ok(self.value) + } else { + Err("TestError".into()) + } + } + } + + struct QueryRoot; + + #[Object] + impl QueryRoot { + async fn value(&self) -> i32 { + 10 + } + } + + struct SubscriptionRoot; + + #[Subscription] + impl SubscriptionRoot { + async fn events(&self) -> impl Stream { + futures_util::stream::iter((0..10).map(|n| Event { value: n })) + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); + let (tx, rx) = async_channel::unbounded(); + let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS); + + tx.send( + serde_json::to_string(&value!({ + "type": "connection_init" + })) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + Some(value!({ + "type": "connection_ack", + })), + serde_json::from_str(&stream.next().await.unwrap()).unwrap() + ); + + tx.send( + serde_json::to_string(&value!({ + "type": "start", + "id": "1", + "payload": { + "query": "subscription { events { value } }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); + + for i in 0i32..5 { + assert_eq!( + Some(value!({ + "type": "next", + "id": "1", + "payload": { "data": { "events": { "value": i } } }, + })), + serde_json::from_str(&stream.next().await.unwrap()).unwrap() + ); + } + + assert_eq!( + Some(value!({ + "type": "next", + "id": "1", + "payload": { + "data": null, + "errors": [{ + "message": "TestError", + "locations": [{"line": 1, "column": 25}], + "path": ["events", "value"], + }], + }, + })), + serde_json::from_str(&stream.next().await.unwrap()).unwrap() + ); +} + +#[async_std::test] +pub async fn test_query_over_websocket() { + struct QueryRoot; + + #[Object] + impl QueryRoot { + async fn value(&self) -> i32 { + 999 + } + } + + let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription); + let (tx, rx) = async_channel::unbounded(); + let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS); + + tx.send( + serde_json::to_string(&value!({ + "type": "connection_init", + })) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + Some(value!({ + "type": "connection_ack", + })), + serde_json::from_str(&stream.next().await.unwrap()).unwrap() + ); + + tx.send( + serde_json::to_string(&value!({ + "type": "start", + "id": "1", + "payload": { + "query": "query { value }" + }, + })) + .unwrap(), + ) + .await + .unwrap(); + + assert_eq!( + Some(value!({ + "type": "next", + "id": "1", + "payload": { "data": { "value": 999 } }, + })), + serde_json::from_str(&stream.next().await.unwrap()).unwrap() + ); + + assert_eq!( + Some(value!({ + "type": "complete", + "id": "1", + })), + serde_json::from_str(&stream.next().await.unwrap()).unwrap() + ); +} diff --git a/tests/subscription_websocket.rs b/tests/subscription_websocket_subscriptions_transport_ws.rs similarity index 100% rename from tests/subscription_websocket.rs rename to tests/subscription_websocket_subscriptions_transport_ws.rs From a3a3ba38f0124598ab151f2ae38cbe04ba4ee178 Mon Sep 17 00:00:00 2001 From: Patrick Fernie Date: Fri, 4 Dec 2020 12:16:14 -0500 Subject: [PATCH 7/7] support client specifying multiple protocols in Sec-WebSocket-Protocol negotiation --- integrations/actix-web/src/subscription.rs | 7 +++++-- integrations/warp/src/subscription.rs | 21 ++++++++++----------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/integrations/actix-web/src/subscription.rs b/integrations/actix-web/src/subscription.rs index 0d3fe3fc..be123bd1 100644 --- a/integrations/actix-web/src/subscription.rs +++ b/integrations/actix-web/src/subscription.rs @@ -60,8 +60,11 @@ where .headers() .get("sec-websocket-protocol") .and_then(|value| value.to_str().ok()) - .and_then(|value| WebSocketProtocols::from_str(value).ok()) - { + .and_then(|protocols| { + protocols + .split(',') + .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok()) + }) { Some(protocol) => protocol, None => { // default to the prior standard diff --git a/integrations/warp/src/subscription.rs b/integrations/warp/src/subscription.rs index 171d7aea..6f56d1a3 100644 --- a/integrations/warp/src/subscription.rs +++ b/integrations/warp/src/subscription.rs @@ -65,22 +65,21 @@ where F: FnOnce(serde_json::Value) -> Result + Send + Sync + Clone + 'static, { use async_graphql::http::WebSocketProtocols; + use std::str::FromStr; warp::ws() - .and(warp::header::optional::( - "sec-websocket-protocol", - )) - .map(move |ws: ws::Ws, protocol| { + .and(warp::header::optional::("sec-websocket-protocol")) + .map(move |ws: ws::Ws, protocols: Option| { let schema = schema.clone(); let initializer = initializer.clone(); - let protocol = match protocol { - Some(protocol) => protocol, - None => { - // default to the prior standard - WebSocketProtocols::SubscriptionsTransportWS - } - }; + let protocol = protocols + .and_then(|protocols| { + protocols + .split(',') + .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok()) + }) + .unwrap_or(WebSocketProtocols::SubscriptionsTransportWS); let reply = ws.on_upgrade(move |websocket| { let (ws_sender, ws_receiver) = websocket.split();