Improve websockets

This commit is contained in:
Koxiaet 2020-09-17 19:22:54 +01:00
parent 92de2e5816
commit 890b282dd9
7 changed files with 409 additions and 508 deletions

View File

@ -16,8 +16,8 @@ categories = ["network-programming", "asynchronous"]
async-graphql = { path = "../..", version = "2.0.0-alpha.8" } async-graphql = { path = "../..", version = "2.0.0-alpha.8" }
actix-web = "3.0.0" actix-web = "3.0.0"
actix-web-actors = "3.0.0" actix-web-actors = "3.0.0"
actix-http = "2.0.0"
actix = "0.10.0" actix = "0.10.0"
actix-rt = "1.1.0" actix-rt = "1.1.0"
futures = "0.3.0" futures = "0.3.5"
bytes = "0.5.4"
serde_json = "1.0.48" serde_json = "1.0.48"

View File

@ -1,11 +1,13 @@
use actix::{ use actix::{
Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture, Actor, ActorContext, ActorFuture, ActorStream, AsyncContext, ContextFutureSpawner,
StreamHandler, WrapFuture, WrapStream,
}; };
use actix_http::ws;
use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext}; use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext};
use async_graphql::http::WebSocketStream; use async_graphql::http::WebSocket;
use async_graphql::{resolver_utils::ObjectType, Data, FieldResult, Schema, SubscriptionType}; use async_graphql::{resolver_utils::ObjectType, Data, FieldResult, Schema, SubscriptionType};
use futures::stream::SplitSink; use futures::channel::mpsc;
use futures::{SinkExt, StreamExt}; use futures::SinkExt;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
@ -13,10 +15,11 @@ const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
/// Actor for subscription via websocket /// Actor for subscription via websocket
pub struct WSSubscription<Query, Mutation, Subscription> { pub struct WSSubscription<Query, Mutation, Subscription> {
schema: Schema<Query, Mutation, Subscription>, schema: Option<Schema<Query, Mutation, Subscription>>,
hb: Instant, last_heartbeat: Instant,
sink: Option<SplitSink<WebSocketStream, String>>, messages: Option<mpsc::UnboundedSender<Vec<u8>>>,
initializer: Option<Box<dyn Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync>>, initializer: Option<Box<dyn FnOnce(serde_json::Value) -> FieldResult<Data> + Send + Sync>>,
continuation: Vec<u8>,
} }
impl<Query, Mutation, Subscription> WSSubscription<Query, Mutation, Subscription> impl<Query, Mutation, Subscription> WSSubscription<Query, Mutation, Subscription>
@ -26,19 +29,20 @@ where
Subscription: SubscriptionType + Send + Sync + 'static, Subscription: SubscriptionType + Send + Sync + 'static,
{ {
/// Create an actor for subscription connection via websocket. /// Create an actor for subscription connection via websocket.
pub fn new(schema: &Schema<Query, Mutation, Subscription>) -> Self { pub fn new(schema: Schema<Query, Mutation, Subscription>) -> Self {
Self { Self {
schema: schema.clone(), schema: Some(schema),
hb: Instant::now(), last_heartbeat: Instant::now(),
sink: None, messages: None,
initializer: None, initializer: None,
continuation: Vec::new(),
} }
} }
/// Set a context data initialization function. /// Set a context data initialization function.
pub fn initializer<F>(self, f: F) -> Self pub fn initializer<F>(self, f: F) -> Self
where where
F: Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync + 'static, F: FnOnce(serde_json::Value) -> FieldResult<Data> + Send + Sync + 'static,
{ {
Self { Self {
initializer: Some(Box::new(f)), initializer: Some(Box::new(f)),
@ -46,9 +50,9 @@ where
} }
} }
fn hb(&self, ctx: &mut WebsocketContext<Self>) { fn send_heartbeats(&self, ctx: &mut WebsocketContext<Self>) {
ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT { if Instant::now().duration_since(act.last_heartbeat) > CLIENT_TIMEOUT {
ctx.stop(); ctx.stop();
} }
ctx.ping(b""); ctx.ping(b"");
@ -65,20 +69,19 @@ where
type Context = WebsocketContext<Self>; type Context = WebsocketContext<Self>;
fn started(&mut self, ctx: &mut Self::Context) { fn started(&mut self, ctx: &mut Self::Context) {
self.hb(ctx); self.send_heartbeats(ctx);
if let Some(initializer) = self.initializer.take() {
let (sink, stream) = async_graphql::http::WebSocketStream::new_with_initializer( let (tx, rx) = mpsc::unbounded();
&self.schema,
initializer, WebSocket::with_data(self.schema.take().unwrap(), rx, self.initializer.take())
) .into_actor(self)
.split(); .map(|response, _act, ctx| {
ctx.add_stream(stream); ctx.text(response);
self.sink = Some(sink); })
} else { .finish()
let (sink, stream) = async_graphql::http::WebSocketStream::new(&self.schema).split(); .spawn(ctx);
ctx.add_stream(stream);
self.sink = Some(sink); self.messages = Some(tx);
};
} }
} }
@ -98,47 +101,49 @@ where
Ok(msg) => msg, Ok(msg) => msg,
}; };
match msg { let message = match msg {
Message::Ping(msg) => { Message::Ping(msg) => {
self.hb = Instant::now(); self.last_heartbeat = Instant::now();
ctx.pong(&msg); ctx.pong(&msg);
None
} }
Message::Pong(_) => { Message::Pong(_) => {
self.hb = Instant::now(); self.last_heartbeat = Instant::now();
None
} }
Message::Text(s) => { Message::Continuation(item) => match item {
if let Some(mut sink) = self.sink.take() { ws::Item::FirstText(bytes) | ws::Item::FirstBinary(bytes) => {
async move { self.continuation = bytes.to_vec();
let res = sink.send(s).await; None
res.map(|_| sink)
}
.into_actor(self)
.then(|res, actor, ctx| {
match res {
Ok(sink) => actor.sink = Some(sink),
Err(_) => ctx.stop(),
}
async {}.into_actor(actor)
})
.wait(ctx);
} }
} ws::Item::Continue(bytes) => {
Message::Binary(_) | Message::Close(_) | Message::Continuation(_) => { self.continuation.extend_from_slice(&bytes);
None
}
ws::Item::Last(bytes) => {
self.continuation.extend_from_slice(&bytes);
Some(std::mem::take(&mut self.continuation))
}
},
Message::Text(s) => Some(s.into_bytes()),
Message::Binary(bytes) => Some(bytes.to_vec()),
Message::Close(_) => {
ctx.stop(); ctx.stop();
None
} }
Message::Nop => {} Message::Nop => None,
};
if let Some(message) = message {
let mut sender = self.messages.as_ref().unwrap().clone();
async move { sender.send(message).await }
.into_actor(self)
.map(|res, _actor, ctx| match res {
Ok(()) => {}
Err(_) => ctx.stop(),
})
.spawn(ctx)
} }
} }
} }
impl<Query, Mutation, Subscription> StreamHandler<String>
for WSSubscription<Query, Mutation, Subscription>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
fn handle(&mut self, data: String, ctx: &mut Self::Context) {
ctx.text(data);
}
}

