2020-09-11 08:05:21 +00:00
//! WebSocket transport for subscription
2022-04-19 04:25:11 +00:00
use std ::{
collections ::HashMap ,
future ::Future ,
pin ::Pin ,
sync ::Arc ,
task ::{ Context , Poll } ,
} ;
2020-09-17 18:22:54 +00:00
2022-04-19 04:25:11 +00:00
use futures_util ::{
future ::{ BoxFuture , Ready } ,
stream ::Stream ,
FutureExt , StreamExt ,
} ;
2020-10-15 06:38:10 +00:00
use pin_project_lite ::pin_project ;
use serde ::{ Deserialize , Serialize } ;
use crate ::{ Data , Error , ObjectType , Request , Response , Result , Schema , SubscriptionType } ;
2021-08-01 09:44:28 +00:00
/// All known protocols based on WebSocket.
pub const ALL_WEBSOCKET_PROTOCOLS : [ & str ; 2 ] = [ " graphql-transport-ws " , " graphql-ws " ] ;
2021-02-06 03:02:41 +00:00
/// An enum representing the various forms of a WebSocket message.
#[ derive(Clone, Debug) ]
pub enum WsMessage {
/// A text WebSocket message
Text ( String ) ,
/// A close message with the close frame.
Close ( u16 , String ) ,
}
impl WsMessage {
2022-04-19 04:25:11 +00:00
/// Returns the contained [WsMessage::Text] value, consuming the `self`
/// value.
2021-02-06 03:02:41 +00:00
///
/// Because this function may panic, its use is generally discouraged.
///
/// # Panics
///
2021-03-22 02:39:17 +00:00
/// Panics if the self value not equals [WsMessage::Text].
2021-02-06 03:02:41 +00:00
pub fn unwrap_text ( self ) -> String {
match self {
Self ::Text ( text ) = > text ,
2021-06-18 02:43:34 +00:00
Self ::Close ( _ , _ ) = > panic! ( " Not a text message " ) ,
2021-02-06 03:02:41 +00:00
}
}
2022-04-19 04:25:11 +00:00
/// Returns the contained [WsMessage::Close] value, consuming the `self`
/// value.
2021-02-06 03:02:41 +00:00
///
/// Because this function may panic, its use is generally discouraged.
///
/// # Panics
///
2021-03-22 02:39:17 +00:00
/// Panics if the self value not equals [WsMessage::Close].
2021-02-06 03:02:41 +00:00
pub fn unwrap_close ( self ) -> ( u16 , String ) {
match self {
Self ::Close ( code , msg ) = > ( code , msg ) ,
2021-06-18 02:43:34 +00:00
Self ::Text ( _ ) = > panic! ( " Not a close message " ) ,
2021-02-06 03:02:41 +00:00
}
}
}
2020-09-17 18:22:54 +00:00
pin_project! {
/// A GraphQL connection over websocket.
///
2021-11-10 12:03:09 +00:00
/// # References
///
/// - [subscriptions-transport-ws](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md)
/// - [graphql-ws](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md)
pub struct WebSocket < S , Query , Mutation , Subscription , OnInit > {
on_connection_init : Option < OnInit > ,
2021-01-16 04:47:32 +00:00
init_fut : Option < BoxFuture < 'static , Result < Data > > > ,
2021-11-09 09:01:51 +00:00
connection_data : Option < Data > ,
2021-03-24 09:24:12 +00:00
data : Option < Arc < Data > > ,
2020-09-17 18:22:54 +00:00
schema : Schema < Query , Mutation , Subscription > ,
streams : HashMap < String , Pin < Box < dyn Stream < Item = Response > + Send > > > ,
#[ pin ]
stream : S ,
2020-12-03 21:16:07 +00:00
protocol : Protocols ,
2020-09-15 03:12:19 +00:00
}
}
2021-05-06 20:59:48 +00:00
type MessageMapStream < S > =
futures_util ::stream ::Map < S , fn ( < S as Stream > ::Item ) -> serde_json ::Result < ClientMessage > > ;
2021-11-10 12:03:09 +00:00
type DefaultOnConnInitType = fn ( serde_json ::Value ) -> Ready < Result < Data > > ;
fn default_on_connection_init ( _ : serde_json ::Value ) -> Ready < Result < Data > > {
futures_util ::future ::ready ( Ok ( Data ::default ( ) ) )
}
impl < S , Query , Mutation , Subscription >
WebSocket < S , Query , Mutation , Subscription , DefaultOnConnInitType >
2021-05-06 20:59:48 +00:00
where
2021-11-09 09:01:51 +00:00
S : Stream < Item = serde_json ::Result < ClientMessage > > ,
2020-09-17 18:22:54 +00:00
{
2021-11-09 09:01:51 +00:00
/// Create a new websocket from [`ClientMessage`] stream.
pub fn from_message_stream (
2020-12-03 21:16:07 +00:00
schema : Schema < Query , Mutation , Subscription > ,
stream : S ,
protocol : Protocols ,
) -> Self {
2021-11-09 09:01:51 +00:00
WebSocket {
2021-11-10 12:03:09 +00:00
on_connection_init : Some ( default_on_connection_init ) ,
2021-11-09 09:01:51 +00:00
init_fut : None ,
connection_data : None ,
data : None ,
2020-09-17 18:22:54 +00:00
schema ,
2021-11-09 09:01:51 +00:00
streams : HashMap ::new ( ) ,
2020-09-17 18:22:54 +00:00
stream ,
2020-12-03 21:16:07 +00:00
protocol ,
2021-11-09 09:01:51 +00:00
}
2020-09-15 03:12:19 +00:00
}
2021-11-10 12:03:09 +00:00
}
2020-09-11 07:54:56 +00:00
2021-11-10 12:03:09 +00:00
impl < S , Query , Mutation , Subscription >
WebSocket < MessageMapStream < S > , Query , Mutation , Subscription , DefaultOnConnInitType >
where
S : Stream ,
S ::Item : AsRef < [ u8 ] > ,
{
/// Create a new websocket from bytes stream.
pub fn new (
schema : Schema < Query , Mutation , Subscription > ,
stream : S ,
protocol : Protocols ,
) -> Self {
let stream = stream
. map ( ClientMessage ::from_bytes as fn ( S ::Item ) -> serde_json ::Result < ClientMessage > ) ;
WebSocket ::from_message_stream ( schema , stream , protocol )
2021-11-09 09:01:51 +00:00
}
2021-11-10 12:03:09 +00:00
}
2021-05-06 20:59:48 +00:00
2021-11-10 12:03:09 +00:00
impl < S , Query , Mutation , Subscription , OnInit > WebSocket < S , Query , Mutation , Subscription , OnInit >
where
S : Stream < Item = serde_json ::Result < ClientMessage > > ,
{
2021-11-09 09:01:51 +00:00
/// Specify a connection data.
///
/// This data usually comes from HTTP requests.
2022-04-19 04:25:11 +00:00
/// When the `GQL_CONNECTION_INIT` message is received, this data will be
/// merged with the data returned by the closure specified by
/// `with_initializer` into the final subscription context data.
2022-01-11 01:35:09 +00:00
#[ must_use ]
2021-11-09 09:01:51 +00:00
pub fn connection_data ( mut self , data : Data ) -> Self {
self . connection_data = Some ( data ) ;
self
2021-05-06 20:59:48 +00:00
}
2021-11-10 12:03:09 +00:00
/// Specify a connection initialize callback function.
///
2022-04-19 04:25:11 +00:00
/// This function if present, will be called with the data sent by the
/// client in the [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
2021-11-10 12:03:09 +00:00
/// From that point on the returned data will be accessible to all requests.
2022-01-11 01:35:09 +00:00
#[ must_use ]
2021-11-10 12:03:09 +00:00
pub fn on_connection_init < F , R > (
self ,
callback : F ,
) -> WebSocket < S , Query , Mutation , Subscription , F >
2021-11-09 09:01:51 +00:00
where
2021-11-10 12:03:09 +00:00
F : FnOnce ( serde_json ::Value ) -> R + Send + 'static ,
R : Future < Output = Result < Data > > + Send + 'static ,
2021-11-09 09:01:51 +00:00
{
2021-11-10 12:03:09 +00:00
WebSocket {
on_connection_init : Some ( callback ) ,
init_fut : self . init_fut ,
connection_data : self . connection_data ,
data : self . data ,
schema : self . schema ,
streams : self . streams ,
stream : self . stream ,
protocol : self . protocol ,
}
2020-09-11 07:54:56 +00:00
}
}
2021-11-10 12:03:09 +00:00
impl < S , Query , Mutation , Subscription , OnInit , InitFut > Stream
for WebSocket < S , Query , Mutation , Subscription , OnInit >
2020-09-11 07:54:56 +00:00
where
2021-05-06 20:59:48 +00:00
S : Stream < Item = serde_json ::Result < ClientMessage > > ,
2021-01-14 04:41:59 +00:00
Query : ObjectType + 'static ,
Mutation : ObjectType + 'static ,
Subscription : SubscriptionType + 'static ,
2021-11-10 12:03:09 +00:00
OnInit : FnOnce ( serde_json ::Value ) -> InitFut + Send + 'static ,
InitFut : Future < Output = Result < Data > > + Send + 'static ,
2020-09-11 07:54:56 +00:00
{
2021-02-06 03:02:41 +00:00
type Item = WsMessage ;
2020-09-11 07:54:56 +00:00
2020-09-17 18:22:54 +00:00
fn poll_next ( self : Pin < & mut Self > , cx : & mut Context ) -> Poll < Option < Self ::Item > > {
2020-11-06 14:20:12 +00:00
let mut this = self . project ( ) ;
2020-09-17 18:22:54 +00:00
2021-02-06 01:24:00 +00:00
if this . init_fut . is_none ( ) {
while let Poll ::Ready ( message ) = Pin ::new ( & mut this . stream ) . poll_next ( cx ) {
let message = match message {
Some ( message ) = > message ,
None = > return Poll ::Ready ( None ) ,
} ;
2020-09-17 18:22:54 +00:00
2021-05-06 20:59:48 +00:00
let message : ClientMessage = match message {
2021-02-06 01:24:00 +00:00
Ok ( message ) = > message ,
2021-05-08 06:06:46 +00:00
Err ( err ) = > return Poll ::Ready ( Some ( WsMessage ::Close ( 1002 , err . to_string ( ) ) ) ) ,
2021-02-06 01:24:00 +00:00
} ;
2020-09-17 18:22:54 +00:00
2021-02-06 01:24:00 +00:00
match message {
ClientMessage ::ConnectionInit { payload } = > {
2021-11-10 12:03:09 +00:00
if let Some ( on_connection_init ) = this . on_connection_init . take ( ) {
2021-02-06 01:24:00 +00:00
* this . init_fut = Some ( Box ::pin ( async move {
2021-11-10 12:03:09 +00:00
on_connection_init ( payload . unwrap_or_default ( ) ) . await
2021-02-06 01:24:00 +00:00
} ) ) ;
break ;
2021-02-06 03:02:41 +00:00
} else {
match this . protocol {
Protocols ::SubscriptionsTransportWS = > {
return Poll ::Ready ( Some ( WsMessage ::Text (
serde_json ::to_string ( & ServerMessage ::ConnectionError {
payload : Error ::new (
" Too many initialisation requests. " ,
) ,
} )
. unwrap ( ) ,
) ) ) ;
}
Protocols ::GraphQLWS = > {
return Poll ::Ready ( Some ( WsMessage ::Close (
4429 ,
" Too many initialisation requests. " . to_string ( ) ,
) ) ) ;
}
}
2021-02-06 01:24:00 +00:00
}
2020-11-07 00:24:17 +00:00
}
2021-02-06 01:24:00 +00:00
ClientMessage ::Start {
2020-11-07 00:24:17 +00:00
id ,
2021-02-06 01:24:00 +00:00
payload : request ,
} = > {
2021-03-24 09:24:12 +00:00
if let Some ( data ) = this . data . clone ( ) {
this . streams . insert (
id ,
2021-03-25 08:33:11 +00:00
Box ::pin (
this . schema . execute_stream_with_session_data ( request , data ) ,
) ,
2021-03-24 09:24:12 +00:00
) ;
} else {
return Poll ::Ready ( Some ( WsMessage ::Close (
1011 ,
" The handshake is not completed. " . to_string ( ) ,
) ) ) ;
}
2021-02-06 01:24:00 +00:00
}
ClientMessage ::Stop { id } = > {
2021-05-06 20:59:48 +00:00
if this . streams . remove ( & id ) . is_some ( ) {
2021-02-06 03:02:41 +00:00
return Poll ::Ready ( Some ( WsMessage ::Text (
2021-05-08 06:06:46 +00:00
serde_json ::to_string ( & ServerMessage ::Complete { id : & id } )
. unwrap ( ) ,
2021-02-06 03:02:41 +00:00
) ) ) ;
2021-02-06 01:24:00 +00:00
}
2020-09-11 07:54:56 +00:00
}
2021-02-06 01:24:00 +00:00
// Note: in the revised `graphql-ws` spec, there is no equivalent to the
// `CONNECTION_TERMINATE` `client -> server` message; rather, disconnection is
// handled by disconnecting the websocket
ClientMessage ::ConnectionTerminate = > return Poll ::Ready ( None ) ,
2021-09-16 12:22:04 +00:00
// Pong must be sent in response from the receiving party as soon as possible.
ClientMessage ::Ping { .. } = > {
return Poll ::Ready ( Some ( WsMessage ::Text (
serde_json ::to_string ( & ServerMessage ::Pong { payload : None } ) . unwrap ( ) ,
) ) ) ;
}
ClientMessage ::Pong { .. } = > {
// Do nothing...
}
2020-09-11 15:43:26 +00:00
}
2020-09-17 18:22:54 +00:00
}
}
2020-09-11 07:54:56 +00:00
2021-01-16 04:47:32 +00:00
if let Some ( init_fut ) = this . init_fut {
if let Poll ::Ready ( res ) = init_fut . poll_unpin ( cx ) {
* this . init_fut = None ;
return match res {
Ok ( data ) = > {
2021-11-09 09:01:51 +00:00
let mut ctx_data = this . connection_data . take ( ) . unwrap_or_default ( ) ;
ctx_data . merge ( data ) ;
* this . data = Some ( Arc ::new ( ctx_data ) ) ;
2021-02-06 03:02:41 +00:00
Poll ::Ready ( Some ( WsMessage ::Text (
2021-01-16 04:47:32 +00:00
serde_json ::to_string ( & ServerMessage ::ConnectionAck ) . unwrap ( ) ,
2021-02-06 03:02:41 +00:00
) ) )
2021-01-16 04:47:32 +00:00
}
2021-02-06 03:02:41 +00:00
Err ( err ) = > match this . protocol {
Protocols ::SubscriptionsTransportWS = > Poll ::Ready ( Some ( WsMessage ::Text (
serde_json ::to_string ( & ServerMessage ::ConnectionError {
payload : Error ::new ( err . message ) ,
} )
2021-01-16 04:47:32 +00:00
. unwrap ( ) ,
2021-02-06 03:02:41 +00:00
) ) ) ,
Protocols ::GraphQLWS = > {
Poll ::Ready ( Some ( WsMessage ::Close ( 1002 , err . message ) ) )
}
} ,
2021-01-16 04:47:32 +00:00
} ;
}
}
2020-09-17 18:22:54 +00:00
for ( id , stream ) in & mut * this . streams {
match Pin ::new ( stream ) . poll_next ( cx ) {
Poll ::Ready ( Some ( payload ) ) = > {
2021-02-06 03:02:41 +00:00
return Poll ::Ready ( Some ( WsMessage ::Text (
2020-12-03 21:16:07 +00:00
serde_json ::to_string ( & this . protocol . next_message ( id , payload ) ) . unwrap ( ) ,
2021-02-06 03:02:41 +00:00
) ) ) ;
2020-09-11 07:54:56 +00:00
}
2020-09-17 18:22:54 +00:00
Poll ::Ready ( None ) = > {
let id = id . clone ( ) ;
this . streams . remove ( & id ) ;
2021-02-06 03:02:41 +00:00
return Poll ::Ready ( Some ( WsMessage ::Text (
2020-09-17 18:22:54 +00:00
serde_json ::to_string ( & ServerMessage ::Complete { id : & id } ) . unwrap ( ) ,
2021-02-06 03:02:41 +00:00
) ) ) ;
2020-09-11 08:41:56 +00:00
}
2020-09-17 18:22:54 +00:00
Poll ::Pending = > { }
2020-09-11 07:54:56 +00:00
}
}
2020-09-17 18:22:54 +00:00
Poll ::Pending
2020-09-11 07:54:56 +00:00
}
}
2020-12-03 21:16:07 +00:00
/// Specification of which GraphQL Over WebSockets protocol is being utilized
2021-08-01 09:44:28 +00:00
#[ derive(Debug, Copy, Clone, Eq, PartialEq, Hash) ]
2020-12-03 21:16:07 +00:00
pub enum Protocols {
/// [subscriptions-transport-ws protocol](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md).
SubscriptionsTransportWS ,
/// [graphql-ws protocol](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md).
GraphQLWS ,
}
impl Protocols {
/// Returns the `Sec-WebSocket-Protocol` header value for the protocol
2021-08-01 09:44:28 +00:00
pub fn sec_websocket_protocol ( & self ) -> & 'static str {
2020-12-03 21:16:07 +00:00
match self {
Protocols ::SubscriptionsTransportWS = > " graphql-ws " ,
Protocols ::GraphQLWS = > " graphql-transport-ws " ,
}
}
#[ inline ]
fn next_message < ' s > ( & self , id : & ' s str , payload : Response ) -> ServerMessage < ' s > {
match self {
2021-03-09 04:07:12 +00:00
Protocols ::SubscriptionsTransportWS = > ServerMessage ::Data { id , payload } ,
Protocols ::GraphQLWS = > ServerMessage ::Next { id , payload } ,
2020-12-03 21:16:07 +00:00
}
}
}
impl std ::str ::FromStr for Protocols {
type Err = Error ;
fn from_str ( protocol : & str ) -> Result < Self , Self ::Err > {
if protocol . eq_ignore_ascii_case ( " graphql-ws " ) {
Ok ( Protocols ::SubscriptionsTransportWS )
} else if protocol . eq_ignore_ascii_case ( " graphql-transport-ws " ) {
Ok ( Protocols ::GraphQLWS )
} else {
Err ( Error ::new ( format! (
" Unsupported Sec-WebSocket-Protocol: {} " ,
protocol
) ) )
}
}
}
2021-05-06 20:59:48 +00:00
/// A websocket message received from the client
2020-09-17 18:22:54 +00:00
#[ derive(Deserialize) ]
#[ serde(tag = " type " , rename_all = " snake_case " ) ]
2022-04-19 04:25:11 +00:00
#[ allow(clippy::large_enum_variant) ] // Request is at fault
2021-05-06 20:59:48 +00:00
pub enum ClientMessage {
/// A new connection
2020-12-02 13:41:12 +00:00
ConnectionInit {
2021-05-06 20:59:48 +00:00
/// Optional init payload from the client
2020-12-02 13:41:12 +00:00
payload : Option < serde_json ::Value > ,
} ,
2021-05-06 20:59:48 +00:00
/// The start of a Websocket subscription
2020-12-03 21:16:07 +00:00
#[ serde(alias = " subscribe " ) ]
2020-12-02 13:41:12 +00:00
Start {
2021-05-06 20:59:48 +00:00
/// Message ID
2020-12-02 13:41:12 +00:00
id : String ,
2022-04-19 04:25:11 +00:00
/// The GraphQL Request - this can be modified by protocol implementors
/// to add files uploads.
2020-12-02 13:41:12 +00:00
payload : Request ,
} ,
2021-05-06 20:59:48 +00:00
/// The end of a Websocket subscription
2020-12-03 21:16:07 +00:00
#[ serde(alias = " complete " ) ]
2020-12-02 13:41:12 +00:00
Stop {
2021-05-06 20:59:48 +00:00
/// Message ID
id : String ,
2020-12-02 13:41:12 +00:00
} ,
2021-05-06 20:59:48 +00:00
/// Connection terminated by the client
2020-09-17 18:22:54 +00:00
ConnectionTerminate ,
2022-04-19 04:25:11 +00:00
/// Useful for detecting failed connections, displaying latency metrics or
/// other types of network probing.
2021-09-16 12:22:04 +00:00
///
/// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#ping
Ping {
/// Additional details about the ping.
payload : Option < serde_json ::Value > ,
} ,
/// The response to the Ping message.
///
/// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong
Pong {
/// Additional details about the pong.
payload : Option < serde_json ::Value > ,
} ,
2020-09-17 18:22:54 +00:00
}
2021-05-06 20:59:48 +00:00
impl ClientMessage {
/// Creates a ClientMessage from an array of bytes
pub fn from_bytes < T > ( message : T ) -> serde_json ::Result < Self >
where
T : AsRef < [ u8 ] > ,
{
serde_json ::from_slice ( message . as_ref ( ) )
}
}
2020-09-17 18:22:54 +00:00
#[ derive(Serialize) ]
#[ serde(tag = " type " , rename_all = " snake_case " ) ]
enum ServerMessage < ' a > {
2020-12-02 13:41:12 +00:00
ConnectionError {
payload : Error ,
} ,
2020-09-17 18:22:54 +00:00
ConnectionAck ,
2020-12-03 21:16:07 +00:00
/// subscriptions-transport-ws protocol next payload
2020-12-02 13:41:12 +00:00
Data {
id : & ' a str ,
2021-03-09 04:07:12 +00:00
payload : Response ,
2020-12-02 13:41:12 +00:00
} ,
2020-12-03 21:16:07 +00:00
/// graphql-ws protocol next payload
Next {
id : & ' a str ,
2021-03-09 04:07:12 +00:00
payload : Response ,
2020-12-03 21:16:07 +00:00
} ,
2020-09-17 18:22:54 +00:00
// Not used by this library, as it's not necessary to send
// Error {
// id: &'a str,
// payload: serde_json::Value,
// },
2020-12-02 13:41:12 +00:00
Complete {
id : & ' a str ,
} ,
2021-09-16 12:22:04 +00:00
/// The response to the Ping message.
///
/// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md#pong
Pong {
#[ serde(skip_serializing_if = " Option::is_none " ) ]
payload : Option < serde_json ::Value > ,
} ,
2020-09-17 18:22:54 +00:00
// Not used by this library
// #[serde(rename = "ka")]
// KeepAlive
}