diff --git a/integrations/actix-web/Cargo.toml b/integrations/actix-web/Cargo.toml index 3e7b12ae..b3749e73 100644 --- a/integrations/actix-web/Cargo.toml +++ b/integrations/actix-web/Cargo.toml @@ -22,6 +22,7 @@ futures-util = { version = "0.3.17", default-features = false } serde_json = "1.0.64" serde_urlencoded = "0.7.0" futures-channel = "0.3.13" +thiserror = "1.0.30" [dev-dependencies] actix-rt = "2.2.0" diff --git a/integrations/actix-web/src/lib.rs b/integrations/actix-web/src/lib.rs index a0c4e227..021e3f59 100644 --- a/integrations/actix-web/src/lib.rs +++ b/integrations/actix-web/src/lib.rs @@ -3,6 +3,8 @@ #![allow(clippy::upper_case_acronyms)] #![warn(missing_docs)] +mod subscription; + use std::future::Future; use std::io::{self, ErrorKind}; use std::pin::Pin; @@ -16,16 +18,14 @@ use futures_util::{StreamExt, TryStreamExt}; use async_graphql::http::MultipartOptions; use async_graphql::ParseRequestError; -pub use subscription::WSSubscription; - -mod subscription; +pub use subscription::GraphQLSubscription; /// Extractor for GraphQL request. /// /// `async_graphql::http::MultipartOptions` allows to configure extraction process. -pub struct Request(pub async_graphql::Request); +pub struct GraphQLRequest(pub async_graphql::Request); -impl Request { +impl GraphQLRequest { /// Unwraps the value to `async_graphql::Request`. #[must_use] pub fn into_inner(self) -> async_graphql::Request { @@ -34,14 +34,14 @@ impl Request { } type BatchToRequestMapper = - fn(<::Future as Future>::Output) -> Result; + fn(<::Future as Future>::Output) -> Result; -impl FromRequest for Request { +impl FromRequest for GraphQLRequest { type Error = Error; - type Future = future::Map<::Future, BatchToRequestMapper>; + type Future = future::Map<::Future, BatchToRequestMapper>; fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - BatchRequest::from_request(req, payload).map(|res| { + GraphQLBatchRequest::from_request(req, payload).map(|res| { Ok(Self( res?.0 .into_single() @@ -54,9 +54,9 @@ impl FromRequest for Request { /// Extractor for GraphQL batch request. /// /// `async_graphql::http::MultipartOptions` allows to configure extraction process. -pub struct BatchRequest(pub async_graphql::BatchRequest); +pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest); -impl BatchRequest { +impl GraphQLBatchRequest { /// Unwraps the value to `async_graphql::BatchRequest`. #[must_use] pub fn into_inner(self) -> async_graphql::BatchRequest { @@ -64,9 +64,9 @@ impl BatchRequest { } } -impl FromRequest for BatchRequest { +impl FromRequest for GraphQLBatchRequest { type Error = Error; - type Future = Pin>>>; + type Future = Pin>>>; fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { let config = req @@ -97,7 +97,7 @@ impl FromRequest for BatchRequest { }); Box::pin(async move { - Ok(BatchRequest( + Ok(GraphQLBatchRequest( async_graphql::http::receive_batch_body( content_type, rx.map_err(|e| match e { @@ -146,21 +146,21 @@ impl FromRequest for BatchRequest { /// /// This contains a batch response, but since regular responses are a type of batch response it /// works for both. -pub struct Response(pub async_graphql::BatchResponse); +pub struct GraphQLResponse(pub async_graphql::BatchResponse); -impl From for Response { +impl From for GraphQLResponse { fn from(resp: async_graphql::Response) -> Self { Self(resp.into()) } } -impl From for Response { +impl From for GraphQLResponse { fn from(resp: async_graphql::BatchResponse) -> Self { Self(resp) } } -impl Responder for Response { +impl Responder for GraphQLResponse { fn respond_to(self, _req: &HttpRequest) -> HttpResponse { let mut res = HttpResponse::build(StatusCode::OK); res.content_type("application/json"); @@ -174,4 +174,4 @@ impl Responder for Response { } res.body(serde_json::to_string(&self.0).unwrap()) } -} \ No newline at end of file +} diff --git a/integrations/actix-web/src/subscription.rs b/integrations/actix-web/src/subscription.rs index 30ee88e9..aae627c5 100644 --- a/integrations/actix-web/src/subscription.rs +++ b/integrations/actix-web/src/subscription.rs @@ -7,9 +7,9 @@ use actix::{ }; use actix::{ActorFutureExt, ActorStreamExt}; use actix_http::error::PayloadError; -use actix_http::{ws, Payload, StatusCode}; +use actix_http::ws; use actix_web::web::Bytes; -use actix_web::{http, Error, FromRequest, HttpRequest, HttpResponse}; +use actix_web::{Error, HttpRequest, HttpResponse}; use actix_web_actors::ws::{CloseReason, Message, ProtocolError, WebsocketContext}; use futures_util::future::Ready; use futures_util::stream::Stream; @@ -20,86 +20,80 @@ use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); const CLIENT_TIMEOUT: Duration = Duration::from_secs(10); -/// A GraphQL protocol extractor. -/// -/// It extract GraphQL protocol from `SEC_WEBSOCKET_PROTOCOL` header. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub struct GraphQLProtocol(WebSocketProtocols); +#[derive(thiserror::Error, Debug)] +#[error("failed to parse graphql protocol")] +pub struct ParseGraphQLProtocolError; -#[async_trait::async_trait] -impl FromRequest for GraphQLProtocol { - type Error = Error; - type Future = Ready; +type DefaultOnConnInitType = fn(serde_json::Value) -> Ready>; - fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { - Ready( - req.headers() - .get(http::header::SEC_WEBSOCKET_PROTOCOL) - .and_then(|value| value.to_str().ok()) - .and_then(|protocols| { - protocols - .split(',') - .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok()) - }) - .map(Self) - .ok_or_else(|| Error::), - ) - } +fn default_on_connection_init(_: serde_json::Value) -> Ready> { + futures_util::future::ready(Ok(Data::default())) } -/// Actor for subscription via websocket -pub struct WSSubscription { +/// A builder for websocket subscription actor. +pub struct GraphQLSubscription { schema: Schema, - protocol: WebSocketProtocols, - last_heartbeat: Instant, - messages: Option>>, - initializer: Option, - continuation: Vec, + data: Data, + on_connection_init: OnInit, } impl - WSSubscription Ready>> + GraphQLSubscription where Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, { - /// Start an actor for subscription connection via websocket. - pub fn start( - schema: Schema, - request: &HttpRequest, - stream: T, - ) -> Result - where - T: Stream> + 'static, - { - Self::start_with_initializer(schema, request, stream, |_| { - futures_util::future::ready(Ok(Default::default())) - }) + /// Create a GraphQL subscription builder. + pub fn new(schema: Schema) -> Self { + Self { + schema, + data: Default::default(), + on_connection_init: default_on_connection_init, + } } } -impl WSSubscription +impl + GraphQLSubscription where Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, - F: FnOnce(serde_json::Value) -> R + Unpin + Send + 'static, - R: Future> + Send + 'static, + OnInit: Fn(serde_json::Value) -> OnInitFut + Unpin + Send + 'static, + OnInitFut: Future> + Send + 'static, { - /// Start an actor for subscription connection via websocket with an initialization function. - pub fn start_with_initializer( - schema: Schema, - request: &HttpRequest, - stream: T, - initializer: F, - ) -> Result + /// Specify the initial subscription context data, usually you can get something from the + /// incoming request to create it. + pub fn with_data(self, data: Data) -> Self { + Self { data, ..self } + } + + /// Specify a callback function to be called when the connection is initialized. + /// + /// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`]. + /// The data returned by this callback function will be merged with the data specified by [`with_data`]. + pub fn on_connection_init( + self, + callback: OnConnInit2, + ) -> GraphQLSubscription where - T: Stream> + 'static, - F: FnOnce(serde_json::Value) -> R + Unpin + Send + 'static, - R: Future> + Send + 'static, + OnConnInit2: Fn(serde_json::Value) -> Fut + Unpin + Send + 'static, + Fut: Future> + Send + 'static, { - let protocol = match request + GraphQLSubscription { + schema: self.schema, + data: self.data, + on_connection_init: callback, + } + } + + /// Start the subscription actor. + pub fn start(self, request: &HttpRequest, stream: S) -> Result + where + S: Stream> + 'static, + { + let protocol = request .headers() .get("sec-websocket-protocol") .and_then(|value| value.to_str().ok()) @@ -107,29 +101,42 @@ where protocols .split(',') .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok()) - }) { - Some(protocol) => protocol, - None => { - // default to the prior standard - WebSocketProtocols::SubscriptionsTransportWS - } + }) + .ok_or_else(|| actix_web::error::ErrorBadRequest(ParseGraphQLProtocolError))?; + + let actor = GraphQLSubscriptionActor { + schema: self.schema, + data: Some(self.data), + protocol, + last_heartbeat: Instant::now(), + messages: None, + on_connection_init: Some(self.on_connection_init), + continuation: Vec::new(), }; - actix_web_actors::ws::start_with_protocols( - Self { - schema, - protocol, - last_heartbeat: Instant::now(), - messages: None, - initializer: Some(initializer), - continuation: Vec::new(), - }, - &ALL_WEBSOCKET_PROTOCOLS, - request, - stream, - ) + actix_web_actors::ws::start_with_protocols(actor, &ALL_WEBSOCKET_PROTOCOLS, request, stream) } +} +struct GraphQLSubscriptionActor { + schema: Schema, + data: Option, + protocol: WebSocketProtocols, + last_heartbeat: Instant, + messages: Option>>, + on_connection_init: Option, + continuation: Vec, +} + +impl + GraphQLSubscriptionActor +where + Query: ObjectType + 'static, + Mutation: ObjectType + 'static, + Subscription: SubscriptionType + 'static, + OnInit: FnOnce(serde_json::Value) -> OnInitFut + Unpin + Send + 'static, + OnInitFut: Future> + Send + 'static, +{ fn send_heartbeats(&self, ctx: &mut WebsocketContext) { ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { if Instant::now().duration_since(act.last_heartbeat) > CLIENT_TIMEOUT { @@ -140,13 +147,14 @@ where } } -impl Actor for WSSubscription +impl Actor + for GraphQLSubscriptionActor where Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, - F: FnOnce(serde_json::Value) -> R + Unpin + Send + 'static, - R: Future> + Send + 'static, + OnInit: FnOnce(serde_json::Value) -> OnInitFut + Unpin + Send + 'static, + OnInitFut: Future> + Send + 'static, { type Context = WebsocketContext; @@ -155,35 +163,32 @@ where let (tx, rx) = async_channel::unbounded(); - WebSocket::with_data( - self.schema.clone(), - rx, - self.initializer.take().unwrap(), - self.protocol, - ) - .into_actor(self) - .map(|response, _act, ctx| match response { - WsMessage::Text(text) => ctx.text(text), - WsMessage::Close(code, msg) => ctx.close(Some(CloseReason { - code: code.into(), - description: Some(msg), - })), - }) - .finish() - .spawn(ctx); + WebSocket::new(self.schema.clone(), rx, self.protocol) + .connection_data(self.data.take().unwrap()) + .on_connection_init(self.on_connection_init.take().unwrap()) + .into_actor(self) + .map(|response, _act, ctx| match response { + WsMessage::Text(text) => ctx.text(text), + WsMessage::Close(code, msg) => ctx.close(Some(CloseReason { + code: code.into(), + description: Some(msg), + })), + }) + .finish() + .spawn(ctx); self.messages = Some(tx); } } -impl StreamHandler> - for WSSubscription +impl StreamHandler> + for GraphQLSubscriptionActor where Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, - F: FnOnce(serde_json::Value) -> R + Unpin + Send + 'static, - R: Future> + Send + 'static, + OnInit: FnOnce(serde_json::Value) -> OnInitFut + Unpin + Send + 'static, + OnInitFut: Future> + Send + 'static, { fn handle(&mut self, msg: Result, ctx: &mut Self::Context) { let msg = match msg { diff --git a/integrations/actix-web/tests/test_utils.rs b/integrations/actix-web/tests/test_utils.rs index 9a41a85b..f9b4c539 100644 --- a/integrations/actix-web/tests/test_utils.rs +++ b/integrations/actix-web/tests/test_utils.rs @@ -3,7 +3,7 @@ use async_graphql::http::{playground_source, GraphQLPlaygroundConfig}; use async_graphql::{ Context, EmptyMutation, EmptySubscription, Object, ObjectType, Schema, SubscriptionType, }; -use async_graphql_actix_web::{Request, Response}; +use async_graphql_actix_web::{GraphQLRequest, GraphQLResponse}; use async_mutex::Mutex; pub async fn gql_playgound() -> HttpResponse { @@ -69,16 +69,16 @@ pub async fn gql_handle_schema< S: SubscriptionType + 'static, >( schema: web::Data>, - req: Request, -) -> Response { + req: GraphQLRequest, +) -> GraphQLResponse { schema.execute(req.into_inner()).await.into() } pub async fn gql_handle_schema_with_header( schema: actix_web::web::Data>, req: HttpRequest, - gql_request: Request, -) -> Response { + gql_request: GraphQLRequest, +) -> GraphQLResponse { let name = req .headers() .get("Name")