http::websocket::create function now returns the WebSocketStream type.

This commit is contained in:
Sunli 2020-09-15 11:12:19 +08:00
parent 4f0ef35d39
commit 86a6a34aac
4 changed files with 191 additions and 137 deletions

View File

@ -2,9 +2,10 @@ use actix::{
Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture,
};
use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext};
use async_graphql::http::websocket::WebSocketStream;
use async_graphql::{resolver_utils::ObjectType, Data, FieldResult, Schema, SubscriptionType};
use futures::channel::mpsc;
use futures::SinkExt;
use futures::stream::SplitSink;
use futures::{SinkExt, StreamExt};
use std::time::{Duration, Instant};
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
@ -14,7 +15,7 @@ const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
pub struct WSSubscription<Query, Mutation, Subscription> {
schema: Schema<Query, Mutation, Subscription>,
hb: Instant,
sink: Option<mpsc::UnboundedSender<Vec<u8>>>,
sink: Option<SplitSink<WebSocketStream, String>>,
initializer: Option<Box<dyn Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync>>,
}
@ -67,11 +68,12 @@ where
self.hb(ctx);
if let Some(initializer) = self.initializer.take() {
let (sink, stream) =
async_graphql::http::websocket::create_with_initializer(&self.schema, initializer);
async_graphql::http::websocket::create_with_initializer(&self.schema, initializer)
.split();
ctx.add_stream(stream);
self.sink = Some(sink);
} else {
let (sink, stream) = async_graphql::http::websocket::create(&self.schema);
let (sink, stream) = async_graphql::http::websocket::create(&self.schema).split();
ctx.add_stream(stream);
self.sink = Some(sink);
};
@ -103,11 +105,20 @@ where
self.hb = Instant::now();
}
Message::Text(s) => {
if let Some(mut sink) = self.sink.clone() {
async move { sink.send(s.into()).await }
.into_actor(self)
.then(|_, actor, _| async {}.into_actor(actor))
.wait(ctx);
if let Some(mut sink) = self.sink.take() {
async move {
let res = sink.send(s).await;
res.map(|_| sink)
}
.into_actor(self)
.then(|res, actor, ctx| {
match res {
Ok(sink) => actor.sink = Some(sink),
Err(_) => ctx.stop(),
}
async {}.into_actor(actor)
})
.wait(ctx);
}
}
Message::Binary(_) | Message::Close(_) | Message::Continuation(_) => {
@ -118,16 +129,14 @@ where
}
}
impl<Query, Mutation, Subscription> StreamHandler<Vec<u8>>
impl<Query, Mutation, Subscription> StreamHandler<String>
for WSSubscription<Query, Mutation, Subscription>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
fn handle(&mut self, data: Vec<u8>, ctx: &mut Self::Context) {
if let Ok(text) = String::from_utf8(data) {
ctx.text(text);
}
fn handle(&mut self, data: String, ctx: &mut Self::Context) {
ctx.text(data);
}
}

View File