View File

@ -1,4 +1,4 @@
//! Async-graphql integration with Wrap //! Async-graphql integration with Warp
#![warn(missing_docs)] #![warn(missing_docs)]
#![allow(clippy::type_complexity)] #![allow(clippy::type_complexity)]
@ -9,13 +9,11 @@ use async_graphql::http::MultipartOptions;
use async_graphql::{ use async_graphql::{
resolver_utils::ObjectType, Data, FieldResult, Request, Schema, SubscriptionType, resolver_utils::ObjectType, Data, FieldResult, Request, Schema, SubscriptionType,
}; };
use futures::io::ErrorKind; use futures::{future, StreamExt, TryStreamExt};
use futures::{select, TryStreamExt};
use futures::{SinkExt, StreamExt};
use hyper::Method; use hyper::Method;
use std::io; use std::io::{self, ErrorKind};
use std::sync::Arc; use std::sync::Arc;
use warp::filters::ws::Message; use warp::filters::ws;
use warp::filters::BoxedFilter; use warp::filters::BoxedFilter;
use warp::reject::Reject; use warp::reject::Reject;
use warp::reply::Response; use warp::reply::Response;
@ -169,81 +167,57 @@ where
/// ``` /// ```
pub fn graphql_subscription<Query, Mutation, Subscription>( pub fn graphql_subscription<Query, Mutation, Subscription>(
schema: Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
) -> BoxedFilter<(impl Reply,)> ) -> impl Filter<Extract = (impl Reply,)> + Clone
where where
Query: ObjectType + Sync + Send + 'static, Query: ObjectType + Sync + Send + 'static,
Mutation: ObjectType + Sync + Send + 'static, Mutation: ObjectType + Sync + Send + 'static,
Subscription: SubscriptionType + Send + Sync + 'static, Subscription: SubscriptionType + Send + Sync + 'static,
{ {
graphql_subscription_with_initializer(schema, |_| Ok(Default::default())) graphql_subscription_with_data::<_, _, _, fn(serde_json::Value) -> FieldResult<Data>>(
schema, None,
)
} }
/// GraphQL subscription filter /// GraphQL subscription filter
/// ///
/// Specifies that a function converts the init payload to data. /// Specifies that a function converts the init payload to data.
pub fn graphql_subscription_with_initializer<Query, Mutation, Subscription, F>( pub fn graphql_subscription_with_data<Query, Mutation, Subscription, F>(
schema: Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
initializer: F, initializer: Option<F>,
) -> BoxedFilter<(impl Reply,)> ) -> impl Filter<Extract = (impl Reply,)> + Clone
where where
Query: ObjectType + Sync + Send + 'static, Query: ObjectType + Sync + Send + 'static,
Mutation: ObjectType + Sync + Send + 'static, Mutation: ObjectType + Sync + Send + 'static,
Subscription: SubscriptionType + Send + Sync + 'static, Subscription: SubscriptionType + Send + Sync + 'static,
F: Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync + Clone + 'static, F: FnOnce(serde_json::Value) -> FieldResult<Data> + Send + Sync + Clone + 'static,
{ {
warp::any() warp::any()
.and(warp::ws()) .and(warp::ws())
.and(warp::any().map(move || schema.clone())) .and(warp::any().map(move || schema.clone()))
.and(warp::any().map(move || initializer.clone())) .and(warp::any().map(move || initializer.clone()))
.map( .map(
|ws: warp::ws::Ws, schema: Schema<Query, Mutation, Subscription>, initializer: F| { |ws: ws::Ws, schema: Schema<Query, Mutation, Subscription>, initializer: Option<F>| {
ws.on_upgrade(move |websocket| { ws.on_upgrade(move |websocket| {
let (mut tx, rx) = websocket.split(); let (ws_sender, ws_receiver) = websocket.split();
let (mut stx, srx) =
async_graphql::http::WebSocketStream::new_with_initializer(
&schema,
initializer,
)
.split();
let mut rx = rx.fuse();
let srx = srx.fuse();
async move { async move {
futures::pin_mut!(srx); let _ = async_graphql::http::WebSocket::with_data(
schema,
loop { ws_receiver
select! { .take_while(|msg| future::ready(msg.is_ok()))
bytes = srx.next() => { .map(Result::unwrap)
if let Some(text) = bytes { .map(ws::Message::into_bytes),
if tx.send(Message::text(text)).await.is_err() { initializer,
return; )
} .map(ws::Message::text)
} else { .map(Ok)
return; .forward(ws_sender)
} .await;
}
msg = rx.next() => {
if let Some(Ok(msg)) = msg {
if msg.is_text() {
if let Ok(text) = String::from_utf8(msg.into_bytes()) {
if stx.send(text).await.is_err() {
return;
}
}
}
} else {
return;
}
}
}
}
} }
}) })
}, },
) )
.map(|reply| warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws")) .map(|reply| warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws"))
.boxed()
} }
/// GraphQL reply /// GraphQL reply

