diff --git a/integrations/poem/src/subscription.rs b/integrations/poem/src/subscription.rs index e4f64323..18af0079 100644 --- a/integrations/poem/src/subscription.rs +++ b/integrations/poem/src/subscription.rs @@ -187,8 +187,7 @@ where future::ready(None) } }) - .map(Message::into_bytes) - .boxed(); + .map(Message::into_bytes); let mut stream = async_graphql::http::WebSocket::new(self.schema.clone(), stream, self.protocol.0) diff --git a/integrations/warp/src/lib.rs b/integrations/warp/src/lib.rs index 767fe615..5b572554 100644 --- a/integrations/warp/src/lib.rs +++ b/integrations/warp/src/lib.rs @@ -2,6 +2,7 @@ #![allow(clippy::type_complexity)] #![forbid(unsafe_code)] +#![warn(missing_docs)] mod batch_request; mod error; @@ -11,7 +12,4 @@ mod subscription; pub use batch_request::{graphql_batch, graphql_batch_opts, BatchResponse}; pub use error::BadRequest; pub use request::{graphql, graphql_opts, Response}; -pub use subscription::{ - graphql_protocol, graphql_subscription, graphql_subscription_upgrade, - graphql_subscription_upgrade_with_data, graphql_subscription_with_data, -}; +pub use subscription::{graphql_protocol, graphql_subscription, GraphQLWebSocket}; diff --git a/integrations/warp/src/subscription.rs b/integrations/warp/src/subscription.rs index f9195bfb..df645fc9 100644 --- a/integrations/warp/src/subscription.rs +++ b/integrations/warp/src/subscription.rs @@ -3,10 +3,10 @@ use std::str::FromStr; use async_graphql::http::{WebSocketProtocols, WsMessage}; use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; -use futures_util::sink::Sink; -use futures_util::stream::Stream; +use futures_util::future::Ready; use futures_util::{future, StreamExt}; use warp::filters::ws; +use warp::ws::WebSocket; use warp::{Filter, Rejection, Reply}; /// GraphQL subscription filter @@ -59,32 +59,16 @@ where Query: ObjectType + Sync + Send + 'static, Mutation: ObjectType + Sync + Send + 'static, Subscription: SubscriptionType + Send + Sync + 'static, -{ - graphql_subscription_with_data(schema, |_| async { Ok(Default::default()) }) -} - -/// GraphQL subscription filter -/// -/// Specifies that a function converts the init payload to data. -pub fn graphql_subscription_with_data( - schema: Schema, - initializer: F, -) -> impl Filter + Clone -where - Query: ObjectType + 'static, - Mutation: ObjectType + 'static, - Subscription: SubscriptionType + 'static, - F: FnOnce(serde_json::Value) -> R + Clone + Send + 'static, - R: Future> + Send + 'static, { warp::ws() .and(graphql_protocol()) .map(move |ws: ws::Ws, protocol| { let schema = schema.clone(); - let initializer = initializer.clone(); - let reply = ws.on_upgrade(move |websocket| { - graphql_subscription_upgrade_with_data(websocket, protocol, schema, initializer) + let reply = ws.on_upgrade(move |socket| { + GraphQLWebSocket::new(socket, schema, protocol) + .on_connection_init(default_on_connection_init) + .serve() }); warp::reply::with_header( @@ -109,10 +93,13 @@ pub fn graphql_protocol() -> impl Filter Ready>; + +fn default_on_connection_init(_: serde_json::Value) -> Ready> { + futures_util::future::ready(Ok(Data::default())) +} + +/// A Websocket connection for GraphQL subscription. /// /// # Examples /// @@ -150,70 +137,112 @@ pub fn graphql_protocol() -> impl Filter( - websocket: S, +pub struct GraphQLWebSocket { + socket: WebSocket, protocol: WebSocketProtocols, schema: Schema, -) where - Query: ObjectType + 'static, - Mutation: ObjectType + 'static, - Subscription: SubscriptionType + 'static, - S: Stream> + Sink, -{ - graphql_subscription_upgrade_with_data(websocket, protocol, schema, |_| async { - Ok(Default::default()) - }) - .await + data: Data, + on_init: OnInit, } -/// Handle the WebSocket subscription. -/// -/// Specifies that a function converts the init payload to data. -pub async fn graphql_subscription_upgrade_with_data( - websocket: S, - protocol: WebSocketProtocols, - schema: Schema, - initializer: F, -) where +impl + GraphQLWebSocket +where Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, - F: FnOnce(serde_json::Value) -> R + Send + 'static, - R: Future> + Send + 'static, - S: Stream> + Sink, { - let (ws_sender, ws_receiver) = websocket.split(); - let _ = async_graphql::http::WebSocket::with_data( - schema, - ws_receiver + /// Create a [`GraphQLWebSocket`] object. + pub fn new( + socket: WebSocket, + schema: Schema, + protocol: WebSocketProtocols, + ) -> Self { + GraphQLWebSocket { + socket, + protocol, + schema, + data: Data::default(), + on_init: default_on_connection_init, + } + } +} + +impl + GraphQLWebSocket +where + Query: ObjectType + 'static, + Mutation: ObjectType + 'static, + Subscription: SubscriptionType + 'static, + OnConnInit: Fn(serde_json::Value) -> OnConnInitFut + Send + Sync + 'static, + OnConnInitFut: Future> + Send + 'static, +{ + /// Specify the initial subscription context data, usually you can get something from the + /// incoming request to create it. + pub fn with_data(self, data: Data) -> Self { + Self { data, ..self } + } + + /// Specify a callback function to be called when the connection is initialized. + /// + /// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`]. + /// The data returned by this callback function will be merged with the data specified by [`with_data`]. + pub fn on_connection_init( + self, + callback: OnConnInit2, + ) -> GraphQLWebSocket + where + OnConnInit2: Fn(serde_json::Value) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + GraphQLWebSocket { + socket: self.socket, + schema: self.schema, + data: self.data, + on_init: callback, + protocol: self.protocol, + } + } + + /// Processing subscription requests. + pub async fn serve(self) { + let (ws_sender, ws_receiver) = self.socket.split(); + + let stream = ws_receiver .take_while(|msg| future::ready(msg.is_ok())) .map(Result::unwrap) .filter(|msg| future::ready(msg.is_text() || msg.is_binary())) - .map(ws::Message::into_bytes), - initializer, - protocol, - ) - .map(|msg| match msg { - WsMessage::Text(text) => ws::Message::text(text), - WsMessage::Close(code, status) => ws::Message::close_with(code, status), - }) - .map(Ok) - .forward(ws_sender) - .await; + .map(ws::Message::into_bytes); + + let _ = async_graphql::http::WebSocket::new(self.schema.clone(), stream, self.protocol) + .connection_data(self.data) + .on_connection_init(self.on_init) + .map(|msg| match msg { + WsMessage::Text(text) => ws::Message::text(text), + WsMessage::Close(code, status) => ws::Message::close_with(code, status), + }) + .map(Ok) + .forward(ws_sender) + .await; + } }