2021-02-06 01:21:35 +00:00
use std ::future ::Future ;
2021-11-10 13:00:02 +00:00
use std ::marker ::PhantomData ;
2021-02-06 01:21:35 +00:00
use std ::str ::FromStr ;
2021-02-06 03:02:41 +00:00
use async_graphql ::http ::{ WebSocket as AGWebSocket , WebSocketProtocols , WsMessage } ;
2021-02-06 01:21:35 +00:00
use async_graphql ::{ Data , ObjectType , Result , Schema , SubscriptionType } ;
2021-11-10 13:00:02 +00:00
use futures_util ::future ::Ready ;
2021-02-06 01:21:35 +00:00
use futures_util ::{ future , StreamExt } ;
2021-11-11 23:45:49 +00:00
use tide ::{ Endpoint , Request , StatusCode } ;
2021-02-06 01:21:35 +00:00
use tide_websockets ::Message ;
2021-11-10 13:00:02 +00:00
type DefaultOnConnCreateType < S > = fn ( & Request < S > ) -> Ready < Result < Data > > ;
fn default_on_connection_create < S > ( _ : & Request < S > ) -> Ready < Result < Data > > {
futures_util ::future ::ready ( Ok ( Data ::default ( ) ) )
}
type DefaultOnConnInitType = fn ( serde_json ::Value ) -> Ready < Result < Data > > ;
fn default_on_connection_init ( _ : serde_json ::Value ) -> Ready < Result < Data > > {
futures_util ::future ::ready ( Ok ( Data ::default ( ) ) )
}
/// GraphQL subscription builder.
2021-09-17 00:39:28 +00:00
#[ cfg_attr(docsrs, doc(cfg(feature = " websocket " ))) ]
2021-11-11 23:45:49 +00:00
pub struct GraphQLSubscriptionBuilder < TideState , Query , Mutation , Subscription , OnCreate , OnInit > {
2021-11-10 13:00:02 +00:00
schema : Schema < Query , Mutation , Subscription > ,
on_connection_create : OnCreate ,
on_connection_init : OnInit ,
2021-11-11 23:45:49 +00:00
_mark : PhantomData < TideState > ,
2021-02-06 01:21:35 +00:00
}
2021-11-11 23:45:49 +00:00
impl < TideState , Query , Mutation , Subscription >
GraphQLSubscriptionBuilder <
TideState ,
2021-11-10 13:00:02 +00:00
Query ,
Mutation ,
Subscription ,
2021-11-11 23:45:49 +00:00
DefaultOnConnCreateType < TideState > ,
2021-11-10 13:00:02 +00:00
DefaultOnConnInitType ,
>
2021-02-06 01:21:35 +00:00
{
2021-11-10 13:00:02 +00:00
/// Create a GraphQL subscription builder.
pub fn new ( schema : Schema < Query , Mutation , Subscription > ) -> Self {
Self {
schema ,
on_connection_create : default_on_connection_create ,
on_connection_init : default_on_connection_init ,
_mark : Default ::default ( ) ,
}
2021-02-06 01:21:35 +00:00
}
}
2021-11-10 13:00:02 +00:00
impl < S , Query , Mutation , Subscription , OnCreate , OnInit >
2021-11-11 23:45:49 +00:00
GraphQLSubscriptionBuilder < S , Query , Mutation , Subscription , OnCreate , OnInit >
2021-02-06 01:21:35 +00:00
{
2021-11-10 13:00:02 +00:00
/// Specify the callback function to be called when the connection is created.
///
/// You can get something from the incoming request to create [`Data`].
pub fn on_connection_create < OnCreate2 , Fut > (
self ,
callback : OnCreate2 ,
2021-11-11 23:45:49 +00:00
) -> GraphQLSubscriptionBuilder < S , Query , Mutation , Subscription , OnCreate2 , OnInit >
2021-02-06 01:21:35 +00:00
where
2021-11-10 13:00:02 +00:00
OnCreate2 : Fn ( & Request < S > ) -> Fut + Clone + Send + Sync + 'static ,
Fut : Future < Output = Result < Data > > + Send + 'static ,
2021-02-06 01:21:35 +00:00
{
2021-11-11 23:45:49 +00:00
GraphQLSubscriptionBuilder {
2021-11-10 13:00:02 +00:00
schema : self . schema ,
on_connection_create : callback ,
on_connection_init : self . on_connection_init ,
_mark : Default ::default ( ) ,
}
2021-02-06 01:21:35 +00:00
}
2021-11-10 13:00:02 +00:00
/// Specify a callback function to be called when the connection is initialized.
2021-02-06 01:21:35 +00:00
///
2021-11-10 13:00:02 +00:00
/// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`].
pub fn on_connection_init < OnInit2 , Fut > (
self ,
callback : OnInit2 ,
2021-11-11 23:45:49 +00:00
) -> GraphQLSubscriptionBuilder < S , Query , Mutation , Subscription , OnCreate , OnInit2 >
2021-02-06 01:21:35 +00:00
where
2021-11-10 13:00:02 +00:00
OnInit2 : FnOnce ( serde_json ::Value ) -> Fut + Clone + Send + Sync + 'static ,
Fut : Future < Output = Result < Data > > + Send + 'static ,
2021-02-06 01:21:35 +00:00
{
2021-11-11 23:45:49 +00:00
GraphQLSubscriptionBuilder {
2021-11-10 13:00:02 +00:00
schema : self . schema ,
on_connection_create : self . on_connection_create ,
on_connection_init : callback ,
_mark : Default ::default ( ) ,
}
}
}
2021-11-11 23:45:49 +00:00
impl < TideState , Query , Mutation , Subscription , OnCreate , OnCreateFut , OnInit , OnInitFut >
GraphQLSubscriptionBuilder < TideState , Query , Mutation , Subscription , OnCreate , OnInit >
2021-11-10 13:00:02 +00:00
where
2021-11-11 23:45:49 +00:00
TideState : Send + Sync + Clone + 'static ,
2021-11-10 13:00:02 +00:00
Query : ObjectType + 'static ,
Mutation : ObjectType + 'static ,
Subscription : SubscriptionType + 'static ,
2021-11-11 23:45:49 +00:00
OnCreate : Fn ( & Request < TideState > ) -> OnCreateFut + Send + Clone + Sync + 'static ,
2021-11-10 13:00:02 +00:00
OnCreateFut : Future < Output = async_graphql ::Result < Data > > + Send + 'static ,
OnInit : FnOnce ( serde_json ::Value ) -> OnInitFut + Clone + Send + Sync + 'static ,
OnInitFut : Future < Output = async_graphql ::Result < Data > > + Send + 'static ,
{
2021-11-11 23:45:49 +00:00
/// Create an endpoint for graphql subscription.
pub fn build ( self ) -> impl Endpoint < TideState > {
tide_websockets ::WebSocket ::< TideState , _ > ::new ( move | request , connection | {
2021-11-10 13:00:02 +00:00
let schema = self . schema . clone ( ) ;
let on_connection_create = self . on_connection_create . clone ( ) ;
let on_connection_init = self . on_connection_init . clone ( ) ;
2021-02-06 01:21:35 +00:00
async move {
2021-11-10 13:00:02 +00:00
let data = on_connection_create ( & request )
. await
. map_err ( | _ | tide ::Error ::from_str ( StatusCode ::BadRequest , " bad request " ) ) ? ;
2021-02-06 01:21:35 +00:00
let protocol = match request
. header ( " sec-websocket-protocol " )
. map ( | value | value . as_str ( ) )
. and_then ( | protocols | {
protocols
. split ( ',' )
. find_map ( | p | WebSocketProtocols ::from_str ( p . trim ( ) ) . ok ( ) )
} ) {
Some ( protocol ) = > protocol ,
None = > {
// default to the prior standard
WebSocketProtocols ::SubscriptionsTransportWS
}
} ;
let sink = connection . clone ( ) ;
2021-11-10 13:00:02 +00:00
let mut stream = AGWebSocket ::new (
2021-02-06 01:21:35 +00:00
schema . clone ( ) ,
connection
. take_while ( | msg | future ::ready ( msg . is_ok ( ) ) )
. map ( Result ::unwrap )
. map ( Message ::into_data ) ,
protocol ,
2021-11-10 13:00:02 +00:00
)
. connection_data ( data )
. on_connection_init ( on_connection_init ) ;
2021-02-06 01:21:35 +00:00
while let Some ( data ) = stream . next ( ) . await {
2021-02-06 03:02:41 +00:00
match data {
WsMessage ::Text ( text ) = > {
if sink . send_string ( text ) . await . is_err ( ) {
break ;
}
}
WsMessage ::Close ( _code , _msg ) = > {
// TODO: Send close frame
break ;
}
2021-02-06 01:21:35 +00:00
}
}
Ok ( ( ) )
}
} )
2021-11-10 13:00:02 +00:00
. with_protocols ( & [ " graphql-transport-ws " , " graphql-ws " ] )
2021-02-06 01:21:35 +00:00
}
}