@ -202,7 +202,8 @@ where
let (mut stx, srx) = async_graphql::http::websocket::create_with_initializer(
&schema,
initializer,
);
)
.split();
let mut rx = rx.fuse();
let srx = srx.fuse();
@ -213,11 +214,9 @@ where
loop {
select! {
bytes = srx.next() => {
if let Some(bytes) = bytes {
if let Ok(text) = String::from_utf8(bytes) {
if tx.send(Message::text(text)).await.is_err() {
return;
}
if let Some(text) = bytes {
if tx.send(Message::text(text)).await.is_err() {
return;
}
} else {
return;
@ -226,8 +225,10 @@ where
msg = rx.next() => {
if let Some(Ok(msg)) = msg {
if msg.is_text() {
if stx.send(msg.into_bytes()).await.is_err() {
return;
if let Ok(text) = String::from_utf8(msg.into_bytes()) {
if stx.send(text).await.is_err() {
return;
}
}
}
} else {

View File

@ -4,7 +4,7 @@ use crate::resolver_utils::ObjectType;
use crate::{Data, FieldResult, Request, Response, Schema, SubscriptionType};
use futures::channel::mpsc;
use futures::task::{Context, Poll};
use futures::{Future, Stream, StreamExt};
use futures::{Future, Sink, SinkExt, Stream, StreamExt};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::pin::Pin;
@ -29,10 +29,47 @@ type HandleRequestBoxFut<'a> =
type InitializerFn = Arc<dyn Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync>;
/// A wrapper around an underlying raw stream which implements the WebSocket protocol.
///
/// Only Text messages can be transmitted. You can use `futures::stream::StreamExt::split` function
/// to splits this object into separate Sink and Stream objects.
pub struct WebSocketStream {
tx: mpsc::UnboundedSender<String>,
rx: Pin<Box<dyn Stream<Item = String> + Send>>,
}
impl Sink<String> for WebSocketStream {
type Error = mpsc::SendError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_ready_unpin(cx)
}
fn start_send(mut self: Pin<&mut Self>, item: String) -> Result<(), Self::Error> {
self.tx.start_send(item)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_flush_unpin(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.tx.poll_close_unpin(cx)
}
}
impl Stream for WebSocketStream {
type Item = String;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_next_unpin(cx)
}
}
/// Create a websocket transport.
pub fn create<Query, Mutation, Subscription>(
schema: &Schema<Query, Mutation, Subscription>,
) -> (mpsc::UnboundedSender<Vec<u8>>, impl Stream<Item = Vec<u8>>)
) -> WebSocketStream
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
@ -45,14 +82,14 @@ where
pub fn create_with_initializer<Query, Mutation, Subscription>(
schema: &Schema<Query, Mutation, Subscription>,
initializer: impl Fn(serde_json::Value) -> FieldResult<Data> + Send + Sync + 'static,
) -> (mpsc::UnboundedSender<Vec<u8>>, impl Stream<Item = Vec<u8>>)
) -> WebSocketStream
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
let schema = schema.clone();
let (tx_bytes, rx_bytes) = mpsc::unbounded();
let (tx, rx) = mpsc::unbounded();
let stream = async_stream::stream! {
let mut streams = Default::default();
let mut send_buf = Default::default();
@ -60,7 +97,7 @@ where
let mut inner_stream = SubscriptionStream {
schema: &schema,
initializer: Arc::new(initializer),
rx_bytes,
rx_bytes: rx,
handle_request_fut: None,
ctx: Some(WSContext {
streams: &mut streams,
@ -72,17 +109,20 @@ where
yield data;
}
};
(tx_bytes, stream)
WebSocketStream {
tx,
rx: Box::pin(stream),
}
}
struct WSContext<'a> {
streams: &'a mut SubscriptionStreams,
send_buf: &'a mut VecDeque<Vec<u8>>,
send_buf: &'a mut VecDeque<String>,
ctx_data: &'a mut Arc<Data>,
}
fn send_message<T: Serialize>(send_buf: &mut VecDeque<Vec<u8>>, msg: &T) {
if let Ok(data) = serde_json::to_vec(msg) {
fn send_message<T: Serialize>(send_buf: &mut VecDeque<String>, msg: &T) {
if let Ok(data) = serde_json::to_string(msg) {
send_buf.push_back(data);
}
}
@ -92,7 +132,7 @@ fn send_message<T: Serialize>(send_buf: &mut VecDeque<Vec<u8>>, msg: &T) {
struct SubscriptionStream<'a, Query, Mutation, Subscription> {
schema: &'a Schema<Query, Mutation, Subscription>,
initializer: InitializerFn,
rx_bytes: mpsc::UnboundedReceiver<Vec<u8>>,
rx_bytes: mpsc::UnboundedReceiver<String>,
handle_request_fut: Option<HandleRequestBoxFut<'a>>,
ctx: Option<WSContext<'a>>,
}
@ -104,7 +144,7 @@ where
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
type Item = Vec<u8>;
type Item = String;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = &mut *self;
@ -205,14 +245,14 @@ async fn handle_request<'a, Query, Mutation, Subscription>(
schema: Schema<Query, Mutation, Subscription>,
initializer: InitializerFn,
ctx: WSContext<'a>,
data: Vec<u8>,
data: String,
) -> FieldResult<WSContext<'a>>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
match serde_json::from_slice::<OperationMessage<serde_json::Value>>(&data) {
match serde_json::from_str::<OperationMessage<serde_json::Value>>(&data) {
Ok(msg) => match msg.ty {
"connection_init" => {
if let Some(payload) = msg.payload {

View File

@ -18,37 +18,38 @@ pub async fn test_subscription_ws_transport() {
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, stream) = http::websocket::create(&schema);
futures::pin_mut!(stream);
let mut stream = http::websocket::create(&schema);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init",
}))
.unwrap(),
)
.await
.unwrap();
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "connection_init",
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap(),
)
.await
.unwrap();
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
@ -57,7 +58,7 @@ pub async fn test_subscription_ws_transport() {
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
}
@ -66,7 +67,7 @@ pub async fn test_subscription_ws_transport() {
"type": "complete",
"id": "1",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
}
@ -92,7 +93,7 @@ pub async fn test_subscription_ws_transport_with_token() {
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, stream) = http::websocket::create_with_initializer(&schema, |value| {
let mut stream = http::websocket::create_with_initializer(&schema, |value| {
#[derive(serde::Deserialize)]
struct Payload {
token: String,
@ -103,37 +104,38 @@ pub async fn test_subscription_ws_transport_with_token() {
data.insert(Token(payload.token));
Ok(data)
});
futures::pin_mut!(stream);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init",
"payload": { "token": "123456" }
}))
.unwrap(),
)
.await
.unwrap();
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "connection_init",
"payload": { "token": "123456" }
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap(),
)
.await
.unwrap();
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
@ -142,7 +144,7 @@ pub async fn test_subscription_ws_transport_with_token() {
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
}
@ -151,7 +153,7 @@ pub async fn test_subscription_ws_transport_with_token() {
"type": "complete",
"id": "1",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
}
@ -187,37 +189,38 @@ pub async fn test_subscription_ws_transport_error() {
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, stream) = http::websocket::create(&schema);
futures::pin_mut!(stream);
let mut stream = http::websocket::create(&schema);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init"
}))
.unwrap(),
)
.await
.unwrap();
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "connection_init"
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { events { value } }"
},
}))
.unwrap(),
)
.await
.unwrap();
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { events { value } }"
},
}))
.unwrap(),
)
.await
.unwrap();
for i in 0i32..5 {
assert_eq!(
@ -226,7 +229,7 @@ pub async fn test_subscription_ws_transport_error() {
"id": "1",
"payload": { "data": { "events": { "value": i } } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
}
@ -240,7 +243,7 @@ pub async fn test_subscription_ws_transport_error() {
"path": ["events", "value"],
}],
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
}
@ -256,37 +259,38 @@ pub async fn test_query_over_websocket() {
}
let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription);
let (mut sink, stream) = http::websocket::create(&schema);
futures::pin_mut!(stream);
let mut stream = http::websocket::create(&schema);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init",
}))
.unwrap(),
)
.await
.unwrap();
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "connection_init",
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "query { value }"
},
}))
.unwrap(),
)
.await
.unwrap();
stream
.send(
serde_json::to_string(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "query { value }"
},
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
@ -294,7 +298,7 @@ pub async fn test_query_over_websocket() {
"id": "1",
"payload": { "data": { "value": 999 } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
assert_eq!(
@ -302,6 +306,6 @@ pub async fn test_query_over_websocket() {
"type": "complete",
"id": "1",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
}