Rework websocket
This commit is contained in:
parent
396e7bc175
commit
b75aa6c25b
|
@ -41,6 +41,10 @@ impl Data {
|
||||||
pub fn insert<D: Any + Send + Sync>(&mut self, data: D) {
|
pub fn insert<D: Any + Send + Sync>(&mut self, data: D) {
|
||||||
self.0.insert(TypeId::of::<D>(), Box::new(data));
|
self.0.insert(TypeId::of::<D>(), Box::new(data));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn merge(&mut self, other: Data) {
|
||||||
|
self.0.extend(other.0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Debug for Data {
|
impl Debug for Data {
|
||||||
|
|
|
@ -8,10 +8,7 @@ use std::task::{Context, Poll};
|
||||||
|
|
||||||
use futures_util::stream::Stream;
|
use futures_util::stream::Stream;
|
||||||
use futures_util::FutureExt;
|
use futures_util::FutureExt;
|
||||||
use futures_util::{
|
use futures_util::{future::BoxFuture, StreamExt};
|
||||||
future::{BoxFuture, Ready},
|
|
||||||
StreamExt,
|
|
||||||
};
|
|
||||||
use pin_project_lite::pin_project;
|
use pin_project_lite::pin_project;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
@ -60,14 +57,18 @@ impl WsMessage {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type BoxInitializer =
|
||||||
|
Box<(dyn FnOnce(serde_json::Value) -> BoxFuture<'static, Result<Data>> + Send + 'static)>;
|
||||||
|
|
||||||
pin_project! {
|
pin_project! {
|
||||||
/// A GraphQL connection over websocket.
|
/// A GraphQL connection over websocket.
|
||||||
///
|
///
|
||||||
/// [Reference](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md).
|
/// [Reference](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md).
|
||||||
/// [Reference](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md).
|
/// [Reference](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md).
|
||||||
pub struct WebSocket<S, F, Query, Mutation, Subscription> {
|
pub struct WebSocket<S, Query, Mutation, Subscription> {
|
||||||
data_initializer: Option<F>,
|
data_initializer: Option<BoxInitializer>,
|
||||||
init_fut: Option<BoxFuture<'static, Result<Data>>>,
|
init_fut: Option<BoxFuture<'static, Result<Data>>>,
|
||||||
|
connection_data: Option<Data>,
|
||||||
data: Option<Arc<Data>>,
|
data: Option<Arc<Data>>,
|
||||||
schema: Schema<Query, Mutation, Subscription>,
|
schema: Schema<Query, Mutation, Subscription>,
|
||||||
streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
|
streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
|
||||||
|
@ -80,80 +81,20 @@ pin_project! {
|
||||||
type MessageMapStream<S> =
|
type MessageMapStream<S> =
|
||||||
futures_util::stream::Map<S, fn(<S as Stream>::Item) -> serde_json::Result<ClientMessage>>;
|
futures_util::stream::Map<S, fn(<S as Stream>::Item) -> serde_json::Result<ClientMessage>>;
|
||||||
|
|
||||||
impl<S, Query, Mutation, Subscription>
|
impl<S, Query, Mutation, Subscription> WebSocket<S, Query, Mutation, Subscription>
|
||||||
WebSocket<
|
|
||||||
MessageMapStream<S>,
|
|
||||||
fn(serde_json::Value) -> Ready<Result<Data>>,
|
|
||||||
Query,
|
|
||||||
Mutation,
|
|
||||||
Subscription,
|
|
||||||
>
|
|
||||||
where
|
|
||||||
S: Stream,
|
|
||||||
S::Item: AsRef<[u8]>,
|
|
||||||
{
|
|
||||||
/// Create a new websocket.
|
|
||||||
#[must_use]
|
|
||||||
pub fn new(
|
|
||||||
schema: Schema<Query, Mutation, Subscription>,
|
|
||||||
stream: S,
|
|
||||||
protocol: Protocols,
|
|
||||||
) -> Self {
|
|
||||||
Self::with_data(
|
|
||||||
schema,
|
|
||||||
stream,
|
|
||||||
|_| futures_util::future::ready(Ok(Default::default())),
|
|
||||||
protocol,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, F, Query, Mutation, Subscription>
|
|
||||||
WebSocket<MessageMapStream<S>, F, Query, Mutation, Subscription>
|
|
||||||
where
|
|
||||||
S: Stream,
|
|
||||||
S::Item: AsRef<[u8]>,
|
|
||||||
{
|
|
||||||
/// Create a new websocket with a data initialization function.
|
|
||||||
///
|
|
||||||
/// This function, if present, will be called with the data sent by the client in the
|
|
||||||
/// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
|
|
||||||
/// From that point on the returned data will be accessible to all requests.
|
|
||||||
#[must_use]
|
|
||||||
pub fn with_data(
|
|
||||||
schema: Schema<Query, Mutation, Subscription>,
|
|
||||||
stream: S,
|
|
||||||
data_initializer: F,
|
|
||||||
protocol: Protocols,
|
|
||||||
) -> Self {
|
|
||||||
// let stream = stream.map(|message| serde_json::from_slice(message.as_ref()));
|
|
||||||
let stream = stream
|
|
||||||
.map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result<ClientMessage>);
|
|
||||||
|
|
||||||
Self::with_message_stream(schema, stream, data_initializer, protocol)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<S, F, Query, Mutation, Subscription> WebSocket<S, F, Query, Mutation, Subscription>
|
|
||||||
where
|
where
|
||||||
S: Stream<Item = serde_json::Result<ClientMessage>>,
|
S: Stream<Item = serde_json::Result<ClientMessage>>,
|
||||||
{
|
{
|
||||||
/// Create a new websocket with a data initialization function from a stream of ClientMessage
|
/// Create a new websocket from [`ClientMessage`] stream.
|
||||||
/// structs.
|
pub fn from_message_stream(
|
||||||
///
|
|
||||||
/// This function, if present, will be called with the data sent by the client in the
|
|
||||||
/// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
|
|
||||||
/// From that point on the returned data will be accessible to all requests.
|
|
||||||
#[must_use]
|
|
||||||
pub fn with_message_stream(
|
|
||||||
schema: Schema<Query, Mutation, Subscription>,
|
schema: Schema<Query, Mutation, Subscription>,
|
||||||
stream: S,
|
stream: S,
|
||||||
data_initializer: F,
|
|
||||||
protocol: Protocols,
|
protocol: Protocols,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
WebSocket {
|
WebSocket {
|
||||||
data_initializer: Some(data_initializer),
|
data_initializer: Some(Box::new(|_| Box::pin(async move { Ok(Data::default()) }))),
|
||||||
init_fut: None,
|
init_fut: None,
|
||||||
|
connection_data: None,
|
||||||
data: None,
|
data: None,
|
||||||
schema,
|
schema,
|
||||||
streams: HashMap::new(),
|
streams: HashMap::new(),
|
||||||
|
@ -161,14 +102,56 @@ where
|
||||||
protocol,
|
protocol,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Specify a connection initializer.
|
||||||
|
///
|
||||||
|
/// This function, if present, will be called with the data sent by the client in the
|
||||||
|
/// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
|
||||||
|
/// From that point on the returned data will be accessible to all requests.
|
||||||
|
pub fn with_initializer<F, R>(mut self, initializer: F) -> Self
|
||||||
|
where
|
||||||
|
F: FnOnce(serde_json::Value) -> R + Send + 'static,
|
||||||
|
R: Future<Output = Result<Data>> + Send + 'static,
|
||||||
|
{
|
||||||
|
self.data_initializer = Some(Box::new(move |value| Box::pin(initializer(value))));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Specify a connection data.
|
||||||
|
///
|
||||||
|
/// This data usually comes from HTTP requests.
|
||||||
|
/// When the `GQL_CONNECTION_INIT` message is received, this data will be merged with the data
|
||||||
|
/// returned by the closure specified by `with_initializer` into the final subscription context data.
|
||||||
|
pub fn connection_data(mut self, data: Data) -> Self {
|
||||||
|
self.connection_data = Some(data);
|
||||||
|
self
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S, F, R, Query, Mutation, Subscription> Stream
|
impl<S, Query, Mutation, Subscription> WebSocket<MessageMapStream<S>, Query, Mutation, Subscription>
|
||||||
for WebSocket<S, F, Query, Mutation, Subscription>
|
where
|
||||||
|
S: Stream,
|
||||||
|
S::Item: AsRef<[u8]>,
|
||||||
|
{
|
||||||
|
/// Create a new websocket from bytes stream.
|
||||||
|
pub fn new(
|
||||||
|
schema: Schema<Query, Mutation, Subscription>,
|
||||||
|
stream: S,
|
||||||
|
protocol: Protocols,
|
||||||
|
) -> Self
|
||||||
|
where
|
||||||
|
S: Stream,
|
||||||
|
S::Item: AsRef<[u8]>,
|
||||||
|
{
|
||||||
|
let stream = stream
|
||||||
|
.map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result<ClientMessage>);
|
||||||
|
WebSocket::from_message_stream(schema, stream, protocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S, Query, Mutation, Subscription> Stream for WebSocket<S, Query, Mutation, Subscription>
|
||||||
where
|
where
|
||||||
S: Stream<Item = serde_json::Result<ClientMessage>>,
|
S: Stream<Item = serde_json::Result<ClientMessage>>,
|
||||||
F: FnOnce(serde_json::Value) -> R + Send + 'static,
|
|
||||||
R: Future<Output = Result<Data>> + Send + 'static,
|
|
||||||
Query: ObjectType + 'static,
|
Query: ObjectType + 'static,
|
||||||
Mutation: ObjectType + 'static,
|
Mutation: ObjectType + 'static,
|
||||||
Subscription: SubscriptionType + 'static,
|
Subscription: SubscriptionType + 'static,
|
||||||
|
@ -266,7 +249,9 @@ where
|
||||||
*this.init_fut = None;
|
*this.init_fut = None;
|
||||||
return match res {
|
return match res {
|
||||||
Ok(data) => {
|
Ok(data) => {
|
||||||
*this.data = Some(Arc::new(data));
|
let mut ctx_data = this.connection_data.take().unwrap_or_default();
|
||||||
|
ctx_data.merge(data);
|
||||||
|
*this.data = Some(Arc::new(ctx_data));
|
||||||
Poll::Ready(Some(WsMessage::Text(
|
Poll::Ready(Some(WsMessage::Text(
|
||||||
serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
|
serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
|
||||||
)))
|
)))
|
||||||
|
|
|
@ -111,10 +111,8 @@ pub async fn test_subscription_ws_transport_with_token() {
|
||||||
|
|
||||||
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
|
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
|
||||||
let (mut tx, rx) = mpsc::unbounded();
|
let (mut tx, rx) = mpsc::unbounded();
|
||||||
let mut stream = http::WebSocket::with_data(
|
let mut stream = http::WebSocket::new(schema.clone(), rx, WebSocketProtocols::GraphQLWS)
|
||||||
schema,
|
.with_initializer(|value| async {
|
||||||
rx,
|
|
||||||
|value| async {
|
|
||||||
#[derive(serde::Deserialize)]
|
#[derive(serde::Deserialize)]
|
||||||
struct Payload {
|
struct Payload {
|
||||||
token: String,
|
token: String,
|
||||||
|
@ -124,9 +122,7 @@ pub async fn test_subscription_ws_transport_with_token() {
|
||||||
let mut data = Data::default();
|
let mut data = Data::default();
|
||||||
data.insert(Token(payload.token));
|
data.insert(Token(payload.token));
|
||||||
Ok(data)
|
Ok(data)
|
||||||
},
|
});
|
||||||
WebSocketProtocols::GraphQLWS,
|
|
||||||
);
|
|
||||||
|
|
||||||
tx.send(
|
tx.send(
|
||||||
serde_json::to_string(&value!({
|
serde_json::to_string(&value!({
|
||||||
|
@ -176,6 +172,60 @@ pub async fn test_subscription_ws_transport_with_token() {
|
||||||
})),
|
})),
|
||||||
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
|
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let (mut tx, rx) = mpsc::unbounded();
|
||||||
|
let mut data = Data::default();
|
||||||
|
data.insert(Token("123456".to_string()));
|
||||||
|
let mut stream =
|
||||||
|
http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS).connection_data(data);
|
||||||
|
|
||||||
|
tx.send(
|
||||||
|
serde_json::to_string(&value!({
|
||||||
|
"type": "connection_init",
|
||||||
|
}))
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
Some(value!({
|
||||||
|
"type": "connection_ack",
|
||||||
|
})),
|
||||||
|
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
|
tx.send(
|
||||||
|
serde_json::to_string(&value!({
|
||||||
|
"type": "start",
|
||||||
|
"id": "1",
|
||||||
|
"payload": {
|
||||||
|
"query": "subscription { values }"
|
||||||
|
},
|
||||||
|
}))
|
||||||
|
.unwrap(),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
for i in 0..10 {
|
||||||
|
assert_eq!(
|
||||||
|
Some(value!({
|
||||||
|
"type": "next",
|
||||||
|
"id": "1",
|
||||||
|
"payload": { "data": { "values": i } },
|
||||||
|
})),
|
||||||
|
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
Some(value!({
|
||||||
|
"type": "complete",
|
||||||
|
"id": "1",
|
||||||
|
})),
|
||||||
|
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
@ -296,12 +346,8 @@ pub async fn test_subscription_init_error() {
|
||||||
|
|
||||||
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
|
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
|
||||||
let (mut tx, rx) = mpsc::unbounded();
|
let (mut tx, rx) = mpsc::unbounded();
|
||||||
let mut stream = http::WebSocket::with_data(
|
let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS)
|
||||||
schema,
|
.with_initializer(|_| async move { Err("Error!".into()) });
|
||||||
rx,
|
|
||||||
|_| async move { Err("Error!".into()) },
|
|
||||||
WebSocketProtocols::GraphQLWS,
|
|
||||||
);
|
|
||||||
|
|
||||||
tx.send(
|
tx.send(
|
||||||
serde_json::to_string(&value!({
|
serde_json::to_string(&value!({
|
||||||
|
|
|
@ -107,10 +107,8 @@ pub async fn test_subscription_ws_transport_with_token() {
|
||||||
|
|
||||||
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
|
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
|
||||||
let (mut tx, rx) = mpsc::unbounded();
|
let (mut tx, rx) = mpsc::unbounded();
|
||||||
let mut stream = http::WebSocket::with_data(
|
let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS)
|
||||||
schema,
|
.with_initializer(|value| async {
|
||||||
rx,
|
|
||||||
|value| async {
|
|
||||||
#[derive(serde::Deserialize)]
|
#[derive(serde::Deserialize)]
|
||||||
struct Payload {
|
struct Payload {
|
||||||
token: String,
|
token: String,
|
||||||
|
@ -120,9 +118,7 @@ pub async fn test_subscription_ws_transport_with_token() {
|
||||||
let mut data = Data::default();
|
let mut data = Data::default();
|
||||||
data.insert(Token(payload.token));
|
data.insert(Token(payload.token));
|
||||||
Ok(data)
|
Ok(data)
|
||||||
},
|
});
|
||||||
WebSocketProtocols::SubscriptionsTransportWS,
|
|
||||||
);
|
|
||||||
|
|
||||||
tx.send(
|
tx.send(
|
||||||
serde_json::to_string(&value!({
|
serde_json::to_string(&value!({
|
||||||
|
|
Loading…
Reference in New Issue
Block a user