Rework websocket

This commit is contained in:
Sunli 2021-11-09 17:01:51 +08:00
parent 396e7bc175
commit b75aa6c25b
4 changed files with 127 additions and 96 deletions

View File

@ -41,6 +41,10 @@ impl Data {
pub fn insert<D: Any + Send + Sync>(&mut self, data: D) { pub fn insert<D: Any + Send + Sync>(&mut self, data: D) {
self.0.insert(TypeId::of::<D>(), Box::new(data)); self.0.insert(TypeId::of::<D>(), Box::new(data));
} }
pub(crate) fn merge(&mut self, other: Data) {
self.0.extend(other.0);
}
} }
impl Debug for Data { impl Debug for Data {

View File

@ -8,10 +8,7 @@ use std::task::{Context, Poll};
use futures_util::stream::Stream; use futures_util::stream::Stream;
use futures_util::FutureExt; use futures_util::FutureExt;
use futures_util::{ use futures_util::{future::BoxFuture, StreamExt};
future::{BoxFuture, Ready},
StreamExt,
};
use pin_project_lite::pin_project; use pin_project_lite::pin_project;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -60,14 +57,18 @@ impl WsMessage {
} }
} }
type BoxInitializer =
Box<(dyn FnOnce(serde_json::Value) -> BoxFuture<'static, Result<Data>> + Send + 'static)>;
pin_project! { pin_project! {
/// A GraphQL connection over websocket. /// A GraphQL connection over websocket.
/// ///
/// [Reference](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md). /// [Reference](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md).
/// [Reference](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md). /// [Reference](https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md).
pub struct WebSocket<S, F, Query, Mutation, Subscription> { pub struct WebSocket<S, Query, Mutation, Subscription> {
data_initializer: Option<F>, data_initializer: Option<BoxInitializer>,
init_fut: Option<BoxFuture<'static, Result<Data>>>, init_fut: Option<BoxFuture<'static, Result<Data>>>,
connection_data: Option<Data>,
data: Option<Arc<Data>>, data: Option<Arc<Data>>,
schema: Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>, streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
@ -80,80 +81,20 @@ pin_project! {
type MessageMapStream<S> = type MessageMapStream<S> =
futures_util::stream::Map<S, fn(<S as Stream>::Item) -> serde_json::Result<ClientMessage>>; futures_util::stream::Map<S, fn(<S as Stream>::Item) -> serde_json::Result<ClientMessage>>;
impl<S, Query, Mutation, Subscription> impl<S, Query, Mutation, Subscription> WebSocket<S, Query, Mutation, Subscription>
WebSocket<
MessageMapStream<S>,
fn(serde_json::Value) -> Ready<Result<Data>>,
Query,
Mutation,
Subscription,
>
where
S: Stream,
S::Item: AsRef<[u8]>,
{
/// Create a new websocket.
#[must_use]
pub fn new(
schema: Schema<Query, Mutation, Subscription>,
stream: S,
protocol: Protocols,
) -> Self {
Self::with_data(
schema,
stream,
|_| futures_util::future::ready(Ok(Default::default())),
protocol,
)
}
}
impl<S, F, Query, Mutation, Subscription>
WebSocket<MessageMapStream<S>, F, Query, Mutation, Subscription>
where
S: Stream,
S::Item: AsRef<[u8]>,
{
/// Create a new websocket with a data initialization function.
///
/// This function, if present, will be called with the data sent by the client in the
/// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
/// From that point on the returned data will be accessible to all requests.
#[must_use]
pub fn with_data(
schema: Schema<Query, Mutation, Subscription>,
stream: S,
data_initializer: F,
protocol: Protocols,
) -> Self {
// let stream = stream.map(|message| serde_json::from_slice(message.as_ref()));
let stream = stream
.map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result<ClientMessage>);
Self::with_message_stream(schema, stream, data_initializer, protocol)
}
}
impl<S, F, Query, Mutation, Subscription> WebSocket<S, F, Query, Mutation, Subscription>
where where
S: Stream<Item = serde_json::Result<ClientMessage>>, S: Stream<Item = serde_json::Result<ClientMessage>>,
{ {
/// Create a new websocket with a data initialization function from a stream of ClientMessage /// Create a new websocket from [`ClientMessage`] stream.
/// structs. pub fn from_message_stream(
///
/// This function, if present, will be called with the data sent by the client in the
/// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
/// From that point on the returned data will be accessible to all requests.
#[must_use]
pub fn with_message_stream(
schema: Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
stream: S, stream: S,
data_initializer: F,
protocol: Protocols, protocol: Protocols,
) -> Self { ) -> Self {
WebSocket { WebSocket {
data_initializer: Some(data_initializer), data_initializer: Some(Box::new(|_| Box::pin(async move { Ok(Data::default()) }))),
init_fut: None, init_fut: None,
connection_data: None,
data: None, data: None,
schema, schema,
streams: HashMap::new(), streams: HashMap::new(),
@ -161,14 +102,56 @@ where
protocol, protocol,
} }
} }
/// Specify a connection initializer.
///
/// This function, if present, will be called with the data sent by the client in the
/// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
/// From that point on the returned data will be accessible to all requests.
pub fn with_initializer<F, R>(mut self, initializer: F) -> Self
where
F: FnOnce(serde_json::Value) -> R + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
{
self.data_initializer = Some(Box::new(move |value| Box::pin(initializer(value))));
self
}
/// Specify a connection data.
///
/// This data usually comes from HTTP requests.
/// When the `GQL_CONNECTION_INIT` message is received, this data will be merged with the data
/// returned by the closure specified by `with_initializer` into the final subscription context data.
pub fn connection_data(mut self, data: Data) -> Self {
self.connection_data = Some(data);
self
}
} }
impl<S, F, R, Query, Mutation, Subscription> Stream impl<S, Query, Mutation, Subscription> WebSocket<MessageMapStream<S>, Query, Mutation, Subscription>
for WebSocket<S, F, Query, Mutation, Subscription> where
S: Stream,
S::Item: AsRef<[u8]>,
{
/// Create a new websocket from bytes stream.
pub fn new(
schema: Schema<Query, Mutation, Subscription>,
stream: S,
protocol: Protocols,
) -> Self
where
S: Stream,
S::Item: AsRef<[u8]>,
{
let stream = stream
.map(ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result<ClientMessage>);
WebSocket::from_message_stream(schema, stream, protocol)
}
}
impl<S, Query, Mutation, Subscription> Stream for WebSocket<S, Query, Mutation, Subscription>
where where
S: Stream<Item = serde_json::Result<ClientMessage>>, S: Stream<Item = serde_json::Result<ClientMessage>>,
F: FnOnce(serde_json::Value) -> R + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
Query: ObjectType + 'static, Query: ObjectType + 'static,
Mutation: ObjectType + 'static, Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static, Subscription: SubscriptionType + 'static,
@ -266,7 +249,9 @@ where
*this.init_fut = None; *this.init_fut = None;
return match res { return match res {
Ok(data) => { Ok(data) => {
*this.data = Some(Arc::new(data)); let mut ctx_data = this.connection_data.take().unwrap_or_default();
ctx_data.merge(data);
*this.data = Some(Arc::new(ctx_data));
Poll::Ready(Some(WsMessage::Text( Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(), serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
))) )))

View File

@ -111,10 +111,8 @@ pub async fn test_subscription_ws_transport_with_token() {
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut tx, rx) = mpsc::unbounded(); let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::with_data( let mut stream = http::WebSocket::new(schema.clone(), rx, WebSocketProtocols::GraphQLWS)
schema, .with_initializer(|value| async {
rx,
|value| async {
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct Payload { struct Payload {
token: String, token: String,
@ -124,9 +122,7 @@ pub async fn test_subscription_ws_transport_with_token() {
let mut data = Data::default(); let mut data = Data::default();
data.insert(Token(payload.token)); data.insert(Token(payload.token));
Ok(data) Ok(data)
}, });
WebSocketProtocols::GraphQLWS,
);
tx.send( tx.send(
serde_json::to_string(&value!({ serde_json::to_string(&value!({
@ -176,6 +172,60 @@ pub async fn test_subscription_ws_transport_with_token() {
})), })),
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap() serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
); );
let (mut tx, rx) = mpsc::unbounded();
let mut data = Data::default();
data.insert(Token("123456".to_string()));
let mut stream =
http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS).connection_data(data);
tx.send(
serde_json::to_string(&value!({
"type": "connection_init",
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(value!({
"type": "connection_ack",
})),
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
tx.send(
serde_json::to_string(&value!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
Some(value!({
"type": "next",
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
}
assert_eq!(
Some(value!({
"type": "complete",
"id": "1",
})),
serde_json::from_str(&stream.next().await.unwrap().unwrap_text()).unwrap()
);
} }
#[tokio::test] #[tokio::test]
@ -296,12 +346,8 @@ pub async fn test_subscription_init_error() {
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut tx, rx) = mpsc::unbounded(); let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::with_data( let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::GraphQLWS)
schema, .with_initializer(|_| async move { Err("Error!".into()) });
rx,
|_| async move { Err("Error!".into()) },
WebSocketProtocols::GraphQLWS,
);
tx.send( tx.send(
serde_json::to_string(&value!({ serde_json::to_string(&value!({

View File

@ -107,10 +107,8 @@ pub async fn test_subscription_ws_transport_with_token() {
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot); let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut tx, rx) = mpsc::unbounded(); let (mut tx, rx) = mpsc::unbounded();
let mut stream = http::WebSocket::with_data( let mut stream = http::WebSocket::new(schema, rx, WebSocketProtocols::SubscriptionsTransportWS)
schema, .with_initializer(|value| async {
rx,
|value| async {
#[derive(serde::Deserialize)] #[derive(serde::Deserialize)]
struct Payload { struct Payload {
token: String, token: String,
@ -120,9 +118,7 @@ pub async fn test_subscription_ws_transport_with_token() {
let mut data = Data::default(); let mut data = Data::default();
data.insert(Token(payload.token)); data.insert(Token(payload.token));
Ok(data) Ok(data)
}, });
WebSocketProtocols::SubscriptionsTransportWS,
);
tx.send( tx.send(
serde_json::to_string(&value!({ serde_json::to_string(&value!({