Improve websockets
This commit is contained in:
parent
92de2e5816
commit
890b282dd9
|
@ -16,8 +16,8 @@ categories = ["network-programming", "asynchronous"]
|
|||
async-graphql = { path = "../..", version = "2.0.0-alpha.8" }
|
||||
actix-web = "3.0.0"
|
||||
actix-web-actors = "3.0.0"
|
||||
actix-http = "2.0.0"
|
||||
actix = "0.10.0"
|
||||
actix-rt = "1.1.0"
|
||||
futures = "0.3.0"
|
||||
bytes = "0.5.4"
|
||||
futures = "0.3.5"
|
||||
serde_json = "1.0.48"
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
use actix::{
|
||||
Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture,
|
||||
Actor, ActorContext, ActorFuture, ActorStream, AsyncContext, ContextFutureSpawner,
|
||||
StreamHandler, WrapFuture, WrapStream,
|
||||
};
|
||||
use actix_http::ws;
|
||||
use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext};
|
||||
use async_graphql::http::WebSocketStream;
|
||||
use async_graphql::http::WebSocket;
|
||||
use async_graphql::{resolver_utils::ObjectType, Data, FieldResult, Schema, SubscriptionType};
|
||||
use futures::stream::SplitSink;
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use futures::channel::mpsc;
|
||||
use futures::SinkExt;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
@ -13,10 +15,11 @@ const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
|
|||
|
||||
/// Actor for subscription via websocket
|
||||
pub struct WSSubscription<Query, Mutation, Subscription> {
|
||||
schema: Schema<Query, Mutation, Subscription>,
|
||||
hb: Instant,
|
||||
sink: Option<SplitSink<WebSocketStream, String>>,
|
||||
initializer: Option<Box<dyn Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync>>,
|
||||
schema: Option<Schema<Query, Mutation, Subscription>>,
|
||||
last_heartbeat: Instant,
|
||||
messages: Option<mpsc::UnboundedSender<Vec<u8>>>,
|
||||
initializer: Option<Box<dyn FnOnce(serde_json::Value) -> FieldResult<Data> + Send + Sync>>,
|
||||
continuation: Vec<u8>,
|
||||
}
|
||||
|
||||
impl<Query, Mutation, Subscription> WSSubscription<Query, Mutation, Subscription>
|
||||
|
@ -26,19 +29,20 @@ where
|
|||
Subscription: SubscriptionType + Send + Sync + 'static,
|
||||
{
|
||||
/// Create an actor for subscription connection via websocket.
|
||||
pub fn new(schema: &Schema<Query, Mutation, Subscription>) -> Self {
|
||||
pub fn new(schema: Schema<Query, Mutation, Subscription>) -> Self {
|
||||
Self {
|
||||
schema: schema.clone(),
|
||||
hb: Instant::now(),
|
||||
sink: None,
|
||||
schema: Some(schema),
|
||||
last_heartbeat: Instant::now(),
|
||||
messages: None,
|
||||
initializer: None,
|
||||
continuation: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a context data initialization function.
|
||||
pub fn initializer<F>(self, f: F) -> Self
|
||||
where
|
||||
F: Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync + 'static,
|
||||
F: FnOnce(serde_json::Value) -> FieldResult<Data> + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
initializer: Some(Box::new(f)),
|
||||
|
@ -46,9 +50,9 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
fn hb(&self, ctx: &mut WebsocketContext<Self>) {
|
||||
fn send_heartbeats(&self, ctx: &mut WebsocketContext<Self>) {
|
||||
ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
|
||||
if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT {
|
||||
if Instant::now().duration_since(act.last_heartbeat) > CLIENT_TIMEOUT {
|
||||
ctx.stop();
|
||||
}
|
||||
ctx.ping(b"");
|
||||
|
@ -65,20 +69,19 @@ where
|
|||
type Context = WebsocketContext<Self>;
|
||||
|
||||
fn started(&mut self, ctx: &mut Self::Context) {
|
||||
self.hb(ctx);
|
||||
if let Some(initializer) = self.initializer.take() {
|
||||
let (sink, stream) = async_graphql::http::WebSocketStream::new_with_initializer(
|
||||
&self.schema,
|
||||
initializer,
|
||||
)
|
||||
.split();
|
||||
ctx.add_stream(stream);
|
||||
self.sink = Some(sink);
|
||||
} else {
|
||||
let (sink, stream) = async_graphql::http::WebSocketStream::new(&self.schema).split();
|
||||
ctx.add_stream(stream);
|
||||
self.sink = Some(sink);
|
||||
};
|
||||
self.send_heartbeats(ctx);
|
||||
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
|
||||
WebSocket::with_data(self.schema.take().unwrap(), rx, self.initializer.take())
|
||||
.into_actor(self)
|
||||
.map(|response, _act, ctx| {
|
||||
ctx.text(response);
|
||||
})
|
||||
.finish()
|
||||
.spawn(ctx);
|
||||
|
||||
self.messages = Some(tx);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -98,47 +101,49 @@ where
|
|||
Ok(msg) => msg,
|
||||
};
|
||||
|
||||
match msg {
|
||||
let message = match msg {
|
||||
Message::Ping(msg) => {
|
||||
self.hb = Instant::now();
|
||||
self.last_heartbeat = Instant::now();
|
||||
ctx.pong(&msg);
|
||||
None
|
||||
}
|
||||
Message::Pong(_) => {
|
||||
self.hb = Instant::now();
|
||||
self.last_heartbeat = Instant::now();
|
||||
None
|
||||
}
|
||||
Message::Text(s) => {
|
||||
if let Some(mut sink) = self.sink.take() {
|
||||
async move {
|
||||
let res = sink.send(s).await;
|
||||
res.map(|_| sink)
|
||||
}
|
||||
.into_actor(self)
|
||||
.then(|res, actor, ctx| {
|
||||
match res {
|
||||
Ok(sink) => actor.sink = Some(sink),
|
||||
Err(_) => ctx.stop(),
|
||||
}
|
||||
async {}.into_actor(actor)
|
||||
})
|
||||
.wait(ctx);
|
||||
Message::Continuation(item) => match item {
|
||||
ws::Item::FirstText(bytes) | ws::Item::FirstBinary(bytes) => {
|
||||
self.continuation = bytes.to_vec();
|
||||
None
|
||||
}
|
||||
}
|
||||
Message::Binary(_) | Message::Close(_) | Message::Continuation(_) => {
|
||||
ws::Item::Continue(bytes) => {
|
||||
self.continuation.extend_from_slice(&bytes);
|
||||
None
|
||||
}
|
||||
ws::Item::Last(bytes) => {
|
||||
self.continuation.extend_from_slice(&bytes);
|
||||
Some(std::mem::take(&mut self.continuation))
|
||||
}
|
||||
},
|
||||
Message::Text(s) => Some(s.into_bytes()),
|
||||
Message::Binary(bytes) => Some(bytes.to_vec()),
|
||||
Message::Close(_) => {
|
||||
ctx.stop();
|
||||
None
|
||||
}
|
||||
Message::Nop => {}
|
||||
Message::Nop => None,
|
||||
};
|
||||
|
||||
if let Some(message) = message {
|
||||
let mut sender = self.messages.as_ref().unwrap().clone();
|
||||
|
||||
async move { sender.send(message).await }
|
||||
.into_actor(self)
|
||||
.map(|res, _actor, ctx| match res {
|
||||
Ok(()) => {}
|
||||
Err(_) => ctx.stop(),
|
||||
})
|
||||
.spawn(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Query, Mutation, Subscription> StreamHandler<String>
|
||||
for WSSubscription<Query, Mutation, Subscription>
|
||||
where
|
||||
Query: ObjectType + Send + Sync + 'static,
|
||||
Mutation: ObjectType + Send + Sync + 'static,
|
||||
Subscription: SubscriptionType + Send + Sync + 'static,
|
||||
{
|
||||
fn handle(&mut self, data: String, ctx: &mut Self::Context) {
|
||||
ctx.text(data);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
//! Async-graphql integration with Wrap
|
||||
//! Async-graphql integration with Warp
|
||||
|
||||
#![warn(missing_docs)]
|
||||
#![allow(clippy::type_complexity)]
|
||||
|
@ -9,13 +9,11 @@ use async_graphql::http::MultipartOptions;
|
|||
use async_graphql::{
|
||||
resolver_utils::ObjectType, Data, FieldResult, Request, Schema, SubscriptionType,
|
||||
};
|
||||
use futures::io::ErrorKind;
|
||||
use futures::{select, TryStreamExt};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use futures::{future, StreamExt, TryStreamExt};
|
||||
use hyper::Method;
|
||||
use std::io;
|
||||
use std::io::{self, ErrorKind};
|
||||
use std::sync::Arc;
|
||||
use warp::filters::ws::Message;
|
||||
use warp::filters::ws;
|
||||
use warp::filters::BoxedFilter;
|
||||
use warp::reject::Reject;
|
||||
use warp::reply::Response;
|
||||
|
@ -169,81 +167,57 @@ where
|
|||
/// ```
|
||||
pub fn graphql_subscription<Query, Mutation, Subscription>(
|
||||
schema: Schema<Query, Mutation, Subscription>,
|
||||
) -> BoxedFilter<(impl Reply,)>
|
||||
) -> impl Filter<Extract = (impl Reply,)> + Clone
|
||||
where
|
||||
Query: ObjectType + Sync + Send + 'static,
|
||||
Mutation: ObjectType + Sync + Send + 'static,
|
||||
Subscription: SubscriptionType + Send + Sync + 'static,
|
||||
{
|
||||
graphql_subscription_with_initializer(schema, |_| Ok(Default::default()))
|
||||
graphql_subscription_with_data::<_, _, _, fn(serde_json::Value) -> FieldResult<Data>>(
|
||||
schema, None,
|
||||
)
|
||||
}
|
||||
|
||||
/// GraphQL subscription filter
|
||||
///
|
||||
/// Specifies that a function converts the init payload to data.
|
||||
pub fn graphql_subscription_with_initializer<Query, Mutation, Subscription, F>(
|
||||
pub fn graphql_subscription_with_data<Query, Mutation, Subscription, F>(
|
||||
schema: Schema<Query, Mutation, Subscription>,
|
||||
initializer: F,
|
||||
) -> BoxedFilter<(impl Reply,)>
|
||||
initializer: Option<F>,
|
||||
) -> impl Filter<Extract = (impl Reply,)> + Clone
|
||||
where
|
||||
Query: ObjectType + Sync + Send + 'static,
|
||||
Mutation: ObjectType + Sync + Send + 'static,
|
||||
Subscription: SubscriptionType + Send + Sync + 'static,
|
||||
F: Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync + Clone + 'static,
|
||||
F: FnOnce(serde_json::Value) -> FieldResult<Data> + Send + Sync + Clone + 'static,
|
||||
{
|
||||
warp::any()
|
||||
.and(warp::ws())
|
||||
.and(warp::any().map(move || schema.clone()))
|
||||
.and(warp::any().map(move || initializer.clone()))
|
||||
.map(
|
||||
|ws: warp::ws::Ws, schema: Schema<Query, Mutation, Subscription>, initializer: F| {
|
||||
|ws: ws::Ws, schema: Schema<Query, Mutation, Subscription>, initializer: Option<F>| {
|
||||
ws.on_upgrade(move |websocket| {
|
||||
let (mut tx, rx) = websocket.split();
|
||||
let (mut stx, srx) =
|
||||
async_graphql::http::WebSocketStream::new_with_initializer(
|
||||
&schema,
|
||||
initializer,
|
||||
)
|
||||
.split();
|
||||
|
||||
let mut rx = rx.fuse();
|
||||
let srx = srx.fuse();
|
||||
let (ws_sender, ws_receiver) = websocket.split();
|
||||
|
||||
async move {
|
||||
futures::pin_mut!(srx);
|
||||
|
||||
loop {
|
||||
select! {
|
||||
bytes = srx.next() => {
|
||||
if let Some(text) = bytes {
|
||||
if tx.send(Message::text(text)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
msg = rx.next() => {
|
||||
if let Some(Ok(msg)) = msg {
|
||||
if msg.is_text() {
|
||||
if let Ok(text) = String::from_utf8(msg.into_bytes()) {
|
||||
if stx.send(text).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = async_graphql::http::WebSocket::with_data(
|
||||
schema,
|
||||
ws_receiver
|
||||
.take_while(|msg| future::ready(msg.is_ok()))
|
||||
.map(Result::unwrap)
|
||||
.map(ws::Message::into_bytes),
|
||||
initializer,
|
||||
)
|
||||
.map(ws::Message::text)
|
||||
.map(Ok)
|
||||
.forward(ws_sender)
|
||||
.await;
|
||||
}
|
||||
})
|
||||
},
|
||||
)
|
||||
.map(|reply| warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws"))
|
||||
.boxed()
|
||||
}
|
||||
|
||||
/// GraphQL reply
|
||||
|
|
|
@ -10,7 +10,7 @@ pub use graphiql_source::graphiql_source;
|
|||
#[cfg(feature = "multipart")]
|
||||
pub use multipart::{receive_multipart, MultipartOptions};
|
||||
pub use playground_source::{playground_source, GraphQLPlaygroundConfig};
|
||||
pub use websocket::WebSocketStream;
|
||||
pub use websocket::WebSocket;
|
||||
|
||||
use crate::{ParseRequestError, Request};
|
||||
use futures::io::AsyncRead;
|
||||
|
|
|
@ -2,291 +2,207 @@
|
|||
|
||||
use crate::resolver_utils::ObjectType;
|
||||
use crate::{Data, FieldResult, Request, Response, Schema, SubscriptionType};
|
||||
use futures::channel::mpsc;
|
||||
use futures::task::{Context, Poll};
|
||||
use futures::{Future, Sink, SinkExt, Stream, StreamExt};
|
||||
use futures::Stream;
|
||||
use pin_project_lite::pin_project;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct OperationMessage<'a, T> {
|
||||
#[serde(rename = "type")]
|
||||
ty: &'a str,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
payload: Option<T>,
|
||||
}
|
||||
|
||||
type SubscriptionStreams = HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>;
|
||||
|
||||
type HandleRequestBoxFut = Pin<Box<dyn Future<Output = FieldResult<WSContext>> + Send>>;
|
||||
|
||||
type InitializerFn = Arc<dyn Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync>;
|
||||
|
||||
/// A wrapper around an underlying raw stream which implements the WebSocket protocol.
|
||||
///
|
||||
/// Only Text messages can be transmitted. You can use `futures::stream::StreamExt::split` function
|
||||
/// to splits this object into separate Sink and Stream objects.
|
||||
pub struct WebSocketStream {
|
||||
tx: mpsc::UnboundedSender<String>,
|
||||
rx: Pin<Box<dyn Stream<Item = String> + Send>>,
|
||||
}
|
||||
|
||||
impl Sink<String> for WebSocketStream {
|
||||
type Error = mpsc::SendError;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.tx.poll_ready_unpin(cx)
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
|
||||
self.tx.start_send(item)
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.tx.poll_flush_unpin(cx)
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
self.tx.poll_close_unpin(cx)
|
||||
pin_project! {
|
||||
/// A GraphQL connection over websocket.
|
||||
///
|
||||
/// [Reference](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md).
|
||||
pub struct WebSocket<S, F, Query, Mutation, Subscription> {
|
||||
data_initializer: Option<F>,
|
||||
data: Arc<Data>,
|
||||
schema: Schema<Query, Mutation, Subscription>,
|
||||
streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
|
||||
#[pin]
|
||||
stream: S,
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for WebSocketStream {
|
||||
type Item = String;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
self.rx.poll_next_unpin(cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl WebSocketStream {
|
||||
/// Create a websocket transport.
|
||||
pub fn new<Query, Mutation, Subscription>(
|
||||
schema: &Schema<Query, Mutation, Subscription>,
|
||||
) -> Self
|
||||
where
|
||||
Query: ObjectType + Send + Sync + 'static,
|
||||
Mutation: ObjectType + Send + Sync + 'static,
|
||||
Subscription: SubscriptionType + Send + Sync + 'static,
|
||||
{
|
||||
Self::new_with_initializer(schema, |_| Ok(Default::default()))
|
||||
}
|
||||
|
||||
/// Create a websocket transport and specify a context initialization function.
|
||||
pub fn new_with_initializer<Query, Mutation, Subscription>(
|
||||
schema: &Schema<Query, Mutation, Subscription>,
|
||||
initializer: impl Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync + 'static,
|
||||
) -> Self
|
||||
where
|
||||
Query: ObjectType + Send + Sync + 'static,
|
||||
Mutation: ObjectType + Send + Sync + 'static,
|
||||
Subscription: SubscriptionType + Send + Sync + 'static,
|
||||
{
|
||||
let (tx, rx) = mpsc::unbounded();
|
||||
WebSocketStream {
|
||||
tx,
|
||||
rx: SubscriptionStream {
|
||||
schema: schema.clone(),
|
||||
initializer: Arc::new(initializer),
|
||||
rx_bytes: rx,
|
||||
handle_request_fut: None,
|
||||
ctx: Some(WSContext {
|
||||
streams: Default::default(),
|
||||
send_buf: Default::default(),
|
||||
ctx_data: Arc::new(Data::default()),
|
||||
}),
|
||||
}
|
||||
.boxed(),
|
||||
impl<S, Query, Mutation, Subscription>
|
||||
WebSocket<S, fn(serde_json::Value) -> FieldResult<Data>, Query, Mutation, Subscription>
|
||||
{
|
||||
/// Create a new websocket.
|
||||
#[must_use]
|
||||
pub fn new(schema: Schema<Query, Mutation, Subscription>, stream: S) -> Self {
|
||||
Self {
|
||||
data_initializer: None,
|
||||
data: Arc::default(),
|
||||
schema,
|
||||
streams: HashMap::new(),
|
||||
stream,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct WSContext {
|
||||
streams: SubscriptionStreams,
|
||||
send_buf: VecDeque<String>,
|
||||
ctx_data: Arc<Data>,
|
||||
}
|
||||
|
||||
fn send_message<T: Serialize>(send_buf: &mut VecDeque<String>, msg: &T) {
|
||||
if let Ok(data) = serde_json::to_string(msg) {
|
||||
send_buf.push_back(data);
|
||||
impl<S, F, Query, Mutation, Subscription> WebSocket<S, F, Query, Mutation, Subscription> {
|
||||
/// Create a new websocket with a data initialization function.
|
||||
///
|
||||
/// This function, if present, will be called with the data sent by the client in the
|
||||
/// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
|
||||
/// From that point on the returned data will be accessible to all requests.
|
||||
#[must_use]
|
||||
pub fn with_data(
|
||||
schema: Schema<Query, Mutation, Subscription>,
|
||||
stream: S,
|
||||
data_initializer: Option<F>,
|
||||
) -> Self {
|
||||
Self {
|
||||
data_initializer,
|
||||
data: Arc::default(),
|
||||
schema,
|
||||
streams: HashMap::new(),
|
||||
stream,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[allow(clippy::type_complexity)]
|
||||
struct SubscriptionStream<Query, Mutation, Subscription> {
|
||||
schema: Schema<Query, Mutation, Subscription>,
|
||||
initializer: InitializerFn,
|
||||
rx_bytes: mpsc::UnboundedReceiver<String>,
|
||||
handle_request_fut: Option<HandleRequestBoxFut>,
|
||||
ctx: Option<WSContext>,
|
||||
}
|
||||
|
||||
impl<'a, Query, Mutation, Subscription> Stream for SubscriptionStream<Query, Mutation, Subscription>
|
||||
impl<S, F, Query, Mutation, Subscription> Stream for WebSocket<S, F, Query, Mutation, Subscription>
|
||||
where
|
||||
S: Stream,
|
||||
S::Item: AsRef<[u8]>,
|
||||
F: FnOnce(serde_json::Value) -> FieldResult<Data>,
|
||||
Query: ObjectType + Send + Sync + 'static,
|
||||
Mutation: ObjectType + Send + Sync + 'static,
|
||||
Subscription: SubscriptionType + Send + Sync + 'static,
|
||||
{
|
||||
type Item = String;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
let this = &mut *self;
|
||||
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
|
||||
let this = self.project();
|
||||
|
||||
loop {
|
||||
// receive bytes
|
||||
if let Some(ctx) = &mut this.ctx {
|
||||
if let Some(bytes) = ctx.send_buf.pop_front() {
|
||||
return Poll::Ready(Some(bytes));
|
||||
}
|
||||
}
|
||||
match this.stream.poll_next(cx) {
|
||||
Poll::Ready(message) => {
|
||||
let message = match message {
|
||||
Some(message) => message,
|
||||
None => return Poll::Ready(None),
|
||||
};
|
||||
|
||||
if let Some(handle_request_fut) = &mut this.handle_request_fut {
|
||||
match handle_request_fut.as_mut().poll(cx) {
|
||||
Poll::Ready(Ok(ctx)) => {
|
||||
this.ctx = Some(ctx);
|
||||
this.handle_request_fut = None;
|
||||
continue;
|
||||
let message: ClientMessage = match serde_json::from_slice(message.as_ref()) {
|
||||
Ok(message) => message,
|
||||
Err(e) => {
|
||||
return Poll::Ready(Some(
|
||||
serde_json::to_string(&ServerMessage::ConnectionError {
|
||||
payload: ConnectionError {
|
||||
message: e.to_string(),
|
||||
extensions: None,
|
||||
},
|
||||
})
|
||||
.unwrap(),
|
||||
))
|
||||
}
|
||||
Poll::Ready(Err(_)) => return Poll::Ready(None),
|
||||
Poll::Pending => {}
|
||||
}
|
||||
} else {
|
||||
match Pin::new(&mut this.rx_bytes).poll_next(cx) {
|
||||
Poll::Ready(Some(data)) => {
|
||||
let ctx = this.ctx.take().unwrap();
|
||||
this.handle_request_fut = Some(Box::pin(handle_request(
|
||||
this.schema.clone(),
|
||||
this.initializer.clone(),
|
||||
ctx,
|
||||
data,
|
||||
)));
|
||||
continue;
|
||||
}
|
||||
Poll::Ready(None) => return Poll::Ready(None),
|
||||
Poll::Pending => {}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// receive msg
|
||||
if let Some(ctx) = &mut this.ctx {
|
||||
let mut closed = Vec::new();
|
||||
|
||||
for (id, incoming_stream) in ctx.streams.iter_mut() {
|
||||
match incoming_stream.as_mut().poll_next(cx) {
|
||||
Poll::Ready(Some(res)) => {
|
||||
if let Some(err) = &res.error {
|
||||
closed.push(id.to_string());
|
||||
send_message(
|
||||
&mut ctx.send_buf,
|
||||
&OperationMessage {
|
||||
ty: "error",
|
||||
id: Some(id.to_string()),
|
||||
payload: Some(err),
|
||||
},
|
||||
);
|
||||
} else {
|
||||
send_message(
|
||||
&mut ctx.send_buf,
|
||||
&OperationMessage {
|
||||
ty: "data",
|
||||
id: Some(id.to_string()),
|
||||
payload: Some(&res),
|
||||
},
|
||||
);
|
||||
match message {
|
||||
ClientMessage::ConnectionInit { payload } => {
|
||||
if let Some(payload) = payload {
|
||||
if let Some(data_initializer) = this.data_initializer.take() {
|
||||
*this.data = Arc::new(match data_initializer(payload) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
return Poll::Ready(Some(
|
||||
serde_json::to_string(
|
||||
&ServerMessage::ConnectionError {
|
||||
payload: ConnectionError {
|
||||
message: e.0,
|
||||
extensions: e.1,
|
||||
},
|
||||
},
|
||||
)
|
||||
.unwrap(),
|
||||
))
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
closed.push(id.to_string());
|
||||
send_message(
|
||||
&mut ctx.send_buf,
|
||||
&OperationMessage {
|
||||
ty: "complete",
|
||||
id: Some(id.to_string()),
|
||||
payload: Option::<serde_json::Value>::None,
|
||||
},
|
||||
);
|
||||
}
|
||||
Poll::Pending => {}
|
||||
return Poll::Ready(Some(
|
||||
serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
for id in closed {
|
||||
ctx.streams.remove(&id);
|
||||
}
|
||||
|
||||
if !ctx.send_buf.is_empty() {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return Poll::Pending;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_request<Query, Mutation, Subscription>(
|
||||
schema: Schema<Query, Mutation, Subscription>,
|
||||
initializer: InitializerFn,
|
||||
mut ctx: WSContext,
|
||||
data: String,
|
||||
) -> FieldResult<WSContext>
|
||||
where
|
||||
Query: ObjectType + Send + Sync + 'static,
|
||||
Mutation: ObjectType + Send + Sync + 'static,
|
||||
Subscription: SubscriptionType + Send + Sync + 'static,
|
||||
{
|
||||
match serde_json::from_str::<OperationMessage<serde_json::Value>>(&data) {
|
||||
Ok(msg) => match msg.ty {
|
||||
"connection_init" => {
|
||||
if let Some(payload) = msg.payload {
|
||||
ctx.ctx_data = Arc::new(initializer(payload)?);
|
||||
}
|
||||
send_message(
|
||||
&mut ctx.send_buf,
|
||||
&OperationMessage {
|
||||
ty: "connection_ack",
|
||||
id: None,
|
||||
payload: Option::<serde_json::Value>::None,
|
||||
},
|
||||
);
|
||||
}
|
||||
"start" => {
|
||||
if let (Some(id), Some(payload)) = (msg.id, msg.payload) {
|
||||
if let Ok(request) = serde_json::from_value::<Request>(payload) {
|
||||
let stream = schema
|
||||
.execute_stream_with_ctx_data(request, ctx.ctx_data.clone())
|
||||
.boxed();
|
||||
ctx.streams.insert(id, stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
"stop" => {
|
||||
if let Some(id) = msg.id {
|
||||
if ctx.streams.remove(&id).is_some() {
|
||||
send_message(
|
||||
&mut ctx.send_buf,
|
||||
&OperationMessage {
|
||||
ty: "complete",
|
||||
id: Some(id),
|
||||
payload: Option::<serde_json::Value>::None,
|
||||
},
|
||||
ClientMessage::Start {
|
||||
id,
|
||||
payload: request,
|
||||
} => {
|
||||
this.streams.insert(
|
||||
id,
|
||||
Box::pin(
|
||||
this.schema
|
||||
.execute_stream_with_ctx_data(request, Arc::clone(this.data)),
|
||||
),
|
||||
);
|
||||
}
|
||||
ClientMessage::Stop { id } => {
|
||||
if this.streams.remove(id).is_some() {
|
||||
return Poll::Ready(Some(
|
||||
serde_json::to_string(&ServerMessage::Complete { id }).unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
ClientMessage::ConnectionTerminate => return Poll::Ready(None),
|
||||
}
|
||||
}
|
||||
"connection_terminate" => return Err("connection_terminate".into()),
|
||||
_ => return Err("Unknown op".into()),
|
||||
},
|
||||
Err(err) => return Err(err.into()),
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
|
||||
Ok(ctx)
|
||||
for (id, stream) in &mut *this.streams {
|
||||
match Pin::new(stream).poll_next(cx) {
|
||||
Poll::Ready(Some(payload)) => {
|
||||
return Poll::Ready(Some(
|
||||
serde_json::to_string(&ServerMessage::Data {
|
||||
id,
|
||||
payload: Box::new(payload),
|
||||
})
|
||||
.unwrap(),
|
||||
));
|
||||
}
|
||||
Poll::Ready(None) => {
|
||||
let id = id.clone();
|
||||
this.streams.remove(&id);
|
||||
return Poll::Ready(Some(
|
||||
serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(),
|
||||
));
|
||||
}
|
||||
Poll::Pending => {}
|
||||
}
|
||||
}
|
||||
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum ClientMessage<'a> {
|
||||
ConnectionInit { payload: Option<serde_json::Value> },
|
||||
Start { id: String, payload: Request },
|
||||
Stop { id: &'a str },
|
||||
ConnectionTerminate,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
enum ServerMessage<'a> {
|
||||
ConnectionError { payload: ConnectionError },
|
||||
ConnectionAck,
|
||||
Data { id: &'a str, payload: Box<Response> },
|
||||
// Not used by this library, as it's not necessary to send
|
||||
// Error {
|
||||
// id: &'a str,
|
||||
// payload: serde_json::Value,
|
||||
// },
|
||||
Complete { id: &'a str },
|
||||
// Not used by this library
|
||||
// #[serde(rename = "ka")]
|
||||
// KeepAlive
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ConnectionError {
|
||||
message: String,
|
||||
extensions: Option<serde_json::Value>,
|
||||
}
|
||||
|
|
|
@ -436,9 +436,9 @@ where
|
|||
|
||||
pub(crate) fn execute_stream_with_ctx_data(
|
||||
&self,
|
||||
request: impl Into<Request>,
|
||||
request: impl Into<Request> + Send,
|
||||
ctx_data: Arc<Data>,
|
||||
) -> impl Stream<Item = Response> {
|
||||
) -> impl Stream<Item = Response> + Send {
|
||||
let schema = self.clone();
|
||||
|
||||
async_stream::stream! {
|
||||
|
@ -495,7 +495,10 @@ where
|
|||
}
|
||||
|
||||
/// Execute an GraphQL subscription.
|
||||
pub fn execute_stream(&self, request: impl Into<Request>) -> impl Stream<Item = Response> {
|
||||
pub fn execute_stream(
|
||||
&self,
|
||||
request: impl Into<Request>,
|
||||
) -> impl Stream<Item = Response> + Send {
|
||||
let mut request = request.into();
|
||||
let ctx_data = std::mem::take(&mut request.data);
|
||||
self.execute_stream_with_ctx_data(request, Arc::new(ctx_data))
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use async_graphql::*;
|
||||
use futures::channel::mpsc;
|
||||
use futures::{SinkExt, Stream, StreamExt};
|
||||
|
||||
#[async_std::test]
|
||||
|
@ -18,56 +19,55 @@ pub async fn test_subscription_ws_transport() {
|
|||
}
|
||||
|
||||
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
|
||||
let mut stream = http::WebSocketStream::new(&schema);
|
||||
let (mut tx, rx) = mpsc::unbounded();
|
||||
let mut stream = http::WebSocket::new(schema, rx);
|
||||
|
||||
stream
|
||||
.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "connection_init",
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "connection_init",
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
"type": "connection_ack",
|
||||
})),
|
||||
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
|
||||
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
|
||||
serde_json::json!({
|
||||
"type": "connection_ack",
|
||||
}),
|
||||
);
|
||||
|
||||
stream
|
||||
.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "start",
|
||||
"id": "1",
|
||||
"payload": {
|
||||
"query": "subscription { values }"
|
||||
},
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "start",
|
||||
"id": "1",
|
||||
"payload": {
|
||||
"query": "subscription { values }"
|
||||
},
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
for i in 0..10 {
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
"type": "data",
|
||||
"id": "1",
|
||||
"payload": { "data": { "values": i } },
|
||||
})),
|
||||
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
|
||||
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
|
||||
serde_json::json!({
|
||||
"type": "data",
|
||||
"id": "1",
|
||||
"payload": { "data": { "values": i } },
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
"type": "complete",
|
||||
"id": "1",
|
||||
})),
|
||||
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
|
||||
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
|
||||
serde_json::json!({
|
||||
"type": "complete",
|
||||
"id": "1",
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -93,56 +93,59 @@ pub async fn test_subscription_ws_transport_with_token() {
|
|||
}
|
||||
|
||||
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
|
||||
let mut stream = http::WebSocketStream::new_with_initializer(&schema, |value| {
|
||||
#[derive(serde::Deserialize)]
|
||||
struct Payload {
|
||||
token: String,
|
||||
}
|
||||
let (mut tx, rx) = mpsc::unbounded();
|
||||
let mut stream = http::WebSocket::with_data(
|
||||
schema,
|
||||
rx,
|
||||
Some(|value| {
|
||||
#[derive(serde::Deserialize)]
|
||||
struct Payload {
|
||||
token: String,
|
||||
}
|
||||
|
||||
let payload: Payload = serde_json::from_value(value).unwrap();
|
||||
let mut data = Data::default();
|
||||
data.insert(Token(payload.token));
|
||||
Ok(data)
|
||||
});
|
||||
let payload: Payload = serde_json::from_value(value).unwrap();
|
||||
let mut data = Data::default();
|
||||
data.insert(Token(payload.token));
|
||||
Ok(data)
|
||||
}),
|
||||
);
|
||||
|
||||
stream
|
||||
.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "connection_init",
|
||||
"payload": { "token": "123456" }
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "connection_init",
|
||||
"payload": { "token": "123456" }
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
"type": "connection_ack",
|
||||
"type": "connection_ack",
|
||||
})),
|
||||
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
|
||||
);
|
||||
|
||||
stream
|
||||
.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "start",
|
||||
"id": "1",
|
||||
"payload": {
|
||||
"query": "subscription { values }"
|
||||
},
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "start",
|
||||
"id": "1",
|
||||
"payload": {
|
||||
"query": "subscription { values }"
|
||||
},
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
for i in 0..10 {
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
"type": "data",
|
||||
"id": "1",
|
||||
"payload": { "data": { "values": i } },
|
||||
"type": "data",
|
||||
"id": "1",
|
||||
"payload": { "data": { "values": i } },
|
||||
})),
|
||||
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
|
||||
);
|
||||
|
@ -150,8 +153,8 @@ pub async fn test_subscription_ws_transport_with_token() {
|
|||
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
"type": "complete",
|
||||
"id": "1",
|
||||
"type": "complete",
|
||||
"id": "1",
|
||||
})),
|
||||
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
|
||||
);
|
||||
|
@ -189,45 +192,44 @@ pub async fn test_subscription_ws_transport_error() {
|
|||
}
|
||||
|
||||
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
|
||||
let mut stream = http::WebSocketStream::new(&schema);
|
||||
let (mut tx, rx) = mpsc::unbounded();
|
||||
let mut stream = http::WebSocket::new(schema, rx);
|
||||
|
||||
stream
|
||||
.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "connection_init"
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "connection_init"
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
"type": "connection_ack",
|
||||
"type": "connection_ack",
|
||||
})),
|
||||
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
|
||||
);
|
||||
|
||||
stream
|
||||
.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "start",
|
||||
"id": "1",
|
||||
"payload": {
|
||||
"query": "subscription { events { value } }"
|
||||
},
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "start",
|
||||
"id": "1",
|
||||
"payload": {
|
||||
"query": "subscription { events { value } }"
|
||||
},
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
for i in 0i32..5 {
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
"type": "data",
|
||||
"id": "1",
|
||||
"payload": { "data": { "events": { "value": i } } },
|
||||
"type": "data",
|
||||
"id": "1",
|
||||
"payload": { "data": { "events": { "value": i } } },
|
||||
})),
|
||||
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
|
||||
);
|
||||
|
@ -235,13 +237,15 @@ pub async fn test_subscription_ws_transport_error() {
|
|||
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
"type": "error",
|
||||
"id": "1",
|
||||
"payload": [{
|
||||
"message": "TestError",
|
||||
"locations": [{"line": 1, "column": 25}],
|
||||
"path": ["events", "value"],
|
||||
}],
|
||||
"type": "data",
|
||||
"id": "1",
|
||||
"payload": {
|
||||
"errors": [{
|
||||
"message": "TestError",
|
||||
"locations": [{"line": 1, "column": 25}],
|
||||
"path": ["events", "value"],
|
||||
}],
|
||||
},
|
||||
})),
|
||||
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
|
||||
);
|
||||
|
@ -259,17 +263,17 @@ pub async fn test_query_over_websocket() {
|
|||
}
|
||||
|
||||
let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription);
|
||||
let mut stream = http::WebSocketStream::new(&schema);
|
||||
let (mut tx, rx) = mpsc::unbounded();
|
||||
let mut stream = http::WebSocket::new(schema, rx);
|
||||
|
||||
stream
|
||||
.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "connection_init",
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "connection_init",
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
|
@ -278,19 +282,18 @@ pub async fn test_query_over_websocket() {
|
|||
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
|
||||
);
|
||||
|
||||
stream
|
||||
.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "start",
|
||||
"id": "1",
|
||||
"payload": {
|
||||
"query": "query { value }"
|
||||
},
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
tx.send(
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"type": "start",
|
||||
"id": "1",
|
||||
"payload": {
|
||||
"query": "query { value }"
|
||||
},
|
||||
}))
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
|
|
Loading…
Reference in New Issue
Block a user