diff --git a/integrations/actix-web/src/subscription.rs b/integrations/actix-web/src/subscription.rs index 16455294..cdb22b5e 100644 --- a/integrations/actix-web/src/subscription.rs +++ b/integrations/actix-web/src/subscription.rs @@ -1,19 +1,26 @@ +use std::str::FromStr; +use std::time::{Duration, Instant}; + use actix::{ Actor, ActorContext, ActorFuture, ActorStream, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture, WrapStream, }; -use actix_http::ws; +use actix_http::error::PayloadError; +use actix_http::{ws, Error}; +use actix_web::web::Bytes; +use actix_web::{HttpRequest, HttpResponse}; use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext}; -use async_graphql::http::WebSocket; +use async_graphql::http::{WebSocket, WebSocketProtocols}; use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; -use std::time::{Duration, Instant}; +use futures_util::stream::Stream; const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); const CLIENT_TIMEOUT: Duration = Duration::from_secs(10); /// Actor for subscription via websocket pub struct WSSubscription { - schema: Option>, + schema: Schema, + protocol: WebSocketProtocols, last_heartbeat: Instant, messages: Option>>, initializer: Option Result + Send + Sync>>, @@ -26,15 +33,41 @@ where Mutation: ObjectType + Send + Sync + 'static, Subscription: SubscriptionType + Send + Sync + 'static, { - /// Create an actor for subscription connection via websocket. - pub fn new(schema: Schema) -> Self { - Self { - schema: Some(schema), - last_heartbeat: Instant::now(), - messages: None, - initializer: None, - continuation: Vec::new(), - } + /// Start an actor for subscription connection via websocket. + pub fn start( + schema: Schema, + request: &HttpRequest, + stream: T, + ) -> Result + where + T: Stream> + 'static, + { + let protocol = match request + .headers() + .get("sec-websocket-protocol") + .and_then(|value| value.to_str().ok()) + .and_then(|value| WebSocketProtocols::from_str(value).ok()) + { + Some(protocol) => protocol, + None => { + // default to the prior standard + WebSocketProtocols::SubscriptionsTransportWS + } + }; + + actix_web_actors::ws::start_with_protocols( + Self { + schema, + protocol, + last_heartbeat: Instant::now(), + messages: None, + initializer: None, + continuation: Vec::new(), + }, + &["graphql-transport-ws", "graphql-ws"], + request, + stream, + ) } /// Set a context data initialization function. @@ -71,13 +104,18 @@ where let (tx, rx) = async_channel::unbounded(); - WebSocket::with_data(self.schema.take().unwrap(), rx, self.initializer.take()) - .into_actor(self) - .map(|response, _act, ctx| { - ctx.text(response); - }) - .finish() - .spawn(ctx); + WebSocket::with_data( + self.schema.clone(), + rx, + self.initializer.take(), + self.protocol, + ) + .into_actor(self) + .map(|response, _act, ctx| { + ctx.text(response); + }) + .finish() + .spawn(ctx); self.messages = Some(tx); } diff --git a/tests/subscription_websocket.rs b/tests/subscription_websocket.rs index 23cf7edf..3209b1ac 100644 --- a/tests/subscription_websocket.rs +++ b/tests/subscription_websocket.rs @@ -1,3 +1,4 @@ +use async_graphql::http::WebSocketProtocols; use async_graphql::*; use futures_util::stream::{Stream, StreamExt}; @@ -23,7 +24,7 @@ pub async fn test_subscription_ws_transport() { let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let (tx, rx) = async_channel::unbounded(); - let mut stream = http::WebSocket::new(schema, rx); + let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS); tx.send( serde_json::to_string(&value!({ @@ -115,6 +116,7 @@ pub async fn test_subscription_ws_transport_with_token() { data.insert(Token(payload.token)); Ok(data) }), + WebSocketProtocols::SubscriptionsTransportWS, ); tx.send( @@ -204,7 +206,7 @@ pub async fn test_subscription_ws_transport_error() { let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let (tx, rx) = async_channel::unbounded(); - let mut stream = http::WebSocket::new(schema, rx); + let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS); tx.send( serde_json::to_string(&value!({ @@ -276,7 +278,7 @@ pub async fn test_query_over_websocket() { let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription); let (tx, rx) = async_channel::unbounded(); - let mut stream = http::WebSocket::new(schema, rx); + let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS); tx.send( serde_json::to_string(&value!({