View File

@ -10,7 +10,7 @@ pub use graphiql_source::graphiql_source;
#[cfg(feature = "multipart")] #[cfg(feature = "multipart")]
pub use multipart::{receive_multipart, MultipartOptions}; pub use multipart::{receive_multipart, MultipartOptions};
pub use playground_source::{playground_source, GraphQLPlaygroundConfig}; pub use playground_source::{playground_source, GraphQLPlaygroundConfig};
pub use websocket::WebSocketStream; pub use websocket::WebSocket;
use crate::{ParseRequestError, Request}; use crate::{ParseRequestError, Request};
use futures::io::AsyncRead; use futures::io::AsyncRead;

View File

@ -2,291 +2,207 @@
use crate::resolver_utils::ObjectType; use crate::resolver_utils::ObjectType;
use crate::{Data, FieldResult, Request, Response, Schema, SubscriptionType}; use crate::{Data, FieldResult, Request, Response, Schema, SubscriptionType};
use futures::channel::mpsc; use futures::Stream;
use futures::task::{Context, Poll}; use pin_project_lite::pin_project;
use futures::{Future, Sink, SinkExt, Stream, StreamExt};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque}; use std::collections::HashMap;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::{Context, Poll};
#[derive(Serialize, Deserialize)] pin_project! {
struct OperationMessage<'a, T> { /// A GraphQL connection over websocket.
#[serde(rename = "type")] ///
ty: &'a str, /// [Reference](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md).
pub struct WebSocket<S, F, Query, Mutation, Subscription> {
#[serde(skip_serializing_if = "Option::is_none")] data_initializer: Option<F>,
id: Option<String>, data: Arc<Data>,
schema: Schema<Query, Mutation, Subscription>,
#[serde(skip_serializing_if = "Option::is_none")] streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
payload: Option<T>, #[pin]
} stream: S,
type SubscriptionStreams = HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>;
type HandleRequestBoxFut = Pin<Box<dyn Future<Output = FieldResult<WSContext>> + Send>>;
type InitializerFn = Arc<dyn Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync>;
/// A wrapper around an underlying raw stream which implements the WebSocket protocol.
///
/// Only Text messages can be transmitted. You can use `futures::stream::StreamExt::split` function
/// to splits this object into separate Sink and Stream objects.
pub struct WebSocketStream {
tx: mpsc::UnboundedSender<String>,
rx: Pin<Box<dyn Stream<Item = String> + Send>>,
}
impl Sink<String> for WebSocketStream {
type Error = mpsc::SendError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_ready_unpin(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
self.tx.start_send(item)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_flush_unpin(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_close_unpin(cx)
} }
} }
impl Stream for WebSocketStream { impl<S, Query, Mutation, Subscription>
type Item = String; WebSocket<S, fn(serde_json::Value) -> FieldResult<Data>, Query, Mutation, Subscription>
{
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { /// Create a new websocket.
self.rx.poll_next_unpin(cx) #[must_use]
} pub fn new(schema: Schema<Query, Mutation, Subscription>, stream: S) -> Self {
} Self {
data_initializer: None,
impl WebSocketStream { data: Arc::default(),
/// Create a websocket transport. schema,
pub fn new<Query, Mutation, Subscription>( streams: HashMap::new(),
schema: &Schema<Query, Mutation, Subscription>, stream,
) -> Self
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
Self::new_with_initializer(schema, |_| Ok(Default::default()))
}
/// Create a websocket transport and specify a context initialization function.
pub fn new_with_initializer<Query, Mutation, Subscription>(
schema: &Schema<Query, Mutation, Subscription>,
initializer: impl Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync + 'static,
) -> Self
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
let (tx, rx) = mpsc::unbounded();
WebSocketStream {
tx,
rx: SubscriptionStream {
schema: schema.clone(),
initializer: Arc::new(initializer),
rx_bytes: rx,
handle_request_fut: None,
ctx: Some(WSContext {
streams: Default::default(),
send_buf: Default::default(),
ctx_data: Arc::new(Data::default()),
}),
}
.boxed(),
} }
} }
} }
struct WSContext { impl<S, F, Query, Mutation, Subscription> WebSocket<S, F, Query, Mutation, Subscription> {
streams: SubscriptionStreams, /// Create a new websocket with a data initialization function.
send_buf: VecDeque<String>, ///
ctx_data: Arc<Data>, /// This function, if present, will be called with the data sent by the client in the
} /// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
/// From that point on the returned data will be accessible to all requests.
fn send_message<T: Serialize>(send_buf: &mut VecDeque<String>, msg: &T) { #[must_use]
if let Ok(data) = serde_json::to_string(msg) { pub fn with_data(
send_buf.push_back(data); schema: Schema<Query, Mutation, Subscription>,
stream: S,
data_initializer: Option<F>,
) -> Self {
Self {
data_initializer,
data: Arc::default(),
schema,
streams: HashMap::new(),
stream,
}
} }
} }
#[allow(missing_docs)] impl<S, F, Query, Mutation, Subscription> Stream for WebSocket<S, F, Query, Mutation, Subscription>
#[allow(clippy::type_complexity)]
struct SubscriptionStream<Query, Mutation, Subscription> {
schema: Schema<Query, Mutation, Subscription>,
initializer: InitializerFn,
rx_bytes: mpsc::UnboundedReceiver<String>,
handle_request_fut: Option<HandleRequestBoxFut>,
ctx: Option<WSContext>,
}
impl<'a, Query, Mutation, Subscription> Stream for SubscriptionStream<Query, Mutation, Subscription>
where where
S: Stream,
S::Item: AsRef<[u8]>,
F: FnOnce(serde_json::Value) -> FieldResult<Data>,
Query: ObjectType + Send + Sync + 'static, Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static, Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static, Subscription: SubscriptionType + Send + Sync + 'static,
{ {
type Item = String; type Item = String;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let this = &mut *self; let this = self.project();
loop { match this.stream.poll_next(cx) {
// receive bytes Poll::Ready(message) => {
if let Some(ctx) = &mut this.ctx { let message = match message {
if let Some(bytes) = ctx.send_buf.pop_front() { Some(message) => message,
return Poll::Ready(Some(bytes)); None => return Poll::Ready(None),
} };
}
if let Some(handle_request_fut) = &mut this.handle_request_fut { let message: ClientMessage = match serde_json::from_slice(message.as_ref()) {
match handle_request_fut.as_mut().poll(cx) { Ok(message) => message,
Poll::Ready(Ok(ctx)) => { Err(e) => {
this.ctx = Some(ctx); return Poll::Ready(Some(
this.handle_request_fut = None; serde_json::to_string(&ServerMessage::ConnectionError {
continue; payload: ConnectionError {
message: e.to_string(),
extensions: None,
},
})
.unwrap(),
))
} }
Poll::Ready(Err(_)) => return Poll::Ready(None), };
Poll::Pending => {}
}
} else {
match Pin::new(&mut this.rx_bytes).poll_next(cx) {
Poll::Ready(Some(data)) => {
let ctx = this.ctx.take().unwrap();
this.handle_request_fut = Some(Box::pin(handle_request(
this.schema.clone(),
this.initializer.clone(),
ctx,
data,
)));
continue;
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => {}
}
}
// receive msg match message {
if let Some(ctx) = &mut this.ctx { ClientMessage::ConnectionInit { payload } => {
let mut closed = Vec::new(); if let Some(payload) = payload {
if let Some(data_initializer) = this.data_initializer.take() {
for (id, incoming_stream) in ctx.streams.iter_mut() { *this.data = Arc::new(match data_initializer(payload) {
match incoming_stream.as_mut().poll_next(cx) { Ok(data) => data,
Poll::Ready(Some(res)) => { Err(e) => {
if let Some(err) = &res.error { return Poll::Ready(Some(
closed.push(id.to_string()); serde_json::to_string(
send_message( &ServerMessage::ConnectionError {
&mut ctx.send_buf, payload: ConnectionError {
&OperationMessage { message: e.0,
ty: "error", extensions: e.1,
id: Some(id.to_string()), },
payload: Some(err), },
}, )
); .unwrap(),
} else { ))
send_message( }
&mut ctx.send_buf, });
&OperationMessage {
ty: "data",
id: Some(id.to_string()),
payload: Some(&res),
},
);
} }
} }
Poll::Ready(None) => { return Poll::Ready(Some(
closed.push(id.to_string()); serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
send_message( ));
&mut ctx.send_buf,
&OperationMessage {
ty: "complete",
id: Some(id.to_string()),
payload: Option::<serde_json::Value>::None,
},
);
}
Poll::Pending => {}
} }
} ClientMessage::Start {
id,
for id in closed { payload: request,
ctx.streams.remove(&id); } => {
} this.streams.insert(
id,
if !ctx.send_buf.is_empty() { Box::pin(
continue; this.schema
} .execute_stream_with_ctx_data(request, Arc::clone(this.data)),
} ),
return Poll::Pending;
}
}
}
async fn handle_request<Query, Mutation, Subscription>(
schema: Schema<Query, Mutation, Subscription>,
initializer: InitializerFn,
mut ctx: WSContext,
data: String,
) -> FieldResult<WSContext>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
match serde_json::from_str::<OperationMessage<serde_json::Value>>(&data) {
Ok(msg) => match msg.ty {
"connection_init" => {
if let Some(payload) = msg.payload {
ctx.ctx_data = Arc::new(initializer(payload)?);
}
send_message(
&mut ctx.send_buf,
&OperationMessage {
ty: "connection_ack",
id: None,
payload: Option::<serde_json::Value>::None,
},
);
}
"start" => {
if let (Some(id), Some(payload)) = (msg.id, msg.payload) {
if let Ok(request) = serde_json::from_value::<Request>(payload) {
let stream = schema
.execute_stream_with_ctx_data(request, ctx.ctx_data.clone())
.boxed();
ctx.streams.insert(id, stream);
}
}
}
"stop" => {
if let Some(id) = msg.id {
if ctx.streams.remove(&id).is_some() {
send_message(
&mut ctx.send_buf,
&OperationMessage {
ty: "complete",
id: Some(id),
payload: Option::<serde_json::Value>::None,
},
); );
} }
ClientMessage::Stop { id } => {
if this.streams.remove(id).is_some() {
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::Complete { id }).unwrap(),
));
}
}
ClientMessage::ConnectionTerminate => return Poll::Ready(None),
} }
} }
"connection_terminate" => return Err("connection_terminate".into()), Poll::Pending => {}
_ => return Err("Unknown op".into()), }
},
Err(err) => return Err(err.into()),
}
Ok(ctx) for (id, stream) in &mut *this.streams {
match Pin::new(stream).poll_next(cx) {
Poll::Ready(Some(payload)) => {
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::Data {
id,
payload: Box::new(payload),
})
.unwrap(),
));
}
Poll::Ready(None) => {
let id = id.clone();
this.streams.remove(&id);
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(),
));
}
Poll::Pending => {}
}
}
Poll::Pending
}
}
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ClientMessage<'a> {
ConnectionInit { payload: Option<serde_json::Value> },
Start { id: String, payload: Request },
Stop { id: &'a str },
ConnectionTerminate,
}
#[derive(Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ServerMessage<'a> {
ConnectionError { payload: ConnectionError },
ConnectionAck,
Data { id: &'a str, payload: Box<Response> },
// Not used by this library, as it's not necessary to send
// Error {
// id: &'a str,
// payload: serde_json::Value,
// },
Complete { id: &'a str },
// Not used by this library
// #[serde(rename = "ka")]
// KeepAlive
}
#[derive(Serialize)]
struct ConnectionError {
message: String,
extensions: Option<serde_json::Value>,
} }

