2022-04-19 04:25:11 +00:00
use std ::{ io ::Error as IoError , str ::FromStr } ;
use async_graphql ::{
http ::{ WebSocketProtocols , WsMessage , ALL_WEBSOCKET_PROTOCOLS } ,
Data , ObjectType , Schema , SubscriptionType ,
} ;
use futures_util ::{
future ::{ self , Ready } ,
stream ::{ SplitSink , SplitStream } ,
Future , Sink , SinkExt , Stream , StreamExt ,
} ;
2021-12-16 04:56:11 +00:00
use poem ::{
2022-04-19 04:25:11 +00:00
http ,
http ::StatusCode ,
web ::websocket ::{ Message , WebSocket } ,
Endpoint , Error , FromRequest , IntoResponse , Request , RequestBody , Response , Result ,
2021-12-16 04:56:11 +00:00
} ;
2021-11-12 04:00:38 +00:00
/// A GraphQL protocol extractor.
///
/// It extract GraphQL protocol from `SEC_WEBSOCKET_PROTOCOL` header.
#[ derive(Debug, Copy, Clone, PartialEq, Eq) ]
2021-11-13 08:15:25 +00:00
pub struct GraphQLProtocol ( pub WebSocketProtocols ) ;
2021-11-12 04:00:38 +00:00
#[ poem::async_trait ]
impl < ' a > FromRequest < ' a > for GraphQLProtocol {
2021-12-16 04:56:11 +00:00
async fn from_request ( req : & ' a Request , _body : & mut RequestBody ) -> Result < Self > {
2021-11-12 04:00:38 +00:00
req . headers ( )
. get ( http ::header ::SEC_WEBSOCKET_PROTOCOL )
. and_then ( | value | value . to_str ( ) . ok ( ) )
. and_then ( | protocols | {
protocols
. split ( ',' )
. find_map ( | p | WebSocketProtocols ::from_str ( p . trim ( ) ) . ok ( ) )
} )
. map ( Self )
2021-12-17 00:11:09 +00:00
. ok_or_else ( | | Error ::from_status ( StatusCode ::BAD_REQUEST ) )
2021-11-12 04:00:38 +00:00
}
}
2021-08-23 15:16:31 +00:00
/// A GraphQL subscription endpoint.
///
/// # Example
///
/// ```
/// use async_graphql::{EmptyMutation, Object, Schema, Subscription};
2022-04-19 04:25:11 +00:00
/// use async_graphql_poem::GraphQLSubscription;
/// use futures_util::{stream, Stream};
/// use poem::{get, Route};
2021-08-23 15:16:31 +00:00
///
/// struct Query;
///
/// #[Object]
/// impl Query {
/// async fn value(&self) -> i32 {
/// 100
/// }
/// }
///
/// struct Subscription;
///
/// #[Subscription]
/// impl Subscription {
/// async fn values(&self) -> impl Stream<Item = i32> {
/// stream::iter(vec![1, 2, 3, 4, 5])
/// }
/// }
///
/// type MySchema = Schema<Query, EmptyMutation, Subscription>;
///
2021-09-01 00:15:27 +00:00
/// let schema = Schema::new(Query, EmptyMutation, Subscription);
2021-10-23 05:33:21 +00:00
/// let app = Route::new().at("/ws", get(GraphQLSubscription::new(schema)));
2021-08-23 15:16:31 +00:00
/// ```
2021-11-12 04:00:38 +00:00
pub struct GraphQLSubscription < Query , Mutation , Subscription > {
2021-08-23 15:16:31 +00:00
schema : Schema < Query , Mutation , Subscription > ,
2021-11-10 12:03:09 +00:00
}
2021-11-12 04:00:38 +00:00
impl < Query , Mutation , Subscription > GraphQLSubscription < Query , Mutation , Subscription > {
/// Create a GraphQL subscription endpoint.
pub fn new ( schema : Schema < Query , Mutation , Subscription > ) -> Self {
Self { schema }
}
}
2021-11-10 12:03:09 +00:00
2021-11-12 04:00:38 +00:00
#[ poem::async_trait ]
impl < Query , Mutation , Subscription > Endpoint for GraphQLSubscription < Query , Mutation , Subscription >
where
Query : ObjectType + 'static ,
Mutation : ObjectType + 'static ,
Subscription : SubscriptionType + 'static ,
{
2021-12-16 04:56:11 +00:00
type Output = Response ;
2021-11-12 04:00:38 +00:00
2021-12-16 04:56:11 +00:00
async fn call ( & self , req : Request ) -> Result < Self ::Output > {
2021-11-12 04:00:38 +00:00
let ( req , mut body ) = req . split ( ) ;
let websocket = WebSocket ::from_request ( & req , & mut body ) . await ? ;
let protocol = GraphQLProtocol ::from_request ( & req , & mut body ) . await ? ;
let schema = self . schema . clone ( ) ;
let resp = websocket
. protocols ( ALL_WEBSOCKET_PROTOCOLS )
. on_upgrade ( move | stream | GraphQLWebSocket ::new ( stream , schema , protocol ) . serve ( ) )
. into_response ( ) ;
Ok ( resp )
}
2021-11-10 12:03:09 +00:00
}
2021-11-12 04:00:38 +00:00
type DefaultOnConnInitType = fn ( serde_json ::Value ) -> Ready < async_graphql ::Result < Data > > ;
2021-11-10 12:03:09 +00:00
2021-11-12 04:00:38 +00:00
fn default_on_connection_init ( _ : serde_json ::Value ) -> Ready < async_graphql ::Result < Data > > {
2021-11-10 12:03:09 +00:00
futures_util ::future ::ready ( Ok ( Data ::default ( ) ) )
2021-08-23 15:16:31 +00:00
}
2021-11-12 04:00:38 +00:00
/// A Websocket connection for GraphQL subscription.
2021-11-18 01:25:32 +00:00
pub struct GraphQLWebSocket < Sink , Stream , Query , Mutation , Subscription , OnConnInit > {
sink : Sink ,
stream : Stream ,
2021-11-12 04:00:38 +00:00
schema : Schema < Query , Mutation , Subscription > ,
data : Data ,
on_connection_init : OnConnInit ,
protocol : GraphQLProtocol ,
}
2021-11-18 01:25:32 +00:00
impl < S , Query , Mutation , Subscription >
GraphQLWebSocket <
SplitSink < S , Message > ,
SplitStream < S > ,
Query ,
Mutation ,
Subscription ,
DefaultOnConnInitType ,
>
2021-11-12 04:00:38 +00:00
where
2021-11-18 01:25:32 +00:00
S : Stream < Item = Result < Message , IoError > > + Sink < Message > ,
2021-11-12 04:00:38 +00:00
Query : ObjectType + 'static ,
Mutation : ObjectType + 'static ,
Subscription : SubscriptionType + 'static ,
2021-08-23 15:16:31 +00:00
{
2021-11-12 04:00:38 +00:00
/// Create a [`GraphQLWebSocket`] object.
pub fn new (
2021-11-18 01:25:32 +00:00
stream : S ,
schema : Schema < Query , Mutation , Subscription > ,
protocol : GraphQLProtocol ,
) -> Self {
let ( sink , stream ) = stream . split ( ) ;
GraphQLWebSocket ::new_with_pair ( sink , stream , schema , protocol )
}
}
impl < Sink , Stream , Query , Mutation , Subscription >
GraphQLWebSocket < Sink , Stream , Query , Mutation , Subscription , DefaultOnConnInitType >
where
Sink : futures_util ::sink ::Sink < Message > ,
Stream : futures_util ::stream ::Stream < Item = Result < Message , IoError > > ,
Query : ObjectType + 'static ,
Mutation : ObjectType + 'static ,
Subscription : SubscriptionType + 'static ,
{
/// Create a [`GraphQLWebSocket`] object with sink and stream objects.
pub fn new_with_pair (
sink : Sink ,
stream : Stream ,
2021-11-12 04:00:38 +00:00
schema : Schema < Query , Mutation , Subscription > ,
protocol : GraphQLProtocol ,
) -> Self {
GraphQLWebSocket {
2021-11-18 01:25:32 +00:00
sink ,
2021-11-12 04:00:38 +00:00
stream ,
2021-11-18 01:25:32 +00:00
schema ,
2021-11-12 04:00:38 +00:00
data : Data ::default ( ) ,
2021-11-10 12:03:09 +00:00
on_connection_init : default_on_connection_init ,
2021-11-12 04:00:38 +00:00
protocol ,
2021-08-23 15:16:31 +00:00
}
}
}
2021-11-18 01:25:32 +00:00
impl < Sink , Stream , Query , Mutation , Subscription , OnConnInit , OnConnInitFut >
GraphQLWebSocket < Sink , Stream , Query , Mutation , Subscription , OnConnInit >
2021-11-12 04:00:38 +00:00
where
2021-11-18 01:25:32 +00:00
Sink : futures_util ::sink ::Sink < Message > ,
Stream : futures_util ::stream ::Stream < Item = Result < Message , IoError > > ,
2021-11-12 04:00:38 +00:00
Query : ObjectType + 'static ,
Mutation : ObjectType + 'static ,
Subscription : SubscriptionType + 'static ,
OnConnInit : Fn ( serde_json ::Value ) -> OnConnInitFut + Send + Sync + 'static ,
OnConnInitFut : Future < Output = async_graphql ::Result < Data > > + Send + 'static ,
2021-11-10 12:03:09 +00:00
{
2022-04-19 04:25:11 +00:00
/// Specify the initial subscription context data, usually you can get
/// something from the incoming request to create it.
2022-01-24 06:14:07 +00:00
#[ must_use ]
2021-11-12 04:00:38 +00:00
pub fn with_data ( self , data : Data ) -> Self {
Self { data , .. self }
2021-11-10 12:03:09 +00:00
}
2022-04-19 04:25:11 +00:00
/// Specify a callback function to be called when the connection is
/// initialized.
2021-11-10 12:03:09 +00:00
///
/// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`].
2022-04-19 04:25:11 +00:00
/// The data returned by this callback function will be merged with the data
/// specified by [`with_data`].
2021-11-12 04:00:38 +00:00
pub fn on_connection_init < OnConnInit2 , Fut > (
2021-08-23 15:16:31 +00:00
self ,
2021-11-12 04:00:38 +00:00
callback : OnConnInit2 ,
2021-11-18 01:25:32 +00:00
) -> GraphQLWebSocket < Sink , Stream , Query , Mutation , Subscription , OnConnInit2 >
2021-08-23 15:16:31 +00:00
where
2021-11-12 04:00:38 +00:00
OnConnInit2 : Fn ( serde_json ::Value ) -> Fut + Send + Sync + 'static ,
Fut : Future < Output = async_graphql ::Result < Data > > + Send + 'static ,
2021-08-23 15:16:31 +00:00
{
2021-11-12 04:00:38 +00:00
GraphQLWebSocket {
2021-11-18 01:25:32 +00:00
sink : self . sink ,
2021-11-12 04:00:38 +00:00
stream : self . stream ,
2021-11-18 01:25:32 +00:00
schema : self . schema ,
2021-11-12 04:00:38 +00:00
data : self . data ,
2021-11-10 12:03:09 +00:00
on_connection_init : callback ,
2021-11-12 04:00:38 +00:00
protocol : self . protocol ,
2021-08-23 15:16:31 +00:00
}
}
2021-11-10 12:03:09 +00:00
2021-11-12 04:00:38 +00:00
/// Processing subscription requests.
pub async fn serve ( self ) {
2021-11-18 01:25:32 +00:00
let stream = self
. stream
2021-11-12 04:00:38 +00:00
. take_while ( | res | future ::ready ( res . is_ok ( ) ) )
. map ( Result ::unwrap )
. filter_map ( | msg | {
if msg . is_text ( ) | | msg . is_binary ( ) {
future ::ready ( Some ( msg ) )
} else {
future ::ready ( None )
2021-08-23 15:16:31 +00:00
}
} )
2021-11-12 07:27:05 +00:00
. map ( Message ::into_bytes ) ;
2021-11-12 04:00:38 +00:00
2021-11-18 01:25:32 +00:00
let stream =
2021-11-12 04:00:38 +00:00
async_graphql ::http ::WebSocket ::new ( self . schema . clone ( ) , stream , self . protocol . 0 )
. connection_data ( self . data )
. on_connection_init ( self . on_connection_init )
. map ( | msg | match msg {
WsMessage ::Text ( text ) = > Message ::text ( text ) ,
WsMessage ::Close ( code , status ) = > Message ::close_with ( code , status ) ,
} ) ;
2021-11-18 01:25:32 +00:00
let sink = self . sink ;
futures_util ::pin_mut! ( stream , sink ) ;
2021-11-12 04:00:38 +00:00
while let Some ( item ) = stream . next ( ) . await {
let _ = sink . send ( item ) . await ;
}
2021-08-23 15:16:31 +00:00
}
}