diff --git a/integrations/actix-web/src/subscription.rs b/integrations/actix-web/src/subscription.rs index 4a4f6aea..30ee88e9 100644 --- a/integrations/actix-web/src/subscription.rs +++ b/integrations/actix-web/src/subscription.rs @@ -7,9 +7,9 @@ use actix::{ }; use actix::{ActorFutureExt, ActorStreamExt}; use actix_http::error::PayloadError; -use actix_http::ws; +use actix_http::{ws, Payload, StatusCode}; use actix_web::web::Bytes; -use actix_web::{HttpRequest, HttpResponse}; +use actix_web::{http, Error, FromRequest, HttpRequest, HttpResponse}; use actix_web_actors::ws::{CloseReason, Message, ProtocolError, WebsocketContext}; use futures_util::future::Ready; use futures_util::stream::Stream; @@ -20,6 +20,33 @@ use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); const CLIENT_TIMEOUT: Duration = Duration::from_secs(10); +/// A GraphQL protocol extractor. +/// +/// It extract GraphQL protocol from `SEC_WEBSOCKET_PROTOCOL` header. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct GraphQLProtocol(WebSocketProtocols); + +#[async_trait::async_trait] +impl FromRequest for GraphQLProtocol { + type Error = Error; + type Future = Ready; + + fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { + Ready( + req.headers() + .get(http::header::SEC_WEBSOCKET_PROTOCOL) + .and_then(|value| value.to_str().ok()) + .and_then(|protocols| { + protocols + .split(',') + .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok()) + }) + .map(Self) + .ok_or_else(|| Error::), + ) + } +} + /// Actor for subscription via websocket pub struct WSSubscription { schema: Schema, @@ -212,4 +239,4 @@ where .spawn(ctx) } } -} \ No newline at end of file +}