Fixed GraphQL over WebSocket Protocol does not support ConnectionError events. #406

This commit is contained in:
Sunli 2021-02-06 11:02:41 +08:00
parent d193a138e3
commit 2d7627cd39
8 changed files with 313 additions and 59 deletions

17
CHANGELOG.md Normal file
View File

@ -0,0 +1,17 @@
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [2.5.2] - 2021-02-06
### Added
- Add subscription support for tide with tide-websockets.
### Fixed
- Fixed the bug that can accept subscription requests during the initialization of WebSocket.
- GraphQL over WebSocket Protocol
- Fixed GraphQL over WebSocket Protocol does not support ConnectionError events. [#406](https://github.com/async-graphql/async-graphql/issues/406)

View File

@ -10,8 +10,8 @@ use actix_http::error::PayloadError;
use actix_http::{ws, Error};
use actix_web::web::Bytes;
use actix_web::{HttpRequest, HttpResponse};
use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext};
use async_graphql::http::{WebSocket, WebSocketProtocols};
use actix_web_actors::ws::{CloseReason, Message, ProtocolError, WebsocketContext};
use async_graphql::http::{WebSocket, WebSocketProtocols, WsMessage};
use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType};
use futures_util::future::Ready;
use futures_util::stream::Stream;
@ -134,8 +134,12 @@ where
self.protocol,
)
.into_actor(self)
.map(|response, _act, ctx| {
ctx.text(response);
.map(|response, _act, ctx| match response {
WsMessage::Text(text) => ctx.text(text),
WsMessage::Close(code, msg) => ctx.close(Some(CloseReason {
code: code.into(),
description: Some(msg),
})),
})
.finish()
.spawn(ctx);

View File

@ -2,7 +2,7 @@ use std::future::Future;
use std::pin::Pin;
use std::str::FromStr;
use async_graphql::http::{WebSocket as AGWebSocket, WebSocketProtocols};
use async_graphql::http::{WebSocket as AGWebSocket, WebSocketProtocols, WsMessage};
use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType};
use futures_util::{future, StreamExt};
use tide::{Endpoint, Request, Response};
@ -83,8 +83,16 @@ where
protocol,
);
while let Some(data) = stream.next().await {
if sink.send_string(data).await.is_err() {
break;
match data {
WsMessage::Text(text) => {
if sink.send_string(text).await.is_err() {
break;
}
}
WsMessage::Close(_code, _msg) => {
// TODO: Send close frame
break;
}
}
}

View File

@ -1,5 +1,6 @@
use std::future::Future;
use async_graphql::http::WsMessage;
use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType};
use futures_util::{future, StreamExt};
use warp::filters::ws;
@ -103,7 +104,10 @@ where
initializer,
protocol,
)
.map(ws::Message::text)
.map(|msg| match msg {
WsMessage::Text(text) => ws::Message::text(text),
WsMessage::Close(code, status) => ws::Message::close_with(code, status),
})
.map(Ok)
.forward(ws_sender)
.await;

View File

@ -14,7 +14,7 @@ pub use graphiql_source::graphiql_source;
#[cfg(feature = "multipart")]
pub use multipart::MultipartOptions;
pub use playground_source::{playground_source, GraphQLPlaygroundConfig};
pub use websocket::{Protocols as WebSocketProtocols, WebSocket};
pub use websocket::{Protocols as WebSocketProtocols, WebSocket, WsMessage};
#[cfg(feature = "multipart")]
#[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))]

View File

