Improve websockets

This commit is contained in:
Koxiaet 2020-09-17 19:22:54 +01:00
parent 92de2e5816
commit 890b282dd9
7 changed files with 409 additions and 508 deletions

View File

@ -16,8 +16,8 @@ categories = ["network-programming", "asynchronous"]
async-graphql = { path = "../..", version = "2.0.0-alpha.8" }
actix-web = "3.0.0"
actix-web-actors = "3.0.0"
actix-http = "2.0.0"
actix = "0.10.0"
actix-rt = "1.1.0"
futures = "0.3.0"
bytes = "0.5.4"
futures = "0.3.5"
serde_json = "1.0.48"

View File

@ -1,11 +1,13 @@
use actix::{
Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture,
Actor, ActorContext, ActorFuture, ActorStream, AsyncContext, ContextFutureSpawner,
StreamHandler, WrapFuture, WrapStream,
};
use actix_http::ws;
use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext};
use async_graphql::http::WebSocketStream;
use async_graphql::http::WebSocket;
use async_graphql::{resolver_utils::ObjectType, Data, FieldResult, Schema, SubscriptionType};
use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use futures::channel::mpsc;
use futures::SinkExt;
use std::time::{Duration, Instant};
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
@ -13,10 +15,11 @@ const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
/// Actor for subscription via websocket
pub struct WSSubscription<Query, Mutation, Subscription> {
schema: Schema<Query, Mutation, Subscription>,
hb: Instant,
sink: Option<SplitSink<WebSocketStream, String>>,
initializer: Option<Box<dyn Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync>>,
schema: Option<Schema<Query, Mutation, Subscription>>,
last_heartbeat: Instant,
messages: Option<mpsc::UnboundedSender<Vec<u8>>>,
initializer: Option<Box<dyn FnOnce(serde_json::Value) -> FieldResult<Data> + Send + Sync>>,
continuation: Vec<u8>,
}
impl<Query, Mutation, Subscription> WSSubscription<Query, Mutation, Subscription>
@ -26,19 +29,20 @@ where
Subscription: SubscriptionType + Send + Sync + 'static,
{
/// Create an actor for subscription connection via websocket.
pub fn new(schema: &Schema<Query, Mutation, Subscription>) -> Self {
pub fn new(schema: Schema<Query, Mutation, Subscription>) -> Self {
Self {
schema: schema.clone(),
hb: Instant::now(),
sink: None,
schema: Some(schema),
last_heartbeat: Instant::now(),
messages: None,
initializer: None,
continuation: Vec::new(),
}
}
/// Set a context data initialization function.
pub fn initializer<F>(self, f: F) -> Self
where
F: Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync + 'static,
F: FnOnce(serde_json::Value) -> FieldResult<Data> + Send + Sync + 'static,
{
Self {
initializer: Some(Box::new(f)),
@ -46,9 +50,9 @@ where
}
}
fn hb(&self, ctx: &mut WebsocketContext<Self>) {
fn send_heartbeats(&self, ctx: &mut WebsocketContext<Self>) {
ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT {
if Instant::now().duration_since(act.last_heartbeat) > CLIENT_TIMEOUT {
ctx.stop();
}
ctx.ping(b"");
@ -65,20 +69,19 @@ where
type Context = WebsocketContext<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
self.hb(ctx);
if let Some(initializer) = self.initializer.take() {
let (sink, stream) = async_graphql::http::WebSocketStream::new_with_initializer(
&self.schema,
initializer,
)
.split();
ctx.add_stream(stream);
self.sink = Some(sink);
} else {
let (sink, stream) = async_graphql::http::WebSocketStream::new(&self.schema).split();
ctx.add_stream(stream);
self.sink = Some(sink);
};
self.send_heartbeats(ctx);
let (tx, rx) = mpsc::unbounded();
WebSocket::with_data(self.schema.take().unwrap(), rx, self.initializer.take())
.into_actor(self)
.map(|response, _act, ctx| {
ctx.text(response);
})
.finish()
.spawn(ctx);
self.messages = Some(tx);
}
}
@ -98,47 +101,49 @@ where
Ok(msg) => msg,
};
match msg {
let message = match msg {
Message::Ping(msg) => {
self.hb = Instant::now();
self.last_heartbeat = Instant::now();
ctx.pong(&msg);
None
}
Message::Pong(_) => {
self.hb = Instant::now();
self.last_heartbeat = Instant::now();
None
}
Message::Text(s) => {
if let Some(mut sink) = self.sink.take() {
async move {
let res = sink.send(s).await;
res.map(|_| sink)
}
.into_actor(self)
.then(|res, actor, ctx| {
match res {
Ok(sink) => actor.sink = Some(sink),
Err(_) => ctx.stop(),
}
async {}.into_actor(actor)
})
.wait(ctx);
Message::Continuation(item) => match item {
ws::Item::FirstText(bytes) | ws::Item::FirstBinary(bytes) => {
self.continuation = bytes.to_vec();
None
}
}
Message::Binary(_) | Message::Close(_) | Message::Continuation(_) => {
ws::Item::Continue(bytes) => {
self.continuation.extend_from_slice(&bytes);
None
}
ws::Item::Last(bytes) => {
self.continuation.extend_from_slice(&bytes);
Some(std::mem::take(&mut self.continuation))
}
},
Message::Text(s) => Some(s.into_bytes()),
Message::Binary(bytes) => Some(bytes.to_vec()),
Message::Close(_) => {
ctx.stop();
None
}
Message::Nop => {}
Message::Nop => None,
};
if let Some(message) = message {
let mut sender = self.messages.as_ref().unwrap().clone();
async move { sender.send(message).await }
.into_actor(self)
.map(|res, _actor, ctx| match res {
Ok(()) => {}
Err(_) => ctx.stop(),
})
.spawn(ctx)
}
}
}
impl<Query, Mutation, Subscription> StreamHandler<String>
for WSSubscription<Query, Mutation, Subscription>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
fn handle(&mut self, data: String, ctx: &mut Self::Context) {
ctx.text(data);
}
}

View File

@ -1,4 +1,4 @@
//! Async-graphql integration with Wrap
//! Async-graphql integration with Warp
#![warn(missing_docs)]
#![allow(clippy::type_complexity)]
@ -9,13 +9,11 @@ use async_graphql::http::MultipartOptions;
use async_graphql::{
resolver_utils::ObjectType, Data, FieldResult, Request, Schema, SubscriptionType,
};
use futures::io::ErrorKind;
use futures::{select, TryStreamExt};
use futures::{SinkExt, StreamExt};
use futures::{future, StreamExt, TryStreamExt};
use hyper::Method;
use std::io;
use std::io::{self, ErrorKind};
use std::sync::Arc;
use warp::filters::ws::Message;
use warp::filters::ws;
use warp::filters::BoxedFilter;
use warp::reject::Reject;
use warp::reply::Response;
@ -169,81 +167,57 @@ where
/// ```
pub fn graphql_subscription<Query, Mutation, Subscription>(
schema: Schema<Query, Mutation, Subscription>,
) -> BoxedFilter<(impl Reply,)>
) -> impl Filter<Extract = (impl Reply,)> + Clone
where
Query: ObjectType + Sync + Send + 'static,
Mutation: ObjectType + Sync + Send + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
graphql_subscription_with_initializer(schema, |_| Ok(Default::default()))
graphql_subscription_with_data::<_, _, _, fn(serde_json::Value) -> FieldResult<Data>>(
schema, None,
)
}
/// GraphQL subscription filter
///
/// Specifies that a function converts the init payload to data.
pub fn graphql_subscription_with_initializer<Query, Mutation, Subscription, F>(
pub fn graphql_subscription_with_data<Query, Mutation, Subscription, F>(
schema: Schema<Query, Mutation, Subscription>,
initializer: F,
) -> BoxedFilter<(impl Reply,)>
initializer: Option<F>,
) -> impl Filter<Extract = (impl Reply,)> + Clone
where
Query: ObjectType + Sync + Send + 'static,
Mutation: ObjectType + Sync + Send + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
F: Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync + Clone + 'static,
F: FnOnce(serde_json::Value) -> FieldResult<Data> + Send + Sync + Clone + 'static,
{
warp::any()
.and(warp::ws())
.and(warp::any().map(move || schema.clone()))
.and(warp::any().map(move || initializer.clone()))
.map(
|ws: warp::ws::Ws, schema: Schema<Query, Mutation, Subscription>, initializer: F| {
|ws: ws::Ws, schema: Schema<Query, Mutation, Subscription>, initializer: Option<F>| {
ws.on_upgrade(move |websocket| {
let (mut tx, rx) = websocket.split();
let (mut stx, srx) =
async_graphql::http::WebSocketStream::new_with_initializer(
&schema,
initializer,
)
.split();
let mut rx = rx.fuse();
let srx = srx.fuse();
let (ws_sender, ws_receiver) = websocket.split();
async move {
futures::pin_mut!(srx);
loop {
select! {
bytes = srx.next() => {
if let Some(text) = bytes {
if tx.send(Message::text(text)).await.is_err() {
return;
}
} else {
return;
}
}
msg = rx.next() => {
if let Some(Ok(msg)) = msg {
if msg.is_text() {
if let Ok(text) = String::from_utf8(msg.into_bytes()) {
if stx.send(text).await.is_err() {
return;
}
}
}
} else {
return;
}
}
}
}
let _ = async_graphql::http::WebSocket::with_data(
schema,
ws_receiver
.take_while(|msg| future::ready(msg.is_ok()))
.map(Result::unwrap)
.map(ws::Message::into_bytes),
initializer,
)
.map(ws::Message::text)
.map(Ok)
.forward(ws_sender)
.await;
}
})
},
)
.map(|reply| warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws"))
.boxed()
}
/// GraphQL reply

View File

@ -10,7 +10,7 @@ pub use graphiql_source::graphiql_source;
#[cfg(feature = "multipart")]
pub use multipart::{receive_multipart, MultipartOptions};
pub use playground_source::{playground_source, GraphQLPlaygroundConfig};
pub use websocket::WebSocketStream;
pub use websocket::WebSocket;
use crate::{ParseRequestError, Request};
use futures::io::AsyncRead;

View File

@ -2,291 +2,207 @@
use crate::resolver_utils::ObjectType;
use crate::{Data, FieldResult, Request, Response, Schema, SubscriptionType};
use futures::channel::mpsc;
use futures::task::{Context, Poll};
use futures::{Future, Sink, SinkExt, Stream, StreamExt};
use futures::Stream;
use pin_project_lite::pin_project;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
#[derive(Serialize, Deserialize)]
struct OperationMessage<'a, T> {
#[serde(rename = "type")]
ty: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
payload: Option<T>,
}
type SubscriptionStreams = HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>;
type HandleRequestBoxFut = Pin<Box<dyn Future<Output = FieldResult<WSContext>> + Send>>;
type InitializerFn = Arc<dyn Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync>;
/// A wrapper around an underlying raw stream which implements the WebSocket protocol.
///
/// Only Text messages can be transmitted. You can use `futures::stream::StreamExt::split` function
/// to splits this object into separate Sink and Stream objects.
pub struct WebSocketStream {
tx: mpsc::UnboundedSender<String>,
rx: Pin<Box<dyn Stream<Item = String> + Send>>,
}
impl Sink<String> for WebSocketStream {
type Error = mpsc::SendError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_ready_unpin(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
self.tx.start_send(item)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_flush_unpin(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_close_unpin(cx)
pin_project! {
/// A GraphQL connection over websocket.
///
/// [Reference](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md).
pub struct WebSocket<S, F, Query, Mutation, Subscription> {
data_initializer: Option<F>,
data: Arc<Data>,
schema: Schema<Query, Mutation, Subscription>,
streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
#[pin]
stream: S,
}
}
impl Stream for WebSocketStream {
type Item = String;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_next_unpin(cx)
}
}
impl WebSocketStream {
/// Create a websocket transport.
pub fn new<Query, Mutation, Subscription>(
schema: &Schema<Query, Mutation, Subscription>,
) -> Self
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
Self::new_with_initializer(schema, |_| Ok(Default::default()))
}
/// Create a websocket transport and specify a context initialization function.
pub fn new_with_initializer<Query, Mutation, Subscription>(
schema: &Schema<Query, Mutation, Subscription>,
initializer: impl Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync + 'static,
) -> Self
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
let (tx, rx) = mpsc::unbounded();
WebSocketStream {
tx,
rx: SubscriptionStream {
schema: schema.clone(),
initializer: Arc::new(initializer),
rx_bytes: rx,
handle_request_fut: None,
ctx: Some(WSContext {
streams: Default::default(),
send_buf: Default::default(),
ctx_data: Arc::new(Data::default()),
}),
}
.boxed(),
impl<S, Query, Mutation, Subscription>
WebSocket<S, fn(serde_json::Value) -> FieldResult<Data>, Query, Mutation, Subscription>
{
/// Create a new websocket.
#[must_use]
pub fn new(schema: Schema<Query, Mutation, Subscription>, stream: S) -> Self {
Self {
data_initializer: None,
data: Arc::default(),
schema,
streams: HashMap::new(),
stream,
}
}
}
struct WSContext {
streams: SubscriptionStreams,
send_buf: VecDeque<String>,
ctx_data: Arc<Data>,
}
fn send_message<T: Serialize>(send_buf: &mut VecDeque<String>, msg: &T) {
if let Ok(data) = serde_json::to_string(msg) {
send_buf.push_back(data);
impl<S, F, Query, Mutation, Subscription> WebSocket<S, F, Query, Mutation, Subscription> {
/// Create a new websocket with a data initialization function.
///
/// This function, if present, will be called with the data sent by the client in the
/// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
/// From that point on the returned data will be accessible to all requests.
#[must_use]
pub fn with_data(
schema: Schema<Query, Mutation, Subscription>,
stream: S,
data_initializer: Option<F>,
) -> Self {
Self {
data_initializer,
data: Arc::default(),
schema,
streams: HashMap::new(),
stream,
}
}
}
#[allow(missing_docs)]
#[allow(clippy::type_complexity)]
struct SubscriptionStream<Query, Mutation, Subscription> {
schema: Schema<Query, Mutation, Subscription>,
initializer: InitializerFn,
rx_bytes: mpsc::UnboundedReceiver<String>,
handle_request_fut: Option<HandleRequestBoxFut>,
ctx: Option<WSContext>,
}
impl<'a, Query, Mutation, Subscription> Stream for SubscriptionStream<Query, Mutation, Subscription>
impl<S, F, Query, Mutation, Subscription> Stream for WebSocket<S, F, Query, Mutation, Subscription>
where
S: Stream,
S::Item: AsRef<[u8]>,
F: FnOnce(serde_json::Value) -> FieldResult<Data>,
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
type Item = String;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let this = self.project();
loop {
// receive bytes
if let Some(ctx) = &mut this.ctx {
if let Some(bytes) = ctx.send_buf.pop_front() {
return Poll::Ready(Some(bytes));
}
}
match this.stream.poll_next(cx) {
Poll::Ready(message) => {
let message = match message {
Some(message) => message,
None => return Poll::Ready(None),
};
if let Some(handle_request_fut) = &mut this.handle_request_fut {
match handle_request_fut.as_mut().poll(cx) {
Poll::Ready(Ok(ctx)) => {
this.ctx = Some(ctx);
this.handle_request_fut = None;
continue;
let message: ClientMessage = match serde_json::from_slice(message.as_ref()) {
Ok(message) => message,
Err(e) => {
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::ConnectionError {
payload: ConnectionError {
message: e.to_string(),
extensions: None,
},
})
.unwrap(),
))
}
Poll::Ready(Err(_)) => return Poll::Ready(None),
Poll::Pending => {}
}
} else {
match Pin::new(&mut this.rx_bytes).poll_next(cx) {
Poll::Ready(Some(data)) => {
let ctx = this.ctx.take().unwrap();
this.handle_request_fut = Some(Box::pin(handle_request(
this.schema.clone(),
this.initializer.clone(),
ctx,
data,
)));
continue;
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => {}
}
}
};
// receive msg
if let Some(ctx) = &mut this.ctx {
let mut closed = Vec::new();
for (id, incoming_stream) in ctx.streams.iter_mut() {
match incoming_stream.as_mut().poll_next(cx) {
Poll::Ready(Some(res)) => {
if let Some(err) = &res.error {
closed.push(id.to_string());
send_message(
&mut ctx.send_buf,
&OperationMessage {
ty: "error",
id: Some(id.to_string()),
payload: Some(err),
},
);
} else {
send_message(
&mut ctx.send_buf,
&OperationMessage {
ty: "data",
id: Some(id.to_string()),
payload: Some(&res),
},
);
match message {
ClientMessage::ConnectionInit { payload } => {
if let Some(payload) = payload {
if let Some(data_initializer) = this.data_initializer.take() {
*this.data = Arc::new(match data_initializer(payload) {
Ok(data) => data,
Err(e) => {
return Poll::Ready(Some(
serde_json::to_string(
&ServerMessage::ConnectionError {
payload: ConnectionError {
message: e.0,
extensions: e.1,
},
},
)
.unwrap(),
))
}
});
}
}
Poll::Ready(None) => {
closed.push(id.to_string());
send_message(
&mut ctx.send_buf,
&OperationMessage {
ty: "complete",
id: Some(id.to_string()),
payload: Option::<serde_json::Value>::None,
},
);
}
Poll::Pending => {}
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
));
}
}
for id in closed {
ctx.streams.remove(&id);
}
if !ctx.send_buf.is_empty() {
continue;
}
}
return Poll::Pending;
}
}
}
async fn handle_request<Query, Mutation, Subscription>(
schema: Schema<Query, Mutation, Subscription>,
initializer: InitializerFn,
mut ctx: WSContext,
data: String,
) -> FieldResult<WSContext>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
match serde_json::from_str::<OperationMessage<serde_json::Value>>(&data) {
Ok(msg) => match msg.ty {
"connection_init" => {
if let Some(payload) = msg.payload {
ctx.ctx_data = Arc::new(initializer(payload)?);
}
send_message(
&mut ctx.send_buf,
&OperationMessage {
ty: "connection_ack",
id: None,
payload: Option::<serde_json::Value>::None,
},
);
}
"start" => {
if let (Some(id), Some(payload)) = (msg.id, msg.payload) {
if let Ok(request) = serde_json::from_value::<Request>(payload) {
let stream = schema
.execute_stream_with_ctx_data(request, ctx.ctx_data.clone())
.boxed();
ctx.streams.insert(id, stream);
}
}
}
"stop" => {
if let Some(id) = msg.id {
if ctx.streams.remove(&id).is_some() {
send_message(
&mut ctx.send_buf,
&OperationMessage {
ty: "complete",
id: Some(id),
payload: Option::<serde_json::Value>::None,
},
ClientMessage::Start {
id,
payload: request,
} => {
this.streams.insert(
id,
Box::pin(
this.schema
.execute_stream_with_ctx_data(request, Arc::clone(this.data)),
),
);
}
ClientMessage::Stop { id } => {
if this.streams.remove(id).is_some() {
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::Complete { id }).unwrap(),
));
}
}
ClientMessage::ConnectionTerminate => return Poll::Ready(None),
}
}
"connection_terminate" => return Err("connection_terminate".into()),
_ => return Err("Unknown op".into()),
},
Err(err) => return Err(err.into()),
}
Poll::Pending => {}
}
Ok(ctx)
for (id, stream) in &mut *this.streams {
match Pin::new(stream).poll_next(cx) {
Poll::Ready(Some(payload)) => {
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::Data {
id,
payload: Box::new(payload),
})
.unwrap(),
));
}
Poll::Ready(None) => {
let id = id.clone();
this.streams.remove(&id);
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(),
));
}
Poll::Pending => {}
}
}
Poll::Pending
}
}
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ClientMessage<'a> {
ConnectionInit { payload: Option<serde_json::Value> },
Start { id: String, payload: Request },
Stop { id: &'a str },
ConnectionTerminate,
}
#[derive(Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ServerMessage<'a> {
ConnectionError { payload: ConnectionError },
ConnectionAck,
Data { id: &'a str, payload: Box<Response> },
// Not used by this library, as it's not necessary to send
// Error {
// id: &'a str,
// payload: serde_json::Value,
// },
Complete { id: &'a str },
// Not used by this library
// #[serde(rename = "ka")]
// KeepAlive
}
#[derive(Serialize)]
struct ConnectionError {
message: String,
extensions: Option<serde_json::Value>,
}

View File

@ -436,9 +436,9 @@ where
pub(crate) fn execute_stream_with_ctx_data(
&self,
request: impl Into<Request>,
request: impl Into<Request> + Send,
ctx_data: Arc<Data>,
) -> impl Stream<Item = Response> {
) -> impl Stream<Item = Response> + Send {
let schema = self.clone();
async_stream::stream! {
@ -495,7 +495,10 @@ where
}
/// Execute an GraphQL subscription.
pub fn execute_stream(&self, request: impl Into<Request>) -> impl Stream<Item = Response> {
pub fn execute_stream(
&self,
request: impl Into<Request>,
) -> impl Stream<Item = Response> + Send {
let mut request = request.into();
let ctx_data = std::mem::take(&mut request.data);
self.execute_stream_with_ctx_data(request, Arc::new(ctx_data))

View File

@ -1,4 +1,5 @@
use async_graphql::*;
use futures::channel::mpsc;
use futures::{SinkExt, Stream, StreamExt};
#[async_std::test]
@ -18,56 +19,55 @@ pub async fn test_subscription_ws_transport() {
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let mut stream = http::WebSocketStream::new(&schema);
let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::new(schema, rx);
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "connection_init",
}))
.unwrap(),
)
.await
.unwrap();
tx.send(
serde_json::to_string(&serde_json::json!({
"type": "connection_init",
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
serde_json::json!({
"type": "connection_ack",
}),
);
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap(),
)
.await
.unwrap();
tx.send(
serde_json::to_string(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "values": i } },
}),
);
}
assert_eq!(
Some(serde_json::json!({
"type": "complete",
"id": "1",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
serde_json::json!({
"type": "complete",
"id": "1",
}),
);
}
@ -93,56 +93,59 @@ pub async fn test_subscription_ws_transport_with_token() {
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let mut stream = http::WebSocketStream::new_with_initializer(&schema, |value| {
#[derive(serde::Deserialize)]
struct Payload {
token: String,
}
let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::with_data(
schema,
rx,
Some(|value| {
#[derive(serde::Deserialize)]
struct Payload {
token: String,
}
let payload: Payload = serde_json::from_value(value).unwrap();
let mut data = Data::default();
data.insert(Token(payload.token));
Ok(data)
});
let payload: Payload = serde_json::from_value(value).unwrap();
let mut data = Data::default();
data.insert(Token(payload.token));
Ok(data)
}),
);
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "connection_init",
"payload": { "token": "123456" }
}))
.unwrap(),
)
.await
.unwrap();
tx.send(
serde_json::to_string(&serde_json::json!({
"type": "connection_init",
"payload": { "token": "123456" }
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap(),
)
.await
.unwrap();
tx.send(
serde_json::to_string(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "values": i } },
"type": "data",
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
@ -150,8 +153,8 @@ pub async fn test_subscription_ws_transport_with_token() {
assert_eq!(
Some(serde_json::json!({
"type": "complete",
"id": "1",
"type": "complete",
"id": "1",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
@ -189,45 +192,44 @@ pub async fn test_subscription_ws_transport_error() {
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let mut stream = http::WebSocketStream::new(&schema);
let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::new(schema, rx);
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "connection_init"
}))
.unwrap(),
)
.await
.unwrap();
tx.send(
serde_json::to_string(&serde_json::json!({
"type": "connection_init"
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { events { value } }"
},
}))
.unwrap(),
)
.await
.unwrap();
tx.send(
serde_json::to_string(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { events { value } }"
},
}))
.unwrap(),
)
.await
.unwrap();
for i in 0i32..5 {
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "events": { "value": i } } },
"type": "data",
"id": "1",
"payload": { "data": { "events": { "value": i } } },
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
@ -235,13 +237,15 @@ pub async fn test_subscription_ws_transport_error() {
assert_eq!(
Some(serde_json::json!({
"type": "error",
"id": "1",
"payload": [{
"message": "TestError",
"locations": [{"line": 1, "column": 25}],
"path": ["events", "value"],
}],
"type": "data",
"id": "1",
"payload": {
"errors": [{
"message": "TestError",
"locations": [{"line": 1, "column": 25}],
"path": ["events", "value"],
}],
},
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
@ -259,17 +263,17 @@ pub async fn test_query_over_websocket() {
}
let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription);
let mut stream = http::WebSocketStream::new(&schema);
let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::new(schema, rx);
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "connection_init",
}))
.unwrap(),
)
.await
.unwrap();
tx.send(
serde_json::to_string(&serde_json::json!({
"type": "connection_init",
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
@ -278,19 +282,18 @@ pub async fn test_query_over_websocket() {
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "query { value }"
},
}))
.unwrap(),
)
.await
.unwrap();
tx.send(
serde_json::to_string(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "query { value }"
},
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({