diff --git a/examples b/examples index 90347775..b36b5c44 160000 --- a/examples +++ b/examples @@ -1 +1 @@ -Subproject commit 9034777541c542d1f53d5b8adcc988b8e53a6ed8 +Subproject commit b36b5c44543b7323cb199ed229ea247e83b85d18 diff --git a/integrations/tide/Cargo.toml b/integrations/tide/Cargo.toml index 48daed53..318cdc61 100644 --- a/integrations/tide/Cargo.toml +++ b/integrations/tide/Cargo.toml @@ -13,8 +13,12 @@ categories = ["network-programming", "asynchronous"] [dependencies] async-graphql = { path = "../..", version = "=2.5.1" } +async-trait = "0.1.42" +futures-util = "0.3.12" +serde_json = "1.0.62" tide = { version = "0.16.0", default-features = false, features = ["h1-server"] } +tide-websockets = "0.2.0" [dev-dependencies] # Surf lacks multipart support diff --git a/integrations/tide/src/lib.rs b/integrations/tide/src/lib.rs index 6194d4a6..3f1dd462 100644 --- a/integrations/tide/src/lib.rs +++ b/integrations/tide/src/lib.rs @@ -11,6 +11,8 @@ #![allow(clippy::needless_doctest_main)] #![forbid(unsafe_code)] +mod subscription; + use async_graphql::http::MultipartOptions; use async_graphql::{ObjectType, ParseRequestError, Schema, SubscriptionType}; use tide::utils::async_trait; @@ -22,6 +24,8 @@ use tide::{ Body, Request, Response, StatusCode, }; +pub use subscription::Subscription; + /// Create a new GraphQL endpoint with the schema. /// /// Default multipart options are used and batch operations are supported. diff --git a/integrations/tide/src/subscription.rs b/integrations/tide/src/subscription.rs new file mode 100644 index 00000000..d6874b13 --- /dev/null +++ b/integrations/tide/src/subscription.rs @@ -0,0 +1,99 @@ +use std::future::Future; +use std::pin::Pin; +use std::str::FromStr; + +use async_graphql::http::{WebSocket as AGWebSocket, WebSocketProtocols}; +use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; +use futures_util::{future, StreamExt}; +use tide::{Endpoint, Request, Response}; +use tide_websockets::Message; + +/// GraphQL subscription endpoint. +pub struct Subscription { + inner: Pin>>, +} + +#[async_trait::async_trait] +impl Endpoint for Subscription +where + S: Send + Sync + Clone + 'static, +{ + async fn call(&self, req: Request) -> tide::Result { + self.inner.call(req).await + } +} + +impl Subscription +where + S: Send + Sync + Clone + 'static, +{ + /// Create a graphql subscription endpoint. + pub fn new(schema: Schema) -> Self + where + Query: ObjectType + 'static, + Mutation: ObjectType + 'static, + Subscription: SubscriptionType + 'static, + { + Self::new_with_initializer(schema, |_| { + futures_util::future::ready(Ok(Default::default())) + }) + } + + /// Create a graphql subscription endpoint. + /// + /// Specifies that a function converts the init payload to data. + pub fn new_with_initializer( + schema: Schema, + initializer: F, + ) -> Self + where + Query: ObjectType + 'static, + Mutation: ObjectType + 'static, + Subscription: SubscriptionType + 'static, + F: FnOnce(serde_json::Value) -> R + Unpin + Send + Sync + Clone + 'static, + R: Future> + Send + 'static, + { + let endpoint = tide_websockets::WebSocket::::new(move |request, connection| { + let schema = schema.clone(); + let initializer = initializer.clone(); + async move { + let protocol = match request + .header("sec-websocket-protocol") + .map(|value| value.as_str()) + .and_then(|protocols| { + protocols + .split(',') + .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok()) + }) { + Some(protocol) => protocol, + None => { + // default to the prior standard + WebSocketProtocols::SubscriptionsTransportWS + } + }; + + let sink = connection.clone(); + let mut stream = AGWebSocket::with_data( + schema.clone(), + connection + .take_while(|msg| future::ready(msg.is_ok())) + .map(Result::unwrap) + .map(Message::into_data), + initializer, + protocol, + ); + while let Some(data) = stream.next().await { + if sink.send_string(data).await.is_err() { + break; + } + } + + Ok(()) + } + }) + .with_protocols(&["graphql-transport-ws", "graphql-ws"]); + Self { + inner: Box::pin(endpoint), + } + } +}