View File

@ -436,9 +436,9 @@ where
pub(crate) fn execute_stream_with_ctx_data( pub(crate) fn execute_stream_with_ctx_data(
&self, &self,
request: impl Into<Request>, request: impl Into<Request> + Send,
ctx_data: Arc<Data>, ctx_data: Arc<Data>,
) -> impl Stream<Item = Response> { ) -> impl Stream<Item = Response> + Send {
let schema = self.clone(); let schema = self.clone();
async_stream::stream! { async_stream::stream! {
@ -495,7 +495,10 @@ where
} }
/// Execute an GraphQL subscription. /// Execute an GraphQL subscription.
pub fn execute_stream(&self, request: impl Into<Request>) -> impl Stream<Item = Response> { pub fn execute_stream(
&self,
request: impl Into<Request>,
) -> impl Stream<Item = Response> + Send {
let mut request = request.into(); let mut request = request.into();
let ctx_data = std::mem::take(&mut request.data); let ctx_data = std::mem::take(&mut request.data);
self.execute_stream_with_ctx_data(request, Arc::new(ctx_data)) self.execute_stream_with_ctx_data(request, Arc::new(ctx_data))

View File

@ -1,4 +1,5 @@
use async_graphql::*; use async_graphql::*;
use futures::channel::mpsc;
use futures::{SinkExt, Stream, StreamExt}; use futures::{SinkExt, Stream, StreamExt};
#[async_std::test] #[async_std::test]
@ -18,56 +19,55 @@ pub async fn test_subscription_ws_transport() {
} }
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let mut stream = http::WebSocketStream::new(&schema); let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::new(schema, rx);
stream tx.send(
.send( serde_json::to_string(&serde_json::json!({
serde_json::to_string(&serde_json::json!({ "type": "connection_init",
"type": "connection_init", }))
})) .unwrap(),
.unwrap(), )
) .await
.await .unwrap();
.unwrap();
assert_eq!( assert_eq!(
Some(serde_json::json!({ serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
"type": "connection_ack", serde_json::json!({
})), "type": "connection_ack",
serde_json::from_str(&stream.next().await.unwrap()).unwrap() }),
); );
stream tx.send(
.send( serde_json::to_string(&serde_json::json!({
serde_json::to_string(&serde_json::json!({ "type": "start",
"type": "start", "id": "1",
"id": "1", "payload": {
"payload": { "query": "subscription { values }"
"query": "subscription { values }" },
}, }))
})) .unwrap(),
.unwrap(), )
) .await
.await .unwrap();
.unwrap();
for i in 0..10 { for i in 0..10 {
assert_eq!( assert_eq!(
Some(serde_json::json!({ serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
"type": "data", serde_json::json!({
"id": "1", "type": "data",
"payload": { "data": { "values": i } }, "id": "1",
})), "payload": { "data": { "values": i } },
serde_json::from_str(&stream.next().await.unwrap()).unwrap() }),
); );
} }
assert_eq!( assert_eq!(
Some(serde_json::json!({ serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
"type": "complete", serde_json::json!({
"id": "1", "type": "complete",
})), "id": "1",
serde_json::from_str(&stream.next().await.unwrap()).unwrap() }),
); );
} }
@ -93,56 +93,59 @@ pub async fn test_subscription_ws_transport_with_token() {
} }
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let mut stream = http::WebSocketStream::new_with_initializer(&schema, |value| { let (mut tx, rx) = mpsc::unbounded();
#[derive(serde::Deserialize)] let mut stream = http::WebSocket::with_data(
struct Payload { schema,
token: String, rx,
} Some(|value| {
#[derive(serde::Deserialize)]
struct Payload {
token: String,
}
let payload: Payload = serde_json::from_value(value).unwrap(); let payload: Payload = serde_json::from_value(value).unwrap();
let mut data = Data::default(); let mut data = Data::default();
data.insert(Token(payload.token)); data.insert(Token(payload.token));
Ok(data) Ok(data)
}); }),
);
stream tx.send(
.send( serde_json::to_string(&serde_json::json!({
serde_json::to_string(&serde_json::json!({ "type": "connection_init",
"type": "connection_init", "payload": { "token": "123456" }
"payload": { "token": "123456" } }))
})) .unwrap(),
.unwrap(), )
) .await
.await .unwrap();
.unwrap();
assert_eq!( assert_eq!(
Some(serde_json::json!({ Some(serde_json::json!({
"type": "connection_ack", "type": "connection_ack",
})), })),
serde_json::from_str(&stream.next().await.unwrap()).unwrap() serde_json::from_str(&stream.next().await.unwrap()).unwrap()
); );
stream tx.send(
.send( serde_json::to_string(&serde_json::json!({
serde_json::to_string(&serde_json::json!({ "type": "start",
"type": "start", "id": "1",
"id": "1", "payload": {
"payload": { "query": "subscription { values }"
"query": "subscription { values }" },
}, }))
})) .unwrap(),
.unwrap(), )
) .await
.await .unwrap();
.unwrap();
for i in 0..10 { for i in 0..10 {
assert_eq!( assert_eq!(
Some(serde_json::json!({ Some(serde_json::json!({
"type": "data", "type": "data",
"id": "1", "id": "1",
"payload": { "data": { "values": i } }, "payload": { "data": { "values": i } },
})), })),
serde_json::from_str(&stream.next().await.unwrap()).unwrap() serde_json::from_str(&stream.next().await.unwrap()).unwrap()
); );
@ -150,8 +153,8 @@ pub async fn test_subscription_ws_transport_with_token() {
assert_eq!( assert_eq!(
Some(serde_json::json!({ Some(serde_json::json!({
"type": "complete", "type": "complete",
"id": "1", "id": "1",
})), })),
serde_json::from_str(&stream.next().await.unwrap()).unwrap() serde_json::from_str(&stream.next().await.unwrap()).unwrap()
); );
@ -189,45 +192,44 @@ pub async fn test_subscription_ws_transport_error() {
} }
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let mut stream = http::WebSocketStream::new(&schema); let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::new(schema, rx);
stream tx.send(
.send( serde_json::to_string(&serde_json::json!({
serde_json::to_string(&serde_json::json!({ "type": "connection_init"
"type": "connection_init" }))
})) .unwrap(),
.unwrap(), )
) .await
.await .unwrap();
.unwrap();
assert_eq!( assert_eq!(
Some(serde_json::json!({ Some(serde_json::json!({
"type": "connection_ack", "type": "connection_ack",
})), })),
serde_json::from_str(&stream.next().await.unwrap()).unwrap() serde_json::from_str(&stream.next().await.unwrap()).unwrap()
); );
stream tx.send(
.send( serde_json::to_string(&serde_json::json!({
serde_json::to_string(&serde_json::json!({ "type": "start",
"type": "start", "id": "1",
"id": "1", "payload": {
"payload": { "query": "subscription { events { value } }"
"query": "subscription { events { value } }" },
}, }))
})) .unwrap(),
.unwrap(), )
) .await
.await .unwrap();
.unwrap();
for i in 0i32..5 { for i in 0i32..5 {
assert_eq!( assert_eq!(
Some(serde_json::json!({ Some(serde_json::json!({
"type": "data", "type": "data",
"id": "1", "id": "1",
"payload": { "data": { "events": { "value": i } } }, "payload": { "data": { "events": { "value": i } } },
})), })),
serde_json::from_str(&stream.next().await.unwrap()).unwrap() serde_json::from_str(&stream.next().await.unwrap()).unwrap()
); );
@ -235,13 +237,15 @@ pub async fn test_subscription_ws_transport_error() {
assert_eq!( assert_eq!(
Some(serde_json::json!({ Some(serde_json::json!({
"type": "error", "type": "data",
"id": "1", "id": "1",
"payload": [{ "payload": {
"message": "TestError", "errors": [{
"locations": [{"line": 1, "column": 25}], "message": "TestError",
"path": ["events", "value"], "locations": [{"line": 1, "column": 25}],
}], "path": ["events", "value"],
}],
},
})), })),
serde_json::from_str(&stream.next().await.unwrap()).unwrap() serde_json::from_str(&stream.next().await.unwrap()).unwrap()
); );
@ -259,17 +263,17 @@ pub async fn test_query_over_websocket() {
} }
let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription); let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription);
let mut stream = http::WebSocketStream::new(&schema); let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::new(schema, rx);
stream tx.send(
.send( serde_json::to_string(&serde_json::json!({
serde_json::to_string(&serde_json::json!({ "type": "connection_init",
"type": "connection_init", }))
})) .unwrap(),
.unwrap(), )
) .await
.await .unwrap();
.unwrap();
assert_eq!( assert_eq!(
Some(serde_json::json!({ Some(serde_json::json!({
@ -278,19 +282,18 @@ pub async fn test_query_over_websocket() {
serde_json::from_str(&stream.next().await.unwrap()).unwrap() serde_json::from_str(&stream.next().await.unwrap()).unwrap()
); );
stream tx.send(
.send( serde_json::to_string(&serde_json::json!({
serde_json::to_string(&serde_json::json!({ "type": "start",
"type": "start", "id": "1",
"id": "1", "payload": {
"payload": { "query": "query { value }"
"query": "query { value }" },
}, }))
})) .unwrap(),
.unwrap(), )
) .await
.await .unwrap();
.unwrap();
assert_eq!( assert_eq!(
Some(serde_json::json!({ Some(serde_json::json!({