diff --git a/integrations/tide/src/lib.rs b/integrations/tide/src/lib.rs index 3a121257..319b4f29 100644 --- a/integrations/tide/src/lib.rs +++ b/integrations/tide/src/lib.rs @@ -23,7 +23,7 @@ use tide::{ }; #[cfg(feature = "websocket")] -pub use subscription::Subscription; +pub use subscription::SubscriptionBuilder; /// Create a new GraphQL endpoint with the schema. /// diff --git a/integrations/tide/src/subscription.rs b/integrations/tide/src/subscription.rs index 11037cf5..0c3e04d4 100644 --- a/integrations/tide/src/subscription.rs +++ b/integrations/tide/src/subscription.rs @@ -1,63 +1,122 @@ use std::future::Future; +use std::marker::PhantomData; use std::pin::Pin; use std::str::FromStr; use async_graphql::http::{WebSocket as AGWebSocket, WebSocketProtocols, WsMessage}; use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; +use futures_util::future::Ready; use futures_util::{future, StreamExt}; -use tide::{Endpoint, Request, Response}; +use tide::{Endpoint, Request, Response, StatusCode}; use tide_websockets::Message; -/// GraphQL subscription endpoint. +type DefaultOnConnCreateType = fn(&Request) -> Ready>; + +fn default_on_connection_create(_: &Request) -> Ready> { + futures_util::future::ready(Ok(Data::default())) +} + +type DefaultOnConnInitType = fn(serde_json::Value) -> Ready>; + +fn default_on_connection_init(_: serde_json::Value) -> Ready> { + futures_util::future::ready(Ok(Data::default())) +} + +/// GraphQL subscription builder. #[cfg_attr(docsrs, doc(cfg(feature = "websocket")))] -pub struct Subscription { - inner: Pin>>, +pub struct SubscriptionBuilder { + schema: Schema, + on_connection_create: OnCreate, + on_connection_init: OnInit, + _mark: PhantomData, } -#[async_trait::async_trait] -impl Endpoint for Subscription -where - S: Send + Sync + Clone + 'static, +impl + SubscriptionBuilder< + S, + Query, + Mutation, + Subscription, + DefaultOnConnCreateType, + DefaultOnConnInitType, + > { - async fn call(&self, req: Request) -> tide::Result { - self.inner.call(req).await + /// Create a GraphQL subscription builder. + pub fn new(schema: Schema) -> Self { + Self { + schema, + on_connection_create: default_on_connection_create, + on_connection_init: default_on_connection_init, + _mark: Default::default(), + } } } -impl Subscription -where - S: Send + Sync + Clone + 'static, +impl + SubscriptionBuilder { - /// Create a graphql subscription endpoint. - pub fn new(schema: Schema) -> Self - where - Query: ObjectType + 'static, - Mutation: ObjectType + 'static, - Subscription: SubscriptionType + 'static, - { - Self::new_with_initializer(schema, |_| { - futures_util::future::ready(Ok(Default::default())) - }) - } - - /// Create a graphql subscription endpoint. + /// Specify the callback function to be called when the connection is created. /// - /// Specifies that a function converts the init payload to data. - pub fn new_with_initializer( - schema: Schema, - initializer: F, - ) -> Self + /// You can get something from the incoming request to create [`Data`]. + pub fn on_connection_create( + self, + callback: OnCreate2, + ) -> SubscriptionBuilder where - Query: ObjectType + 'static, - Mutation: ObjectType + 'static, - Subscription: SubscriptionType + 'static, - F: FnOnce(serde_json::Value) -> R + Unpin + Send + Sync + Clone + 'static, - R: Future> + Send + 'static, + OnCreate2: Fn(&Request) -> Fut + Clone + Send + Sync + 'static, + Fut: Future> + Send + 'static, { - let endpoint = tide_websockets::WebSocket::::new(move |request, connection| { - let schema = schema.clone(); - let initializer = initializer.clone(); + SubscriptionBuilder { + schema: self.schema, + on_connection_create: callback, + on_connection_init: self.on_connection_init, + _mark: Default::default(), + } + } + + /// Specify a callback function to be called when the connection is initialized. + /// + /// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`]. + pub fn on_connection_init( + self, + callback: OnInit2, + ) -> SubscriptionBuilder + where + OnInit2: FnOnce(serde_json::Value) -> Fut + Clone + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + SubscriptionBuilder { + schema: self.schema, + on_connection_create: self.on_connection_create, + on_connection_init: callback, + _mark: Default::default(), + } + } +} + +impl + SubscriptionBuilder +where + S: Send + Sync + Clone + 'static, + Query: ObjectType + 'static, + Mutation: ObjectType + 'static, + Subscription: SubscriptionType + 'static, + OnCreate: Fn(&Request) -> OnCreateFut + Send + Clone + Sync + 'static, + OnCreateFut: Future> + Send + 'static, + OnInit: FnOnce(serde_json::Value) -> OnInitFut + Clone + Send + Sync + 'static, + OnInitFut: Future> + Send + 'static, +{ + pub fn build(self) -> impl Endpoint { + tide_websockets::WebSocket::::new(move |request, connection| { + let schema = self.schema.clone(); + let on_connection_create = self.on_connection_create.clone(); + let on_connection_init = self.on_connection_init.clone(); + async move { + let data = on_connection_create(&request) + .await + .map_err(|_| tide::Error::from_str(StatusCode::BadRequest, "bad request"))?; + let protocol = match request .header("sec-websocket-protocol") .map(|value| value.as_str()) @@ -74,15 +133,17 @@ where }; let sink = connection.clone(); - let mut stream = AGWebSocket::with_data( + let mut stream = AGWebSocket::new( schema.clone(), connection .take_while(|msg| future::ready(msg.is_ok())) .map(Result::unwrap) .map(Message::into_data), - initializer, protocol, - ); + ) + .connection_data(data) + .on_connection_init(on_connection_init); + while let Some(data) = stream.next().await { match data { WsMessage::Text(text) => { @@ -100,9 +161,6 @@ where Ok(()) } }) - .with_protocols(&["graphql-transport-ws", "graphql-ws"]); - Self { - inner: Box::pin(endpoint), - } + .with_protocols(&["graphql-transport-ws", "graphql-ws"]) } }