Update async-graphql-poem

This commit is contained in:
Sunli 2021-11-10 20:03:09 +08:00
parent 0e9e087691
commit 0118a5dd7c
5 changed files with 139 additions and 74 deletions

View File

@ -14,7 +14,7 @@ categories = ["network-programming", "asynchronous"]
[dependencies] [dependencies]
async-graphql = { path = "../..", version = "=2.11.2" } async-graphql = { path = "../..", version = "=2.11.2" }
poem = { version = "1.0.19", features = ["websocket"] } poem = { version = "1.0.23", features = ["websocket"] }
futures-util = { version = "0.3.13", default-features = false } futures-util = { version = "0.3.13", default-features = false }
serde_json = "1.0.66" serde_json = "1.0.66"
tokio-util = { version = "0.6.7", features = ["compat"] } tokio-util = { version = "0.6.7", features = ["compat"] }

View File

@ -4,8 +4,9 @@ use async_graphql::http::{WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS
use async_graphql::{Data, ObjectType, Schema, SubscriptionType}; use async_graphql::{Data, ObjectType, Schema, SubscriptionType};
use futures_util::future::{self, Ready}; use futures_util::future::{self, Ready};
use futures_util::{Future, SinkExt, StreamExt}; use futures_util::{Future, SinkExt, StreamExt};
use poem::http::StatusCode;
use poem::web::websocket::{Message, WebSocket}; use poem::web::websocket::{Message, WebSocket};
use poem::{http, Endpoint, FromRequest, IntoResponse, Request, Response, Result}; use poem::{http, Endpoint, Error, FromRequest, IntoResponse, Request, Response, Result};
/// A GraphQL subscription endpoint. /// A GraphQL subscription endpoint.
/// ///
@ -40,9 +41,22 @@ use poem::{http, Endpoint, FromRequest, IntoResponse, Request, Response, Result}
/// let schema = Schema::new(Query, EmptyMutation, Subscription); /// let schema = Schema::new(Query, EmptyMutation, Subscription);
/// let app = Route::new().at("/ws", get(GraphQLSubscription::new(schema))); /// let app = Route::new().at("/ws", get(GraphQLSubscription::new(schema)));
/// ``` /// ```
pub struct GraphQLSubscription<Query, Mutation, Subscription, F> { pub struct GraphQLSubscription<Query, Mutation, Subscription, OnCreate, OnInit> {
schema: Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
initializer: F, on_connection_create: OnCreate,
on_connection_init: OnInit,
}
type DefaultOnConnCreateType = fn(&Request) -> Ready<Result<Data>>;
fn default_on_connection_create(_: &Request) -> Ready<Result<Data>> {
futures_util::future::ready(Ok(Data::default()))
}
type DefaultOnConnInitType = fn(serde_json::Value) -> Ready<Result<Data>>;
fn default_on_connection_init(_: serde_json::Value) -> Ready<Result<Data>> {
futures_util::future::ready(Ok(Data::default()))
} }
impl<Query, Mutation, Subscription> impl<Query, Mutation, Subscription>
@ -50,48 +64,79 @@ impl<Query, Mutation, Subscription>
Query, Query,
Mutation, Mutation,
Subscription, Subscription,
fn(serde_json::Value) -> Ready<async_graphql::Result<Data>>, DefaultOnConnCreateType,
DefaultOnConnInitType,
> >
{ {
/// Create a GraphQL subscription endpoint. /// Create a GraphQL subscription endpoint.
pub fn new(schema: Schema<Query, Mutation, Subscription>) -> Self { pub fn new(schema: Schema<Query, Mutation, Subscription>) -> Self {
Self { Self {
schema, schema,
initializer: |_| futures_util::future::ready(Ok(Default::default())), on_connection_create: default_on_connection_create,
on_connection_init: default_on_connection_init,
} }
} }
} }
impl<Query, Mutation, Subscription, F> GraphQLSubscription<Query, Mutation, Subscription, F> { impl<Query, Mutation, Subscription, OnCreate, OnInit>
/// With a data initialization function. GraphQLSubscription<Query, Mutation, Subscription, OnCreate, OnInit>
pub fn with_initializer<F2, R>( {
/// Specify the callback function to be called when the connection is created.
///
/// You can get something from the incoming request to create [`Data`].
pub fn on_connection_create<OnCreate2, Fut>(
self, self,
initializer: F2, callback: OnCreate2,
) -> GraphQLSubscription<Query, Mutation, Subscription, F2> ) -> GraphQLSubscription<Query, Mutation, Subscription, OnCreate2, OnInit>
where where
F2: FnOnce(serde_json::Value) -> R + Clone + Send + Sync + 'static, OnCreate2: Fn(&Request) -> Fut + Send + Sync + 'static,
R: Future<Output = Result<Data>> + Send + 'static, Fut: Future<Output = Result<Data>> + Send + 'static,
{ {
GraphQLSubscription { GraphQLSubscription {
schema: self.schema, schema: self.schema,
initializer, on_connection_create: callback,
on_connection_init: self.on_connection_init,
}
}
/// Specify a callback function to be called when the connection is initialized.
///
/// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`].
pub fn on_connection_init<OnInit2, Fut>(
self,
callback: OnInit2,
) -> GraphQLSubscription<Query, Mutation, Subscription, OnCreate, OnInit2>
where
OnInit2: FnOnce(serde_json::Value) -> Fut + Clone + Send + Sync + 'static,
Fut: Future<Output = Result<Data>> + Send + 'static,
{
GraphQLSubscription {
schema: self.schema,
on_connection_create: self.on_connection_create,
on_connection_init: callback,
} }
} }
} }
#[poem::async_trait] #[poem::async_trait]
impl<Query, Mutation, Subscription, F, R> Endpoint impl<Query, Mutation, Subscription, OnCreate, OnCreateFut, OnInit, OnInitFut> Endpoint
for GraphQLSubscription<Query, Mutation, Subscription, F> for GraphQLSubscription<Query, Mutation, Subscription, OnCreate, OnInit>
where where
Query: ObjectType + 'static, Query: ObjectType + 'static,
Mutation: ObjectType + 'static, Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static, Subscription: SubscriptionType + 'static,
F: FnOnce(serde_json::Value) -> R + Clone + Send + Sync + 'static, OnCreate: Fn(&Request) -> OnCreateFut + Send + Sync + 'static,
R: Future<Output = async_graphql::Result<Data>> + Send + 'static, OnCreateFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
OnInit: FnOnce(serde_json::Value) -> OnInitFut + Clone + Send + Sync + 'static,
OnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
{ {
type Output = Result<Response>; type Output = Result<Response>;
async fn call(&self, req: Request) -> Self::Output { async fn call(&self, req: Request) -> Self::Output {
let data = (self.on_connection_create)(&req)
.await
.map_err(|_| Error::new(StatusCode::BAD_REQUEST))?;
let (req, mut body) = req.split(); let (req, mut body) = req.split();
let websocket = WebSocket::from_request(&req, &mut body).await?; let websocket = WebSocket::from_request(&req, &mut body).await?;
let protocol = req let protocol = req
@ -105,7 +150,7 @@ where
}) })
.unwrap_or(WebSocketProtocols::SubscriptionsTransportWS); .unwrap_or(WebSocketProtocols::SubscriptionsTransportWS);
let schema = self.schema.clone(); let schema = self.schema.clone();
let initializer = self.initializer.clone(); let on_connection_init = self.on_connection_init.clone();
let resp = websocket let resp = websocket
.protocols(ALL_WEBSOCKET_PROTOCOLS) .protocols(ALL_WEBSOCKET_PROTOCOLS)
@ -125,16 +170,13 @@ where
.map(Message::into_bytes) .map(Message::into_bytes)
.boxed(); .boxed();
let mut stream = async_graphql::http::WebSocket::with_data( let mut stream = async_graphql::http::WebSocket::new(schema, stream, protocol)
schema, .connection_data(data)
stream, .on_connection_init(on_connection_init)
initializer, .map(|msg| match msg {
protocol, WsMessage::Text(text) => Message::text(text),
) WsMessage::Close(code, status) => Message::close_with(code, status),
.map(|msg| match msg { });
WsMessage::Text(text) => Message::text(text),
WsMessage::Close(code, status) => Message::close_with(code, status),
});
while let Some(item) = stream.next().await { while let Some(item) = stream.next().await {
let _ = sink.send(item).await; let _ = sink.send(item).await;

View File

@ -6,6 +6,7 @@ use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll}; use std::task::{Context, Poll};
use futures_util::future::Ready;
use futures_util::stream::Stream; use futures_util::stream::Stream;
use futures_util::FutureExt; use futures_util::FutureExt;
use futures_util::{future::BoxFuture, StreamExt}; use futures_util::{future::BoxFuture, StreamExt};
@ -57,16 +58,15 @@ impl WsMessage {
} }
} }
type BoxInitializer =
Box<(dyn FnOnce(serde_json::Value) -> BoxFuture<'static, Result<Data>> + Send + 'static)>;
pin_project! { pin_project! {
/// A GraphQL connection over websocket. /// A GraphQL connection over websocket.
/// ///
/// [Reference](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md). /// # References
/// [Reference](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md). ///
pub struct WebSocket<S, Query, Mutation, Subscription> { /// - [subscriptions-transport-ws](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md)
data_initializer: Option<BoxInitializer>, /// - [graphql-ws](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md)
pub struct WebSocket<S, Query, Mutation, Subscription, OnInit> {
on_connection_init: Option<OnInit>,
init_fut: Option<BoxFuture<'static, Result<Data>>>, init_fut: Option<BoxFuture<'static, Result<Data>>>,
connection_data: Option<Data>, connection_data: Option<Data>,
data: Option<Arc<Data>>, data: Option<Arc<Data>>,
@ -81,7 +81,14 @@ pin_project! {
type MessageMapStream<S> = type MessageMapStream<S> =
futures_util::stream::Map<S, fn(<S as Stream>::Item) -> serde_json::Result<ClientMessage>>; futures_util::stream::Map<S, fn(<S as Stream>::Item) -> serde_json::Result<ClientMessage>>;
impl<S, Query, Mutation, Subscription> WebSocket<S, Query, Mutation, Subscription> type DefaultOnConnInitType = fn(serde_json::Value) -> Ready<Result<Data>>;
fn default_on_connection_init(_: serde_json::Value) -> Ready<Result<Data>> {
futures_util::future::ready(Ok(Data::default()))
}
impl<S, Query, Mutation, Subscription>
WebSocket<S, Query, Mutation, Subscription, DefaultOnConnInitType>
where where
S: Stream<Item = serde_json::Result<ClientMessage>>, S: Stream<Item = serde_json::Result<ClientMessage>>,
{ {
@ -92,7 +99,7 @@ where
protocol: Protocols, protocol: Protocols,
) -> Self { ) -> Self {
WebSocket { WebSocket {
data_initializer: Some(Box::new(|_| Box::pin(async move { Ok(Data::default()) }))), on_connection_init: Some(default_on_connection_init),
init_fut: None, init_fut: None,
connection_data: None, connection_data: None,
data: None, data: None,
@ -102,33 +109,10 @@ where
protocol, protocol,
} }
} }
/// Specify a connection initializer.
///
/// This function, if present, will be called with the data sent by the client in the
/// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
/// From that point on the returned data will be accessible to all requests.
pub fn with_initializer<F, R>(mut self, initializer: F) -> Self
where
F: FnOnce(serde_json::Value) -> R + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
{
self.data_initializer = Some(Box::new(move |value| Box::pin(initializer(value))));
self
}
/// Specify a connection data.
///
/// This data usually comes from HTTP requests.
/// When the `GQL_CONNECTION_INIT` message is received, this data will be merged with the data
/// returned by the closure specified by `with_initializer` into the final subscription context data.
pub fn connection_data(mut self, data: Data) -> Self {
self.connection_data = Some(data);
self
}
} }
impl<S, Query, Mutation, Subscription> WebSocket<MessageMapStream<S>, Query, Mutation, Subscription> impl<S, Query, Mutation, Subscription>
WebSocket<MessageMapStream<S>, Query, Mutation, Subscription, DefaultOnConnInitType>
where where
S: Stream, S: Stream,
S::Item: AsRef<[u8]>, S::Item: AsRef<[u8]>,
@ -138,23 +122,62 @@ where
schema: Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
stream: S, stream: S,
protocol: Protocols, protocol: Protocols,
) -> Self ) -> Self {
where
S: Stream,
S::Item: AsRef<[u8]>,
{
let stream = stream let stream = stream
.map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result<ClientMessage>); .map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result<ClientMessage>);
WebSocket::from_message_stream(schema, stream, protocol) WebSocket::from_message_stream(schema, stream, protocol)
} }
} }
impl<S, Query, Mutation, Subscription> Stream for WebSocket<S, Query, Mutation, Subscription> impl<S, Query, Mutation, Subscription, OnInit> WebSocket<S, Query, Mutation, Subscription, OnInit>
where
S: Stream<Item = serde_json::Result<ClientMessage>>,
{
/// Specify a connection data.
///
/// This data usually comes from HTTP requests.
/// When the `GQL_CONNECTION_INIT` message is received, this data will be merged with the data
/// returned by the closure specified by `with_initializer` into the final subscription context data.
pub fn connection_data(mut self, data: Data) -> Self {
self.connection_data = Some(data);
self
}
/// Specify a connection initialize callback function.
///
/// This function if present, will be called with the data sent by the client in the
/// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
/// From that point on the returned data will be accessible to all requests.
pub fn on_connection_init<F, R>(
self,
callback: F,
) -> WebSocket<S, Query, Mutation, Subscription, F>
where
F: FnOnce(serde_json::Value) -> R + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
{
WebSocket {
on_connection_init: Some(callback),
init_fut: self.init_fut,
connection_data: self.connection_data,
data: self.data,
schema: self.schema,
streams: self.streams,
stream: self.stream,
protocol: self.protocol,
}
}
}
impl<S, Query, Mutation, Subscription, OnInit, InitFut> Stream
for WebSocket<S, Query, Mutation, Subscription, OnInit>
where where
S: Stream<Item = serde_json::Result<ClientMessage>>, S: Stream<Item = serde_json::Result<ClientMessage>>,
Query: ObjectType + 'static, Query: ObjectType + 'static,
Mutation: ObjectType + 'static, Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static, Subscription: SubscriptionType + 'static,
OnInit: FnOnce(serde_json::Value) -> InitFut + Send + 'static,
InitFut: Future<Output = Result<Data>> + Send + 'static,
{ {
type Item = WsMessage; type Item = WsMessage;
@ -175,9 +198,9 @@ where
match message { match message {
ClientMessage::ConnectionInit { payload } => { ClientMessage::ConnectionInit { payload } => {
if let Some(data_initializer) = this.data_initializer.take() { if let Some(on_connection_init) = this.on_connection_init.take() {
*this.init_fut = Some(Box::pin(async move { *this.init_fut = Some(Box::pin(async move {
data_initializer(payload.unwrap_or_default()).await on_connection_init(payload.unwrap_or_default()).await
})); }));
break; break;
} else { } else {

View File

@ -112,7 +112,7 @@ pub async fn test_subscription_ws_transport_with_token() {
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut tx, rx) = mpsc::unbounded(); let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::new(schema.clone(), rx, WebSocketProtocols::GraphQLWS) let mut stream = http::WebSocket::new(schema.clone(), rx, WebSocketProtocols::GraphQLWS)
.with_initializer(|value| async { .on_connection_init(|value| async {
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct Payload { struct Payload {
token: String, token: String,
@ -347,7 +347,7 @@ pub async fn test_subscription_init_error() {
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut tx, rx) = mpsc::unbounded(); let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS) let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS)
.with_initializer(|_| async move { Err("Error!".into()) }); .on_connection_init(|_| async move { Err("Error!".into()) });
tx.send( tx.send(
serde_json::to_string(&value!({ serde_json::to_string(&value!({

View File

@ -108,7 +108,7 @@ pub async fn test_subscription_ws_transport_with_token() {
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut tx, rx) = mpsc::unbounded(); let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS) let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS)
.with_initializer(|value| async { .on_connection_init(|value| async {
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct Payload { struct Payload {
token: String, token: String,