@ -14,6 +14,46 @@ use serde::{Deserialize, Serialize};
use crate::{Data, Error, ObjectType, Request, Response, Result, Schema, SubscriptionType};
/// An enum representing the various forms of a WebSocket message.
#[derive(Clone, Debug)]
pub enum WsMessage {
/// A text WebSocket message
Text(String),
/// A close message with the close frame.
Close(u16, String),
}
impl WsMessage {
/// Returns the contained [`Text`] value, consuming the `self` value.
///
/// Because this function may panic, its use is generally discouraged.
///
/// # Panics
///
/// Panics if the self value not equals [`Text`].
pub fn unwrap_text(self) -> String {
match self {
Self::Text(text) => text,
_ => panic!("Not a text message"),
}
}
/// Returns the contained [`Close`] value, consuming the `self` value.
///
/// Because this function may panic, its use is generally discouraged.
///
/// # Panics
///
/// Panics if the self value not equals [`Close`].
pub fn unwrap_close(self) -> (u16, String) {
match self {
Self::Close(code, msg) => (code, msg),
_ => panic!("Not a close message"),
}
}
}
pin_project! {
/// A GraphQL connection over websocket.
///
@ -40,15 +80,12 @@ impl<S, Query, Mutation, Subscription>
stream: S,
protocol: Protocols,
) -> Self {
Self {
data_initializer: Some(|_| futures_util::future::ready(Ok(Default::default()))),
init_fut: None,
data: Arc::default(),
Self::with_data(
schema,
streams: HashMap::new(),
stream,
|_| futures_util::future::ready(Ok(Default::default())),
protocol,
}
)
}
}
@ -88,7 +125,7 @@ where
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
{
type Item = String;
type Item = WsMessage;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let mut this = self.project();
@ -102,14 +139,7 @@ where
let message: ClientMessage = match serde_json::from_slice(message.as_ref()) {
Ok(message) => message,
Err(e) => {
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::ConnectionError {
payload: Error::new(e.to_string()),
})
.unwrap(),
))
}
Err(err) => return Poll::Ready(Some(WsMessage::Close(1002, err.to_string()))),
};
match message {
@ -119,6 +149,25 @@ where
data_initializer(payload.unwrap_or_default()).await
}));
break;
} else {
match this.protocol {
Protocols::SubscriptionsTransportWS => {
return Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&ServerMessage::ConnectionError {
payload: Error::new(
"Too many initialisation requests.",
),
})
.unwrap(),
)));
}
Protocols::GraphQLWS => {
return Poll::Ready(Some(WsMessage::Close(
4429,
"Too many initialisation requests.".to_string(),
)));
}
}
}
}
ClientMessage::Start {
@ -135,9 +184,9 @@ where
}
ClientMessage::Stop { id } => {
if this.streams.remove(id).is_some() {
return Poll::Ready(Some(
return Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&ServerMessage::Complete { id }).unwrap(),
));
)));
}
}
// Note: in the revised `graphql-ws` spec, there is no equivalent to the
@ -154,14 +203,21 @@ where
return match res {
Ok(data) => {
*this.data = Arc::new(data);
Poll::Ready(Some(
Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
))
)))
}
Err(err) => Poll::Ready(Some(
serde_json::to_string(&ServerMessage::ConnectionError { payload: err })
Err(err) => match this.protocol {
Protocols::SubscriptionsTransportWS => Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&ServerMessage::ConnectionError {
payload: Error::new(err.message),
})
.unwrap(),
)),
))),
Protocols::GraphQLWS => {
Poll::Ready(Some(WsMessage::Close(1002, err.message)))
}
},
};
}
}
@ -169,16 +225,16 @@ where
for (id, stream) in &mut *this.streams {
match Pin::new(stream).poll_next(cx) {
Poll::Ready(Some(payload)) => {
return Poll::Ready(Some(
return Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&this.protocol.next_message(id, payload)).unwrap(),
));
)));
}
Poll::Ready(None) => {
let id = id.clone();
this.streams.remove(&id);
return Poll::Ready(Some(
return Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(),
));
)));
}
Poll::Pending => {}
}

View File

