Supports async websocket initializer. #393

This commit is contained in:
Sunli 2021-01-16 12:47:32 +08:00
parent 457fb4223a
commit b98e35995d
6 changed files with 81 additions and 46 deletions

@ -1 +1 @@
Subproject commit de1c7e9acb840d3d24ccce5f546527672b989ba7
Subproject commit 71fab89afb282e24840bf199b92b0d6ec624c5ca

View File

@ -1,3 +1,4 @@
use std::future::Future;
use std::str::FromStr;
use std::time::{Duration, Instant};
@ -12,22 +13,24 @@ use actix_web::{HttpRequest, HttpResponse};
use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext};
use async_graphql::http::{WebSocket, WebSocketProtocols};
use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType};
use futures_util::future::Ready;
use futures_util::stream::Stream;
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
/// Actor for subscription via websocket
pub struct WSSubscription<Query, Mutation, Subscription> {
pub struct WSSubscription<Query, Mutation, Subscription, F> {
schema: Schema<Query, Mutation, Subscription>,
protocol: WebSocketProtocols,
last_heartbeat: Instant,
messages: Option<async_channel::Sender<Vec<u8>>>,
initializer: Option<Box<dyn FnOnce(serde_json::Value) -> Result<Data> + Send + Sync>>,
initializer: Option<F>,
continuation: Vec<u8>,
}
impl<Query, Mutation, Subscription> WSSubscription<Query, Mutation, Subscription>
impl<Query, Mutation, Subscription>
WSSubscription<Query, Mutation, Subscription, fn(serde_json::Value) -> Ready<Result<Data>>>
where
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
@ -42,11 +45,22 @@ where
where
T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
{
Self::start_with_initializer(schema, request, stream, |_| Ok(Default::default()))
Self::start_with_initializer(schema, request, stream, |_| {
futures_util::future::ready(Ok(Default::default()))
})
}
}
impl<Query, Mutation, Subscription, F, R> WSSubscription<Query, Mutation, Subscription, F>
where
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
F: FnOnce(serde_json::Value) -> R + Unpin + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
{
/// Start an actor for subscription connection via websocket with an initialization function.
pub fn start_with_initializer<T, F>(
pub fn start_with_initializer<T>(
schema: Schema<Query, Mutation, Subscription>,
request: &HttpRequest,
stream: T,
@ -54,7 +68,8 @@ where
) -> Result<HttpResponse, Error>
where
T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
F: FnOnce(serde_json::Value) -> Result<Data> + Send + Sync + 'static,
F: FnOnce(serde_json::Value) -> R + Unpin + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
{
let protocol = match request
.headers()
@ -78,7 +93,7 @@ where
protocol,
last_heartbeat: Instant::now(),
messages: None,
initializer: Some(Box::new(initializer)),
initializer: Some(initializer),
continuation: Vec::new(),
},
&["graphql-transport-ws", "graphql-ws"],
@ -97,11 +112,13 @@ where
}
}
impl<Query, Mutation, Subscription> Actor for WSSubscription<Query, Mutation, Subscription>
impl<Query, Mutation, Subscription, F, R> Actor for WSSubscription<Query, Mutation, Subscription, F>
where
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
F: FnOnce(serde_json::Value) -> R + Unpin + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
{
type Context = WebsocketContext<Self>;
@ -113,7 +130,7 @@ where
WebSocket::with_data(
self.schema.clone(),
rx,
self.initializer.take(),
self.initializer.take().unwrap(),
self.protocol,
)
.into_actor(self)
@ -127,12 +144,14 @@ where
}
}
impl<Query, Mutation, Subscription> StreamHandler<Result<Message, ProtocolError>>
for WSSubscription<Query, Mutation, Subscription>
impl<Query, Mutation, Subscription, F, R> StreamHandler<Result<Message, ProtocolError>>
for WSSubscription<Query, Mutation, Subscription, F>
where
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
F: FnOnce(serde_json::Value) -> R + Unpin + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
{
fn handle(&mut self, msg: Result<Message, ProtocolError>, ctx: &mut Self::Context) {
let msg = match msg {

View File

@ -1,3 +1,5 @@
use std::future::Future;
use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType};
use futures_util::{future, StreamExt};
use warp::filters::ws;
@ -48,21 +50,22 @@ where
Mutation: ObjectType + Sync + Send + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
graphql_subscription_with_data::<_, _, _, fn(serde_json::Value) -> Result<Data>>(schema, None)
graphql_subscription_with_data(schema, |_| async { Ok(Default::default()) })
}
/// GraphQL subscription filter
///
/// Specifies that a function converts the init payload to data.
pub fn graphql_subscription_with_data<Query, Mutation, Subscription, F>(
pub fn graphql_subscription_with_data<Query, Mutation, Subscription, F, R>(
schema: Schema<Query, Mutation, Subscription>,
initializer: Option<F>,
initializer: F,
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone
where
Query: ObjectType + Sync + Send + 'static,
Mutation: ObjectType + Sync + Send + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
F: FnOnce(serde_json::Value) -> Result<Data> + Send + Sync + Clone + 'static,
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
F: FnOnce(serde_json::Value) -> R + Clone + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
{
use async_graphql::http::WebSocketProtocols;
use std::str::FromStr;

View File

@ -1,11 +1,14 @@
//! WebSocket transport for subscription
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures_util::future::{BoxFuture, Ready};
use futures_util::stream::Stream;
use futures_util::FutureExt;
use pin_project_lite::pin_project;
use serde::{Deserialize, Serialize};
@ -17,6 +20,7 @@ pin_project! {
/// [Reference](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md).
pub struct WebSocket<S, F, Query, Mutation, Subscription> {
data_initializer: Option<F>,
init_fut: Option<BoxFuture<'static, Result<Data>>>,
data: Arc<Data>,
schema: Schema<Query, Mutation, Subscription>,
streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
@ -27,7 +31,7 @@ pin_project! {
}
impl<S, Query, Mutation, Subscription>
WebSocket<S, fn(serde_json::Value) -> Result<Data>, Query, Mutation, Subscription>
WebSocket<S, fn(serde_json::Value) -> Ready<Result<Data>>, Query, Mutation, Subscription>
{
/// Create a new websocket.
#[must_use]
@ -37,7 +41,8 @@ impl<S, Query, Mutation, Subscription>
protocol: Protocols,
) -> Self {
Self {
data_initializer: None,
data_initializer: Some(|_| futures_util::future::ready(Ok(Default::default()))),
init_fut: None,
data: Arc::default(),
schema,
streams: HashMap::new(),
@ -57,11 +62,12 @@ impl<S, F, Query, Mutation, Subscription> WebSocket<S, F, Query, Mutation, Subsc
pub fn with_data(
schema: Schema<Query, Mutation, Subscription>,
stream: S,
data_initializer: Option<F>,
data_initializer: F,
protocol: Protocols,
) -> Self {
Self {
data_initializer,
data_initializer: Some(data_initializer),
init_fut: None,
data: Arc::default(),
schema,
streams: HashMap::new(),
@ -71,11 +77,13 @@ impl<S, F, Query, Mutation, Subscription> WebSocket<S, F, Query, Mutation, Subsc
}
}
impl<S, F, Query, Mutation, Subscription> Stream for WebSocket<S, F, Query, Mutation, Subscription>
impl<S, F, R, Query, Mutation, Subscription> Stream
for WebSocket<S, F, Query, Mutation, Subscription>
where
S: Stream,
S::Item: AsRef<[u8]>,
F: FnOnce(serde_json::Value) -> Result<Data>,
F: FnOnce(serde_json::Value) -> R + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
@ -105,24 +113,11 @@ where
match message {
ClientMessage::ConnectionInit { payload } => {
if let Some(payload) = payload {
if let Some(data_initializer) = this.data_initializer.take() {
*this.data = Arc::new(match data_initializer(payload) {
Ok(data) => data,
Err(e) => {
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::ConnectionError {
payload: e,
})
.unwrap(),
))
}
});
}
if let Some(data_initializer) = this.data_initializer.take() {
*this.init_fut = Some(Box::pin(async move {
data_initializer(payload.unwrap_or_default()).await
}));
}
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
));
}
ClientMessage::Start {
id,
@ -150,6 +145,24 @@ where
}
}
if let Some(init_fut) = this.init_fut {
if let Poll::Ready(res) = init_fut.poll_unpin(cx) {
*this.init_fut = None;
return match res {
Ok(data) => {
*this.data = Arc::new(data);
Poll::Ready(Some(
serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
))
}
Err(err) => Poll::Ready(Some(
serde_json::to_string(&ServerMessage::ConnectionError { payload: err })
.unwrap(),
)),
};
}
}
for (id, stream) in &mut *this.streams {
match Pin::new(stream).poll_next(cx) {
Poll::Ready(Some(payload)) => {

View File

@ -107,7 +107,7 @@ pub async fn test_subscription_ws_transport_with_token() {
let mut stream = http::WebSocket::with_data(
schema,
rx,
Some(|value| {
|value| async {
#[derive(serde::Deserialize)]
struct Payload {
token: String,
@ -117,7 +117,7 @@ pub async fn test_subscription_ws_transport_with_token() {
let mut data = Data::default();
data.insert(Token(payload.token));
Ok(data)
}),
},
WebSocketProtocols::GraphQLWS,
);

View File

@ -107,7 +107,7 @@ pub async fn test_subscription_ws_transport_with_token() {
let mut stream = http::WebSocket::with_data(
schema,
rx,
Some(|value| {
|value| async {
#[derive(serde::Deserialize)]
struct Payload {
token: String,
@ -117,7 +117,7 @@ pub async fn test_subscription_ws_transport_with_token() {
let mut data = Data::default();
data.insert(Token(payload.token));
Ok(data)
}),
},
WebSocketProtocols::SubscriptionsTransportWS,
);