diff --git a/src/http/mod.rs b/src/http/mod.rs index 0659a524..d82fce3c 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -12,7 +12,12 @@ use crate::{BatchRequest, ParseRequestError, Request}; pub use graphiql_source::graphiql_source; pub use multipart::MultipartOptions; pub use playground_source::{playground_source, GraphQLPlaygroundConfig}; -pub use websocket::{Protocols as WebSocketProtocols, WebSocket, WsMessage}; +pub use websocket::{ + Protocols as WebSocketProtocols, + WebSocket, + WsMessage, + ClientMessage, +}; /// Receive a GraphQL request from a content type and body. pub async fn receive_body( diff --git a/src/http/websocket.rs b/src/http/websocket.rs index 35ac64e9..3e31e926 100644 --- a/src/http/websocket.rs +++ b/src/http/websocket.rs @@ -1,12 +1,12 @@ //! WebSocket transport for subscription -use std::collections::HashMap; +use std::{collections::HashMap}; use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use futures_util::future::{BoxFuture, Ready}; +use futures_util::{StreamExt, future::{BoxFuture, Ready}}; use futures_util::stream::Stream; use futures_util::FutureExt; use pin_project_lite::pin_project; @@ -70,8 +70,14 @@ pin_project! { } } +type MessageMapStream = + futures_util::stream::Map::Item) -> serde_json::Result>; + impl - WebSocket Ready>, Query, Mutation, Subscription> + WebSocket, fn(serde_json::Value) -> Ready>, Query, Mutation, Subscription> +where + S: Stream, + S::Item: AsRef<[u8]>, { /// Create a new websocket. #[must_use] @@ -89,7 +95,12 @@ impl } } -impl WebSocket { +impl + WebSocket, F, Query, Mutation, Subscription> +where + S: Stream, + S::Item: AsRef<[u8]>, +{ /// Create a new websocket with a data initialization function. /// /// This function, if present, will be called with the data sent by the client in the @@ -101,8 +112,42 @@ impl WebSocket Self { - Self { + ) -> Self + { + // let stream = stream.map(|message| serde_json::from_slice(message.as_ref())); + let stream = stream.map( + ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result, + ); + + Self::with_message_stream( + schema, + stream, + data_initializer, + protocol, + ) + } +} + +impl + WebSocket +where + S: Stream>, +{ + /// Create a new websocket with a data initialization function from a stream of ClientMessage + /// structs. + /// + /// This function, if present, will be called with the data sent by the client in the + /// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init). + /// From that point on the returned data will be accessible to all requests. + #[must_use] + pub fn with_message_stream( + schema: Schema, + stream: S, + data_initializer: F, + protocol: Protocols, + ) -> Self + { + WebSocket { data_initializer: Some(data_initializer), init_fut: None, data: None, @@ -117,8 +162,7 @@ impl WebSocket Stream for WebSocket where - S: Stream, - S::Item: AsRef<[u8]>, + S: Stream>, F: FnOnce(serde_json::Value) -> R + Send + 'static, R: Future> + Send + 'static, Query: ObjectType + 'static, @@ -137,9 +181,11 @@ where None => return Poll::Ready(None), }; - let message: ClientMessage = match serde_json::from_slice(message.as_ref()) { + let message: ClientMessage = match message { Ok(message) => message, - Err(err) => return Poll::Ready(Some(WsMessage::Close(1002, err.to_string()))), + Err(err) => { + return Poll::Ready(Some(WsMessage::Close(1002, err.to_string()))) + } }; match message { @@ -189,9 +235,9 @@ where } } ClientMessage::Stop { id } => { - if this.streams.remove(id).is_some() { + if this.streams.remove(&id).is_some() { return Poll::Ready(Some(WsMessage::Text( - serde_json::to_string(&ServerMessage::Complete { id }).unwrap(), + serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(), ))); } } @@ -294,24 +340,44 @@ impl std::str::FromStr for Protocols { } } +/// A websocket message received from the client #[derive(Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] -enum ClientMessage<'a> { +pub enum ClientMessage { + /// A new connection ConnectionInit { + /// Optional init payload from the client payload: Option, }, + /// The start of a Websocket subscription #[serde(alias = "subscribe")] Start { + /// Message ID id: String, + /// The GraphQL Request - this can be modified by protocol implementors to add files + /// uploads. payload: Request, }, + /// The end of a Websocket subscription #[serde(alias = "complete")] Stop { - id: &'a str, + /// Message ID + id: String, }, + /// Connection terminated by the client ConnectionTerminate, } +impl ClientMessage { + /// Creates a ClientMessage from an array of bytes + pub fn from_bytes(message: T) -> serde_json::Result + where + T: AsRef<[u8]>, + { + serde_json::from_slice(message.as_ref()) + } +} + #[derive(Serialize)] #[serde(tag = "type", rename_all = "snake_case")] enum ServerMessage<'a> {