2021-08-23 15:16:31 +00:00
use std ::str ::FromStr ;
use async_graphql ::http ::{ WebSocketProtocols , WsMessage , ALL_WEBSOCKET_PROTOCOLS } ;
use async_graphql ::{ Data , ObjectType , Schema , SubscriptionType } ;
use futures_util ::future ::{ self , Ready } ;
use futures_util ::{ Future , SinkExt , StreamExt } ;
2021-11-10 12:03:09 +00:00
use poem ::http ::StatusCode ;
2021-11-12 04:00:38 +00:00
use poem ::web ::websocket ::{ Message , WebSocket , WebSocketStream } ;
2021-11-12 04:58:32 +00:00
use poem ::{ http , Endpoint , FromRequest , IntoResponse , Request , RequestBody , Response , Result } ;
2021-11-12 04:00:38 +00:00
/// A GraphQL protocol extractor.
///
/// It extract GraphQL protocol from `SEC_WEBSOCKET_PROTOCOL` header.
#[ derive(Debug, Copy, Clone, PartialEq, Eq) ]
pub struct GraphQLProtocol ( WebSocketProtocols ) ;
#[ poem::async_trait ]
impl < ' a > FromRequest < ' a > for GraphQLProtocol {
2021-11-12 04:58:32 +00:00
type Error = StatusCode ;
2021-11-12 04:00:38 +00:00
async fn from_request ( req : & ' a Request , _body : & mut RequestBody ) -> Result < Self , Self ::Error > {
req . headers ( )
. get ( http ::header ::SEC_WEBSOCKET_PROTOCOL )
. and_then ( | value | value . to_str ( ) . ok ( ) )
. and_then ( | protocols | {
protocols
. split ( ',' )
. find_map ( | p | WebSocketProtocols ::from_str ( p . trim ( ) ) . ok ( ) )
} )
. map ( Self )
2021-11-12 04:58:32 +00:00
. ok_or_else ( | | StatusCode ::BAD_REQUEST )
2021-11-12 04:00:38 +00:00
}
}
2021-08-23 15:16:31 +00:00
/// A GraphQL subscription endpoint.
///
/// # Example
///
/// ```
2021-10-23 05:33:21 +00:00
/// use poem::{Route, get};
2021-08-23 15:16:31 +00:00
/// use async_graphql_poem::GraphQLSubscription;
/// use async_graphql::{EmptyMutation, Object, Schema, Subscription};
/// use futures_util::{Stream, stream};
///
/// struct Query;
///
/// #[Object]
/// impl Query {
/// async fn value(&self) -> i32 {
/// 100
/// }
/// }
///
/// struct Subscription;
///
/// #[Subscription]
/// impl Subscription {
/// async fn values(&self) -> impl Stream<Item = i32> {
/// stream::iter(vec![1, 2, 3, 4, 5])
/// }
/// }
///
/// type MySchema = Schema<Query, EmptyMutation, Subscription>;
///
2021-09-01 00:15:27 +00:00
/// let schema = Schema::new(Query, EmptyMutation, Subscription);
2021-10-23 05:33:21 +00:00
/// let app = Route::new().at("/ws", get(GraphQLSubscription::new(schema)));
2021-08-23 15:16:31 +00:00
/// ```
2021-11-12 04:00:38 +00:00
pub struct GraphQLSubscription < Query , Mutation , Subscription > {
2021-08-23 15:16:31 +00:00
schema : Schema < Query , Mutation , Subscription > ,
2021-11-10 12:03:09 +00:00
}
2021-11-12 04:00:38 +00:00
impl < Query , Mutation , Subscription > GraphQLSubscription < Query , Mutation , Subscription > {
/// Create a GraphQL subscription endpoint.
pub fn new ( schema : Schema < Query , Mutation , Subscription > ) -> Self {
Self { schema }
}
}
2021-11-10 12:03:09 +00:00
2021-11-12 04:00:38 +00:00
#[ poem::async_trait ]
impl < Query , Mutation , Subscription > Endpoint for GraphQLSubscription < Query , Mutation , Subscription >
where
Query : ObjectType + 'static ,
Mutation : ObjectType + 'static ,
Subscription : SubscriptionType + 'static ,
{
type Output = Result < Response > ;
async fn call ( & self , req : Request ) -> Self ::Output {
let ( req , mut body ) = req . split ( ) ;
let websocket = WebSocket ::from_request ( & req , & mut body ) . await ? ;
let protocol = GraphQLProtocol ::from_request ( & req , & mut body ) . await ? ;
let schema = self . schema . clone ( ) ;
let resp = websocket
. protocols ( ALL_WEBSOCKET_PROTOCOLS )
. on_upgrade ( move | stream | GraphQLWebSocket ::new ( stream , schema , protocol ) . serve ( ) )
. into_response ( ) ;
Ok ( resp )
}
2021-11-10 12:03:09 +00:00
}
2021-11-12 04:00:38 +00:00
type DefaultOnConnInitType = fn ( serde_json ::Value ) -> Ready < async_graphql ::Result < Data > > ;
2021-11-10 12:03:09 +00:00
2021-11-12 04:00:38 +00:00
fn default_on_connection_init ( _ : serde_json ::Value ) -> Ready < async_graphql ::Result < Data > > {
2021-11-10 12:03:09 +00:00
futures_util ::future ::ready ( Ok ( Data ::default ( ) ) )
2021-08-23 15:16:31 +00:00
}
2021-11-12 04:00:38 +00:00
/// A Websocket connection for GraphQL subscription.
pub struct GraphQLWebSocket < Query , Mutation , Subscription , OnConnInit > {
schema : Schema < Query , Mutation , Subscription > ,
stream : WebSocketStream ,
data : Data ,
on_connection_init : OnConnInit ,
protocol : GraphQLProtocol ,
}
2021-08-23 15:16:31 +00:00
impl < Query , Mutation , Subscription >
2021-11-12 04:00:38 +00:00
GraphQLWebSocket < Query , Mutation , Subscription , DefaultOnConnInitType >
where
Query : ObjectType + 'static ,
Mutation : ObjectType + 'static ,
Subscription : SubscriptionType + 'static ,
2021-08-23 15:16:31 +00:00
{
2021-11-12 04:00:38 +00:00
/// Create a [`GraphQLWebSocket`] object.
pub fn new (
stream : WebSocketStream ,
schema : Schema < Query , Mutation , Subscription > ,
protocol : GraphQLProtocol ,
) -> Self {
GraphQLWebSocket {
2021-08-23 15:16:31 +00:00
schema ,
2021-11-12 04:00:38 +00:00
stream ,
data : Data ::default ( ) ,
2021-11-10 12:03:09 +00:00
on_connection_init : default_on_connection_init ,
2021-11-12 04:00:38 +00:00
protocol ,
2021-08-23 15:16:31 +00:00
}
}
}
2021-11-12 04:00:38 +00:00
impl < Query , Mutation , Subscription , OnConnInit , OnConnInitFut >
GraphQLWebSocket < Query , Mutation , Subscription , OnConnInit >
where
Query : ObjectType + 'static ,
Mutation : ObjectType + 'static ,
Subscription : SubscriptionType + 'static ,
OnConnInit : Fn ( serde_json ::Value ) -> OnConnInitFut + Send + Sync + 'static ,
OnConnInitFut : Future < Output = async_graphql ::Result < Data > > + Send + 'static ,
2021-11-10 12:03:09 +00:00
{
2021-11-12 04:00:38 +00:00
/// Specify the initial subscription context data, usually you can get something from the
/// incoming request to create it.
pub fn with_data ( self , data : Data ) -> Self {
Self { data , .. self }
2021-11-10 12:03:09 +00:00
}
/// Specify a callback function to be called when the connection is initialized.
///
/// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`].
2021-11-12 04:00:38 +00:00
/// The data returned by this callback function will be merged with the data specified by [`with_data`].
pub fn on_connection_init < OnConnInit2 , Fut > (
2021-08-23 15:16:31 +00:00
self ,
2021-11-12 04:00:38 +00:00
callback : OnConnInit2 ,
) -> GraphQLWebSocket < Query , Mutation , Subscription , OnConnInit2 >
2021-08-23 15:16:31 +00:00
where
2021-11-12 04:00:38 +00:00
OnConnInit2 : Fn ( serde_json ::Value ) -> Fut + Send + Sync + 'static ,
Fut : Future < Output = async_graphql ::Result < Data > > + Send + 'static ,
2021-08-23 15:16:31 +00:00
{
2021-11-12 04:00:38 +00:00
GraphQLWebSocket {
2021-08-23 15:16:31 +00:00
schema : self . schema ,
2021-11-12 04:00:38 +00:00
stream : self . stream ,
data : self . data ,
2021-11-10 12:03:09 +00:00
on_connection_init : callback ,
2021-11-12 04:00:38 +00:00
protocol : self . protocol ,
2021-08-23 15:16:31 +00:00
}
}
2021-11-10 12:03:09 +00:00
2021-11-12 04:00:38 +00:00
/// Processing subscription requests.
pub async fn serve ( self ) {
let ( mut sink , stream ) = self . stream . split ( ) ;
let stream = stream
. take_while ( | res | future ::ready ( res . is_ok ( ) ) )
. map ( Result ::unwrap )
. filter_map ( | msg | {
if msg . is_text ( ) | | msg . is_binary ( ) {
future ::ready ( Some ( msg ) )
} else {
future ::ready ( None )
2021-08-23 15:16:31 +00:00
}
} )
2021-11-12 07:27:05 +00:00
. map ( Message ::into_bytes ) ;
2021-11-12 04:00:38 +00:00
let mut stream =
async_graphql ::http ::WebSocket ::new ( self . schema . clone ( ) , stream , self . protocol . 0 )
. connection_data ( self . data )
. on_connection_init ( self . on_connection_init )
. map ( | msg | match msg {
WsMessage ::Text ( text ) = > Message ::text ( text ) ,
WsMessage ::Close ( code , status ) = > Message ::close_with ( code , status ) ,
} ) ;
while let Some ( item ) = stream . next ( ) . await {
let _ = sink . send ( item ) . await ;
}
2021-08-23 15:16:31 +00:00
}
}