2021-01-16 04:47:32 +00:00
use std ::future ::Future ;
2021-04-05 05:20:02 +00:00
use std ::str ::FromStr ;
2021-01-16 04:47:32 +00:00
2021-04-05 05:20:02 +00:00
use async_graphql ::http ::{ WebSocketProtocols , WsMessage } ;
2020-10-01 01:07:19 +00:00
use async_graphql ::{ Data , ObjectType , Result , Schema , SubscriptionType } ;
2021-11-12 07:27:05 +00:00
use futures_util ::future ::Ready ;
2021-11-18 01:25:32 +00:00
use futures_util ::stream ::{ SplitSink , SplitStream } ;
use futures_util ::{ future , Sink , Stream , StreamExt } ;
2020-09-21 06:57:33 +00:00
use warp ::filters ::ws ;
2021-11-18 01:25:32 +00:00
use warp ::ws ::Message ;
use warp ::{ Error , Filter , Rejection , Reply } ;
2020-09-21 06:57:33 +00:00
/// GraphQL subscription filter
///
/// # Examples
///
/// ```no_run
/// use async_graphql::*;
/// use async_graphql_warp::*;
/// use warp::Filter;
2020-10-16 06:49:22 +00:00
/// use futures_util::stream::{Stream, StreamExt};
2020-09-21 06:57:33 +00:00
/// use std::time::Duration;
///
/// struct QueryRoot;
///
/// #[Object]
2020-10-20 03:49:31 +00:00
/// impl QueryRoot {
/// async fn value(&self) -> i32 {
2020-10-22 02:11:47 +00:00
/// // A GraphQL Object type must define one or more fields.
2020-10-20 03:49:31 +00:00
/// 100
/// }
/// }
2020-09-21 06:57:33 +00:00
///
/// struct SubscriptionRoot;
///
/// #[Subscription]
/// impl SubscriptionRoot {
/// async fn tick(&self) -> impl Stream<Item = String> {
2021-01-20 02:37:31 +00:00
/// async_stream::stream! {
/// let mut interval = tokio::time::interval(Duration::from_secs(1));
/// loop {
/// let n = interval.tick().await;
/// yield format!("{}", n.elapsed().as_secs_f32());
/// }
/// }
2020-09-21 06:57:33 +00:00
/// }
/// }
///
2021-11-20 03:16:48 +00:00
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
/// let filter = async_graphql_warp::graphql_subscription(schema)
/// .or(warp::any().map(|| "Hello, World!"));
/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
/// # });
2020-09-21 06:57:33 +00:00
/// ```
pub fn graphql_subscription < Query , Mutation , Subscription > (
schema : Schema < Query , Mutation , Subscription > ,
) -> impl Filter < Extract = ( impl Reply , ) , Error = Rejection > + Clone
where
Query : ObjectType + Sync + Send + 'static ,
Mutation : ObjectType + Sync + Send + 'static ,
Subscription : SubscriptionType + Send + Sync + 'static ,
2021-04-04 06:43:23 +00:00
{
2020-12-03 21:16:07 +00:00
warp ::ws ( )
2021-04-05 05:20:02 +00:00
. and ( graphql_protocol ( ) )
. map ( move | ws : ws ::Ws , protocol | {
2020-12-03 21:16:07 +00:00
let schema = schema . clone ( ) ;
2020-12-02 13:13:00 +00:00
2021-11-12 07:27:05 +00:00
let reply = ws . on_upgrade ( move | socket | {
GraphQLWebSocket ::new ( socket , schema , protocol )
. on_connection_init ( default_on_connection_init )
. serve ( )
2020-12-03 21:16:07 +00:00
} ) ;
warp ::reply ::with_header (
reply ,
" Sec-WebSocket-Protocol " ,
protocol . sec_websocket_protocol ( ) ,
)
} )
2020-09-21 06:57:33 +00:00
}
2021-04-05 05:20:02 +00:00
/// Create a `Filter` that parse [WebSocketProtocols] from `sec-websocket-protocol` header.
pub fn graphql_protocol ( ) -> impl Filter < Extract = ( WebSocketProtocols , ) , Error = Rejection > + Clone
{
warp ::header ::optional ::< String > ( " sec-websocket-protocol " ) . map ( | protocols : Option < String > | {
protocols
. and_then ( | protocols | {
protocols
. split ( ',' )
. find_map ( | p | WebSocketProtocols ::from_str ( p . trim ( ) ) . ok ( ) )
} )
. unwrap_or ( WebSocketProtocols ::SubscriptionsTransportWS )
} )
}
2021-11-12 07:27:05 +00:00
type DefaultOnConnInitType = fn ( serde_json ::Value ) -> Ready < async_graphql ::Result < Data > > ;
fn default_on_connection_init ( _ : serde_json ::Value ) -> Ready < async_graphql ::Result < Data > > {
futures_util ::future ::ready ( Ok ( Data ::default ( ) ) )
}
/// A Websocket connection for GraphQL subscription.
2021-04-05 05:20:02 +00:00
///
/// # Examples
///
/// ```no_run
/// use async_graphql::*;
/// use async_graphql_warp::*;
2021-11-13 08:48:09 +00:00
/// use warp::{Filter, ws};
2021-04-05 05:20:02 +00:00
/// use futures_util::stream::{Stream, StreamExt};
/// use std::time::Duration;
///
/// struct QueryRoot;
///
/// #[Object]
/// impl QueryRoot {
/// async fn value(&self) -> i32 {
/// // A GraphQL Object type must define one or more fields.
/// 100
/// }
/// }
///
/// struct SubscriptionRoot;
///
/// #[Subscription]
/// impl SubscriptionRoot {
/// async fn tick(&self) -> impl Stream<Item = String> {
/// async_stream::stream! {
/// let mut interval = tokio::time::interval(Duration::from_secs(1));
/// loop {
/// let n = interval.tick().await;
/// yield format!("{}", n.elapsed().as_secs_f32());
/// }
/// }
/// }
/// }
///
2021-11-20 03:16:48 +00:00
/// # tokio::runtime::Runtime::new().unwrap().block_on(async {
/// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
2021-11-12 07:27:05 +00:00
///
2021-11-20 03:16:48 +00:00
/// let filter = warp::ws()
/// .and(graphql_protocol())
/// .map(move |ws: ws::Ws, protocol| {
/// let schema = schema.clone();
2021-11-12 07:27:05 +00:00
///
2021-11-20 03:16:48 +00:00
/// let reply = ws.on_upgrade(move |socket| {
/// GraphQLWebSocket::new(socket, schema, protocol).serve()
2021-04-05 05:20:02 +00:00
/// });
2021-11-12 07:27:05 +00:00
///
2021-11-20 03:16:48 +00:00
/// warp::reply::with_header(
/// reply,
/// "Sec-WebSocket-Protocol",
/// protocol.sec_websocket_protocol(),
/// )
/// });
///
/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
/// # });
2021-04-05 05:20:02 +00:00
/// ```
2021-11-18 01:25:32 +00:00
pub struct GraphQLWebSocket < Sink , Stream , Query , Mutation , Subscription , OnInit > {
sink : Sink ,
stream : Stream ,
2021-04-05 05:20:02 +00:00
protocol : WebSocketProtocols ,
schema : Schema < Query , Mutation , Subscription > ,
2021-11-12 07:27:05 +00:00
data : Data ,
on_init : OnInit ,
}
2021-11-18 01:25:32 +00:00
impl < S , Query , Mutation , Subscription >
GraphQLWebSocket <
SplitSink < S , Message > ,
SplitStream < S > ,
Query ,
Mutation ,
Subscription ,
DefaultOnConnInitType ,
>
2021-11-12 07:27:05 +00:00
where
2021-11-18 01:25:32 +00:00
S : Stream < Item = Result < Message , Error > > + Sink < Message > ,
2021-04-05 05:20:02 +00:00
Query : ObjectType + 'static ,
Mutation : ObjectType + 'static ,
Subscription : SubscriptionType + 'static ,
{
2021-11-12 07:27:05 +00:00
/// Create a [`GraphQLWebSocket`] object.
pub fn new (
2021-11-18 01:25:32 +00:00
socket : S ,
schema : Schema < Query , Mutation , Subscription > ,
protocol : WebSocketProtocols ,
) -> Self {
let ( sink , stream ) = socket . split ( ) ;
GraphQLWebSocket ::new_with_pair ( sink , stream , schema , protocol )
}
}
impl < Sink , Stream , Query , Mutation , Subscription >
GraphQLWebSocket < Sink , Stream , Query , Mutation , Subscription , DefaultOnConnInitType >
where
Sink : futures_util ::sink ::Sink < Message > ,
Stream : futures_util ::stream ::Stream < Item = Result < Message , Error > > ,
Query : ObjectType + 'static ,
Mutation : ObjectType + 'static ,
Subscription : SubscriptionType + 'static ,
{
/// Create a [`GraphQLWebSocket`] object with sink and stream objects.
pub fn new_with_pair (
sink : Sink ,
stream : Stream ,
2021-11-12 07:27:05 +00:00
schema : Schema < Query , Mutation , Subscription > ,
protocol : WebSocketProtocols ,
) -> Self {
GraphQLWebSocket {
2021-11-18 01:25:32 +00:00
sink ,
stream ,
2021-11-12 07:27:05 +00:00
protocol ,
schema ,
data : Data ::default ( ) ,
on_init : default_on_connection_init ,
}
}
2021-04-05 05:20:02 +00:00
}
2021-11-18 01:25:32 +00:00
impl < Sink , Stream , Query , Mutation , Subscription , OnConnInit , OnConnInitFut >
GraphQLWebSocket < Sink , Stream , Query , Mutation , Subscription , OnConnInit >
2021-11-12 07:27:05 +00:00
where
2021-11-18 01:25:32 +00:00
Sink : futures_util ::sink ::Sink < Message > ,
Stream : futures_util ::stream ::Stream < Item = Result < Message , Error > > ,
2021-04-05 05:20:02 +00:00
Query : ObjectType + 'static ,
Mutation : ObjectType + 'static ,
Subscription : SubscriptionType + 'static ,
2021-11-12 07:27:05 +00:00
OnConnInit : Fn ( serde_json ::Value ) -> OnConnInitFut + Send + Sync + 'static ,
OnConnInitFut : Future < Output = async_graphql ::Result < Data > > + Send + 'static ,
2021-04-05 05:20:02 +00:00
{
2021-11-12 07:27:05 +00:00
/// Specify the initial subscription context data, usually you can get something from the
/// incoming request to create it.
pub fn with_data ( self , data : Data ) -> Self {
Self { data , .. self }
}
/// Specify a callback function to be called when the connection is initialized.
///
/// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`].
/// The data returned by this callback function will be merged with the data specified by [`with_data`].
pub fn on_connection_init < OnConnInit2 , Fut > (
self ,
callback : OnConnInit2 ,
2021-11-18 01:25:32 +00:00
) -> GraphQLWebSocket < Sink , Stream , Query , Mutation , Subscription , OnConnInit2 >
2021-11-12 07:27:05 +00:00
where
OnConnInit2 : Fn ( serde_json ::Value ) -> Fut + Send + Sync + 'static ,
Fut : Future < Output = async_graphql ::Result < Data > > + Send + 'static ,
{
GraphQLWebSocket {
2021-11-18 01:25:32 +00:00
sink : self . sink ,
stream : self . stream ,
2021-11-12 07:27:05 +00:00
schema : self . schema ,
data : self . data ,
on_init : callback ,
protocol : self . protocol ,
}
}
/// Processing subscription requests.
pub async fn serve ( self ) {
2021-11-18 01:25:32 +00:00
let stream = self
. stream
2021-04-05 05:20:02 +00:00
. take_while ( | msg | future ::ready ( msg . is_ok ( ) ) )
. map ( Result ::unwrap )
. filter ( | msg | future ::ready ( msg . is_text ( ) | | msg . is_binary ( ) ) )
2021-11-12 07:27:05 +00:00
. map ( ws ::Message ::into_bytes ) ;
let _ = async_graphql ::http ::WebSocket ::new ( self . schema . clone ( ) , stream , self . protocol )
. connection_data ( self . data )
. on_connection_init ( self . on_init )
. map ( | msg | match msg {
WsMessage ::Text ( text ) = > ws ::Message ::text ( text ) ,
WsMessage ::Close ( code , status ) = > ws ::Message ::close_with ( code , status ) ,
} )
. map ( Ok )
2021-11-18 01:25:32 +00:00
. forward ( self . sink )
2021-11-12 07:27:05 +00:00
. await ;
}
2021-04-05 05:20:02 +00:00
}