@ -38,7 +38,8 @@ pub async fn test_subscription_ws_transport() {
.unwrap();
assert_eq!(
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap().unwrap_text())
.unwrap(),
serde_json::json!({
"type": "connection_ack",
}),
@ -59,7 +60,8 @@ pub async fn test_subscription_ws_transport() {
for i in 0..10 {
assert_eq!(
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap().unwrap_text())
.unwrap(),
serde_json::json!({
"type": "next",
"id": "1",
@ -69,7 +71,8 @@ pub async fn test_subscription_ws_transport() {
}
assert_eq!(
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap().unwrap_text())
.unwrap(),
serde_json::json!({
"type": "complete",
"id": "1",
@ -135,7 +138,7 @@ pub async fn test_subscription_ws_transport_with_token() {
Some(value!({
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
tx.send(
@ -158,7 +161,7 @@ pub async fn test_subscription_ws_transport_with_token() {
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
}
@ -167,7 +170,7 @@ pub async fn test_subscription_ws_transport_with_token() {
"type": "complete",
"id": "1",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
}
@ -223,7 +226,7 @@ pub async fn test_subscription_ws_transport_error() {
Some(value!({
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
tx.send(
@ -246,7 +249,7 @@ pub async fn test_subscription_ws_transport_error() {
"id": "1",
"payload": { "data": { "events": { "value": i } } },
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
}
@ -263,7 +266,106 @@ pub async fn test_subscription_ws_transport_error() {
}],
},
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
}
#[async_std::test]
pub async fn test_subscription_init_error() {
struct QueryRoot;
#[Object]
impl QueryRoot {
async fn value(&self) -> i32 {
10
}
}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn events(&self) -> impl Stream<Item = i32> {
futures_util::stream::once(async move { 10 })
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::with_data(
schema,
rx,
|_| async move { Err("Error!".into()) },
WebSocketProtocols::GraphQLWS,
);
tx.send(
serde_json::to_string(&value!({
"type": "connection_init"
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
(1002, "Error!".to_string()),
dbg!(stream.next().await.unwrap()).unwrap_close()
);
}
#[async_std::test]
pub async fn test_subscription_too_many_initialisation_requests_error() {
struct QueryRoot;
#[Object]
impl QueryRoot {
async fn value(&self) -> i32 {
10
}
}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn events(&self) -> impl Stream<Item = i32> {
futures_util::stream::once(async move { 10 })
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS);
tx.send(
serde_json::to_string(&value!({
"type": "connection_init"
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(value!({
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
tx.send(
serde_json::to_string(&value!({
"type": "connection_init"
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
(4429, "Too many initialisation requests.".to_string()),
stream.next().await.unwrap().unwrap_close()
);
}
@ -295,7 +397,7 @@ pub async fn test_query_over_websocket() {
Some(value!({
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
tx.send(
@ -317,7 +419,7 @@ pub async fn test_query_over_websocket() {
"id": "1",
"payload": { "data": { "value": 999 } },
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
assert_eq!(
@ -325,6 +427,6 @@ pub async fn test_query_over_websocket() {
"type": "complete",
"id": "1",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
}

View File

@ -38,7 +38,8 @@ pub async fn test_subscription_ws_transport() {
.unwrap();
assert_eq!(
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap().unwrap_text())
.unwrap(),
serde_json::json!({
"type": "connection_ack",
}),
@ -59,7 +60,8 @@ pub async fn test_subscription_ws_transport() {
for i in 0..10 {
assert_eq!(
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap().unwrap_text())
.unwrap(),
serde_json::json!({
"type": "data",
"id": "1",
@ -69,7 +71,8 @@ pub async fn test_subscription_ws_transport() {
}
assert_eq!(
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap().unwrap_text())
.unwrap(),
serde_json::json!({
"type": "complete",
"id": "1",
@ -135,7 +138,7 @@ pub async fn test_subscription_ws_transport_with_token() {
Some(value!({
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
tx.send(
@ -158,7 +161,7 @@ pub async fn test_subscription_ws_transport_with_token() {
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
}
@ -167,7 +170,7 @@ pub async fn test_subscription_ws_transport_with_token() {
"type": "complete",
"id": "1",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
}
@ -223,7 +226,7 @@ pub async fn test_subscription_ws_transport_error() {
Some(value!({
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
tx.send(
@ -246,7 +249,7 @@ pub async fn test_subscription_ws_transport_error() {
"id": "1",
"payload": { "data": { "events": { "value": i } } },
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
}
@ -263,7 +266,67 @@ pub async fn test_subscription_ws_transport_error() {
}],
},
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
}
#[async_std::test]
pub async fn test_subscription_too_many_initialisation_requests_error() {
struct QueryRoot;
#[Object]
impl QueryRoot {
async fn value(&self) -> i32 {
10
}
}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn events(&self) -> impl Stream<Item = i32> {
futures_util::stream::once(async move { 10 })
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS);
tx.send(
serde_json::to_string(&value!({
"type": "connection_init"
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(value!({
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
tx.send(
serde_json::to_string(&value!({
"type": "connection_init"
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(value!({
"type": "connection_error",
"payload": {
"message": "Too many initialisation requests."
},
})),
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
}
@ -295,7 +358,7 @@ pub async fn test_query_over_websocket() {
Some(value!({
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
tx.send(
@ -317,7 +380,7 @@ pub async fn test_query_over_websocket() {
"id": "1",
"payload": { "data": { "value": 999 } },
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
assert_eq!(
@ -325,6 +388,6 @@ pub async fn test_query_over_websocket() {
"type": "complete",
"id": "1",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
}