Merge pull request #501 from D1plo1d/feature/websocket-file-uploads

Added `Websocket::with_message_stream` for client message customization
This commit is contained in:
Sunli 2021-05-08 14:10:28 +08:00 committed by GitHub
commit 6dd7245914
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 86 additions and 15 deletions

View File

@ -12,7 +12,12 @@ use crate::{BatchRequest, ParseRequestError, Request};
pub use graphiql_source::graphiql_source;
pub use multipart::MultipartOptions;
pub use playground_source::{playground_source, GraphQLPlaygroundConfig};
pub use websocket::{Protocols as WebSocketProtocols, WebSocket, WsMessage};
pub use websocket::{
Protocols as WebSocketProtocols,
WebSocket,
WsMessage,
ClientMessage,
};
/// Receive a GraphQL request from a content type and body.
pub async fn receive_body(

View File

@ -1,12 +1,12 @@
//! WebSocket transport for subscription
use std::collections::HashMap;
use std::{collections::HashMap};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures_util::future::{BoxFuture, Ready};
use futures_util::{StreamExt, future::{BoxFuture, Ready}};
use futures_util::stream::Stream;
use futures_util::FutureExt;
use pin_project_lite::pin_project;
@ -70,8 +70,14 @@ pin_project! {
}
}
type MessageMapStream<S> =
futures_util::stream::Map<S, fn(<S as Stream>::Item) -> serde_json::Result<ClientMessage>>;
impl<S, Query, Mutation, Subscription>
WebSocket<S, fn(serde_json::Value) -> Ready<Result<Data>>, Query, Mutation, Subscription>
WebSocket<MessageMapStream<S>, fn(serde_json::Value) -> Ready<Result<Data>>, Query, Mutation, Subscription>
where
S: Stream,
S::Item: AsRef<[u8]>,
{
/// Create a new websocket.
#[must_use]
@ -89,7 +95,12 @@ impl<S, Query, Mutation, Subscription>
}
}
impl<S, F, Query, Mutation, Subscription> WebSocket<S, F, Query, Mutation, Subscription> {
impl<S, F, Query, Mutation, Subscription>
WebSocket<MessageMapStream<S>, F, Query, Mutation, Subscription>
where
S: Stream,
S::Item: AsRef<[u8]>,
{
/// Create a new websocket with a data initialization function.
///
/// This function, if present, will be called with the data sent by the client in the
@ -101,8 +112,42 @@ impl<S, F, Query, Mutation, Subscription> WebSocket<S, F, Query, Mutation, Subsc
stream: S,
data_initializer: F,
protocol: Protocols,
) -> Self {
Self {
) -> Self
{
// let stream = stream.map(|message| serde_json::from_slice(message.as_ref()));
let stream = stream.map(
ClientMessage::from_bytes as fn(S::Item) -> serde_json::Result<ClientMessage>,
);
Self::with_message_stream(
schema,
stream,
data_initializer,
protocol,
)
}
}
impl<S, F, Query, Mutation, Subscription>
WebSocket<S, F, Query, Mutation, Subscription>
where
S: Stream<Item = serde_json::Result<ClientMessage>>,
{
/// Create a new websocket with a data initialization function from a stream of ClientMessage
/// structs.
///
/// This function, if present, will be called with the data sent by the client in the
/// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
/// From that point on the returned data will be accessible to all requests.
#[must_use]
pub fn with_message_stream(
schema: Schema<Query, Mutation, Subscription>,
stream: S,
data_initializer: F,
protocol: Protocols,
) -> Self
{
WebSocket {
data_initializer: Some(data_initializer),
init_fut: None,
data: None,
@ -117,8 +162,7 @@ impl<S, F, Query, Mutation, Subscription> WebSocket<S, F, Query, Mutation, Subsc
impl<S, F, R, Query, Mutation, Subscription> Stream
for WebSocket<S, F, Query, Mutation, Subscription>
where
S: Stream,
S::Item: AsRef<[u8]>,
S: Stream<Item = serde_json::Result<ClientMessage>>,
F: FnOnce(serde_json::Value) -> R + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
Query: ObjectType + 'static,
@ -137,9 +181,11 @@ where
None => return Poll::Ready(None),
};
let message: ClientMessage = match serde_json::from_slice(message.as_ref()) {
let message: ClientMessage = match message {
Ok(message) => message,
Err(err) => return Poll::Ready(Some(WsMessage::Close(1002, err.to_string()))),
Err(err) => {
return Poll::Ready(Some(WsMessage::Close(1002, err.to_string())))
}
};
match message {
@ -189,9 +235,9 @@ where
}
}
ClientMessage::Stop { id } => {
if this.streams.remove(id).is_some() {
if this.streams.remove(&id).is_some() {
return Poll::Ready(Some(WsMessage::Text(
serde_json::to_string(&ServerMessage::Complete { id }).unwrap(),
serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(),
)));
}
}
@ -294,24 +340,44 @@ impl std::str::FromStr for Protocols {
}
}
/// A websocket message received from the client
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ClientMessage<'a> {
pub enum ClientMessage {
/// A new connection
ConnectionInit {
/// Optional init payload from the client
payload: Option<serde_json::Value>,
},
/// The start of a Websocket subscription
#[serde(alias = "subscribe")]
Start {
/// Message ID
id: String,
/// The GraphQL Request - this can be modified by protocol implementors to add files
/// uploads.
payload: Request,
},
/// The end of a Websocket subscription
#[serde(alias = "complete")]
Stop {
id: &'a str,
/// Message ID
id: String,
},
/// Connection terminated by the client
ConnectionTerminate,
}
impl ClientMessage {
/// Creates a ClientMessage from an array of bytes
pub fn from_bytes<T>(message: T) -> serde_json::Result<Self>
where
T: AsRef<[u8]>,
{
serde_json::from_slice(message.as_ref())
}
}
#[derive(Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ServerMessage<'a> {