Merge pull request #350 from pfernie/graphql-ws-subproto

Detect subprotocol for GraphQL Over WebSocket subscriptions
This commit is contained in:
Sunli 2020-12-05 08:34:08 +08:00 committed by GitHub
commit f2f22d2d5a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 537 additions and 65 deletions

@ -1 +1 @@
Subproject commit 71d7318ca5bd5f9b77ffc7136990c80db35a2b46
Subproject commit b51536fb0a933d25d9fcc91a08c2a53daf74daf8

View File

@ -1,19 +1,26 @@
use std::str::FromStr;
use std::time::{Duration, Instant};
use actix::{
Actor, ActorContext, ActorFuture, ActorStream, AsyncContext, ContextFutureSpawner,
StreamHandler, WrapFuture, WrapStream,
};
use actix_http::ws;
use actix_http::error::PayloadError;
use actix_http::{ws, Error};
use actix_web::web::Bytes;
use actix_web::{HttpRequest, HttpResponse};
use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext};
use async_graphql::http::WebSocket;
use async_graphql::http::{WebSocket, WebSocketProtocols};
use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType};
use std::time::{Duration, Instant};
use futures_util::stream::Stream;
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
/// Actor for subscription via websocket
pub struct WSSubscription<Query, Mutation, Subscription> {
schema: Option<Schema<Query, Mutation, Subscription>>,
schema: Schema<Query, Mutation, Subscription>,
protocol: WebSocketProtocols,
last_heartbeat: Instant,
messages: Option<async_channel::Sender<Vec<u8>>>,
initializer: Option<Box<dyn FnOnce(serde_json::Value) -> Result<Data> + Send + Sync>>,
@ -26,26 +33,58 @@ where
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
/// Create an actor for subscription connection via websocket.
pub fn new(schema: Schema<Query, Mutation, Subscription>) -> Self {
Self {
schema: Some(schema),
last_heartbeat: Instant::now(),
messages: None,
initializer: None,
continuation: Vec::new(),
}
/// Start an actor for subscription connection via websocket.
pub fn start<T>(
schema: Schema<Query, Mutation, Subscription>,
request: &HttpRequest,
stream: T,
) -> Result<HttpResponse, Error>
where
T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
{
Self::start_with_initializer(schema, request, stream, |_| Ok(Default::default()))
}
/// Set a context data initialization function.
pub fn initializer<F>(self, f: F) -> Self
/// Start an actor for subscription connection via websocket with an initialization function.
pub fn start_with_initializer<T, F>(
schema: Schema<Query, Mutation, Subscription>,
request: &HttpRequest,
stream: T,
initializer: F,
) -> Result<HttpResponse, Error>
where
T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
F: FnOnce(serde_json::Value) -> Result<Data> + Send + Sync + 'static,
{
Self {
initializer: Some(Box::new(f)),
..self
}
let protocol = match request
.headers()
.get("sec-websocket-protocol")
.and_then(|value| value.to_str().ok())
.and_then(|protocols| {
protocols
.split(',')
.find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
}) {
Some(protocol) => protocol,
None => {
// default to the prior standard
WebSocketProtocols::SubscriptionsTransportWS
}
};
actix_web_actors::ws::start_with_protocols(
Self {
schema,
protocol,
last_heartbeat: Instant::now(),
messages: None,
initializer: Some(Box::new(initializer)),
continuation: Vec::new(),
},
&["graphql-transport-ws", "graphql-ws"],
request,
stream,
)
}
fn send_heartbeats(&self, ctx: &mut WebsocketContext<Self>) {
@ -71,13 +110,18 @@ where
let (tx, rx) = async_channel::unbounded();
WebSocket::with_data(self.schema.take().unwrap(), rx, self.initializer.take())
.into_actor(self)
.map(|response, _act, ctx| {
ctx.text(response);
})
.finish()
.spawn(ctx);
WebSocket::with_data(
self.schema.clone(),
rx,
self.initializer.take(),
self.protocol,
)
.into_actor(self)
.map(|response, _act, ctx| {
ctx.text(response);
})
.finish()
.spawn(ctx);
self.messages = Some(tx);
}

View File

@ -64,28 +64,47 @@ where
Subscription: SubscriptionType + Send + Sync + 'static,
F: FnOnce(serde_json::Value) -> Result<Data> + Send + Sync + Clone + 'static,
{
warp::ws().map(move |ws: ws::Ws| {
let schema = schema.clone();
let initializer = initializer.clone();
use async_graphql::http::WebSocketProtocols;
use std::str::FromStr;
let reply = ws.on_upgrade(move |websocket| {
let (ws_sender, ws_receiver) = websocket.split();
warp::ws()
.and(warp::header::optional::<String>("sec-websocket-protocol"))
.map(move |ws: ws::Ws, protocols: Option<String>| {
let schema = schema.clone();
let initializer = initializer.clone();
async move {
let _ = async_graphql::http::WebSocket::with_data(
schema,
ws_receiver
.take_while(|msg| future::ready(msg.is_ok()))
.map(Result::unwrap)
.map(ws::Message::into_bytes),
initializer,
)
.map(ws::Message::text)
.map(Ok)
.forward(ws_sender)
.await;
}
});
warp::reply::with_header(reply, "Sec-WebSocket-Protocol", "graphql-ws")
})
let protocol = protocols
.and_then(|protocols| {
protocols
.split(',')
.find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
})
.unwrap_or(WebSocketProtocols::SubscriptionsTransportWS);
let reply = ws.on_upgrade(move |websocket| {
let (ws_sender, ws_receiver) = websocket.split();
async move {
let _ = async_graphql::http::WebSocket::with_data(
schema,
ws_receiver
.take_while(|msg| future::ready(msg.is_ok()))
.map(Result::unwrap)
.map(ws::Message::into_bytes),
initializer,
protocol,
)
.map(ws::Message::text)
.map(Ok)
.forward(ws_sender)
.await;
}
});
warp::reply::with_header(
reply,
"Sec-WebSocket-Protocol",
protocol.sec_websocket_protocol(),
)
})
}

View File

@ -14,7 +14,7 @@ pub use graphiql_source::graphiql_source;
#[cfg(feature = "multipart")]
pub use multipart::MultipartOptions;
pub use playground_source::{playground_source, GraphQLPlaygroundConfig};
pub use websocket::WebSocket;
pub use websocket::{Protocols as WebSocketProtocols, WebSocket};
#[cfg(feature = "multipart")]
#[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))]

View File

@ -22,6 +22,7 @@ pin_project! {
streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
#[pin]
stream: S,
protocol: Protocols,
}
}
@ -30,13 +31,18 @@ impl<S, Query, Mutation, Subscription>
{
/// Create a new websocket.
#[must_use]
pub fn new(schema: Schema<Query, Mutation, Subscription>, stream: S) -> Self {
pub fn new(
schema: Schema<Query, Mutation, Subscription>,
stream: S,
protocol: Protocols,
) -> Self {
Self {
data_initializer: None,
data: Arc::default(),
schema,
streams: HashMap::new(),
stream,
protocol,
}
}
}
@ -52,6 +58,7 @@ impl<S, F, Query, Mutation, Subscription> WebSocket<S, F, Query, Mutation, Subsc
schema: Schema<Query, Mutation, Subscription>,
stream: S,
data_initializer: Option<F>,
protocol: Protocols,
) -> Self {
Self {
data_initializer,
@ -59,6 +66,7 @@ impl<S, F, Query, Mutation, Subscription> WebSocket<S, F, Query, Mutation, Subsc
schema,
streams: HashMap::new(),
stream,
protocol,
}
}
}
@ -135,6 +143,9 @@ where
));
}
}
// Note: in the revised `graphql-ws` spec, there is no equivalent to the
// `CONNECTION_TERMINATE` `client -> server` message; rather, disconnection is
// handled by disconnecting the websocket
ClientMessage::ConnectionTerminate => return Poll::Ready(None),
}
}
@ -143,11 +154,7 @@ where
match Pin::new(stream).poll_next(cx) {
Poll::Ready(Some(payload)) => {
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::Data {
id,
payload: Box::new(payload),
})
.unwrap(),
serde_json::to_string(&this.protocol.next_message(id, payload)).unwrap(),
));
}
Poll::Ready(None) => {
@ -165,27 +172,99 @@ where
}
}
/// Specification of which GraphQL Over WebSockets protocol is being utilized
#[derive(Copy, Clone)]
pub enum Protocols {
/// [subscriptions-transport-ws protocol](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md).
SubscriptionsTransportWS,
/// [graphql-ws protocol](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md).
GraphQLWS,
}
impl Protocols {
/// Returns the `Sec-WebSocket-Protocol` header value for the protocol
pub fn sec_websocket_protocol(&self) -> &str {
match self {
Protocols::SubscriptionsTransportWS => "graphql-ws",
Protocols::GraphQLWS => "graphql-transport-ws",
}
}
#[inline]
fn next_message<'s>(&self, id: &'s str, payload: Response) -> ServerMessage<'s> {
match self {
Protocols::SubscriptionsTransportWS => ServerMessage::Data {
id,
payload: Box::new(payload),
},
Protocols::GraphQLWS => ServerMessage::Next {
id,
payload: Box::new(payload),
},
}
}
}
impl std::str::FromStr for Protocols {
type Err = Error;
fn from_str(protocol: &str) -> Result<Self, Self::Err> {
if protocol.eq_ignore_ascii_case("graphql-ws") {
Ok(Protocols::SubscriptionsTransportWS)
} else if protocol.eq_ignore_ascii_case("graphql-transport-ws") {
Ok(Protocols::GraphQLWS)
} else {
Err(Error::new(format!(
"Unsupported Sec-WebSocket-Protocol: {}",
protocol
)))
}
}
}
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ClientMessage<'a> {
ConnectionInit { payload: Option<serde_json::Value> },
Start { id: String, payload: Request },
Stop { id: &'a str },
ConnectionInit {
payload: Option<serde_json::Value>,
},
#[serde(alias = "subscribe")]
Start {
id: String,
payload: Request,
},
#[serde(alias = "complete")]
Stop {
id: &'a str,
},
ConnectionTerminate,
}
#[derive(Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ServerMessage<'a> {
ConnectionError { payload: Error },
ConnectionError {
payload: Error,
},
ConnectionAck,
Data { id: &'a str, payload: Box<Response> },
/// subscriptions-transport-ws protocol next payload
Data {
id: &'a str,
payload: Box<Response>,
},
/// graphql-ws protocol next payload
Next {
id: &'a str,
payload: Box<Response>,
},
// Not used by this library, as it's not necessary to send
// Error {
// id: &'a str,
// payload: serde_json::Value,
// },
Complete { id: &'a str },
Complete {
id: &'a str,
},
// Not used by this library
// #[serde(rename = "ka")]
// KeepAlive

View File

@ -0,0 +1,328 @@
use async_graphql::http::WebSocketProtocols;
use async_graphql::*;
use futures_util::stream::{Stream, StreamExt};
#[async_std::test]
pub async fn test_subscription_ws_transport() {
struct QueryRoot;
#[Object]
impl QueryRoot {
async fn value(&self) -> i32 {
10
}
}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn values(&self) -> impl Stream<Item = i32> {
futures_util::stream::iter(0..10)
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (tx, rx) = async_channel::unbounded();
let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS);
tx.send(
serde_json::to_string(&value!({
"type": "connection_init",
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
serde_json::json!({
"type": "connection_ack",
}),
);
tx.send(
serde_json::to_string(&value!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
serde_json::json!({
"type": "next",
"id": "1",
"payload": { "data": { "values": i } },
}),
);
}
assert_eq!(
serde_json::from_str::<serde_json::Value>(&stream.next().await.unwrap()).unwrap(),
serde_json::json!({
"type": "complete",
"id": "1",
}),
);
}
#[async_std::test]
pub async fn test_subscription_ws_transport_with_token() {
struct Token(String);
struct QueryRoot;
#[Object]
impl QueryRoot {
async fn value(&self) -> i32 {
10
}
}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn values(&self, ctx: &Context<'_>) -> Result<impl Stream<Item = i32>> {
if ctx.data_unchecked::<Token>().0 != "123456" {
return Err("forbidden".into());
}
Ok(futures_util::stream::iter(0..10))
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (tx, rx) = async_channel::unbounded();
let mut stream = http::WebSocket::with_data(
schema,
rx,
Some(|value| {
#[derive(serde::Deserialize)]
struct Payload {
token: String,
}
let payload: Payload = serde_json::from_value(value).unwrap();
let mut data = Data::default();
data.insert(Token(payload.token));
Ok(data)
}),
WebSocketProtocols::GraphQLWS,
);
tx.send(
serde_json::to_string(&value!({
"type": "connection_init",
"payload": { "token": "123456" }
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(value!({
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
tx.send(
serde_json::to_string(&value!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
Some(value!({
"type": "next",
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
}
assert_eq!(
Some(value!({
"type": "complete",
"id": "1",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
}
#[async_std::test]
pub async fn test_subscription_ws_transport_error() {
struct Event {
value: i32,
}
#[Object]
impl Event {
async fn value(&self) -> Result<i32> {
if self.value < 5 {
Ok(self.value)
} else {
Err("TestError".into())
}
}
}
struct QueryRoot;
#[Object]
impl QueryRoot {
async fn value(&self) -> i32 {
10
}
}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn events(&self) -> impl Stream<Item = Event> {
futures_util::stream::iter((0..10).map(|n| Event { value: n }))
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (tx, rx) = async_channel::unbounded();
let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS);
tx.send(
serde_json::to_string(&value!({
"type": "connection_init"
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(value!({
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
tx.send(
serde_json::to_string(&value!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { events { value } }"
},
}))
.unwrap(),
)
.await
.unwrap();
for i in 0i32..5 {
assert_eq!(
Some(value!({
"type": "next",
"id": "1",
"payload": { "data": { "events": { "value": i } } },
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
}
assert_eq!(
Some(value!({
"type": "next",
"id": "1",
"payload": {
"data": null,
"errors": [{
"message": "TestError",
"locations": [{"line": 1, "column": 25}],
"path": ["events", "value"],
}],
},
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
}
#[async_std::test]
pub async fn test_query_over_websocket() {
struct QueryRoot;
#[Object]
impl QueryRoot {
async fn value(&self) -> i32 {
999
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription);
let (tx, rx) = async_channel::unbounded();
let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS);
tx.send(
serde_json::to_string(&value!({
"type": "connection_init",
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(value!({
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
tx.send(
serde_json::to_string(&value!({
"type": "start",
"id": "1",
"payload": {
"query": "query { value }"
},
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(value!({
"type": "next",
"id": "1",
"payload": { "data": { "value": 999 } },
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
assert_eq!(
Some(value!({
"type": "complete",
"id": "1",
})),
serde_json::from_str(&stream.next().await.unwrap()).unwrap()
);
}

View File

@ -1,3 +1,4 @@
use async_graphql::http::WebSocketProtocols;
use async_graphql::*;
use futures_util::stream::{Stream, StreamExt};
@ -23,7 +24,7 @@ pub async fn test_subscription_ws_transport() {
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (tx, rx) = async_channel::unbounded();
let mut stream = http::WebSocket::new(schema, rx);
let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS);
tx.send(
serde_json::to_string(&value!({
@ -115,6 +116,7 @@ pub async fn test_subscription_ws_transport_with_token() {
data.insert(Token(payload.token));
Ok(data)
}),
WebSocketProtocols::SubscriptionsTransportWS,
);
tx.send(
@ -204,7 +206,7 @@ pub async fn test_subscription_ws_transport_error() {
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (tx, rx) = async_channel::unbounded();
let mut stream = http::WebSocket::new(schema, rx);
let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS);
tx.send(
serde_json::to_string(&value!({
@ -276,7 +278,7 @@ pub async fn test_query_over_websocket() {
let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription);
let (tx, rx) = async_channel::unbounded();
let mut stream = http::WebSocket::new(schema, rx);
let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS);
tx.send(
serde_json::to_string(&value!({