async-graphql/src/http/websocket.rs
2020-10-16 07:49:22 +01:00

196 lines
6.9 KiB
Rust

//! WebSocket transport for subscription
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures_util::stream::Stream;
use pin_project_lite::pin_project;
use serde::{Deserialize, Serialize};
use crate::{Data, Error, ObjectType, Request, Response, Result, Schema, SubscriptionType};
pin_project! {
/// A GraphQL connection over websocket.
///
/// [Reference](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md).
pub struct WebSocket<S, F, Query, Mutation, Subscription> {
data_initializer: Option<F>,
data: Arc<Data>,
schema: Schema<Query, Mutation, Subscription>,
streams: HashMap<String, Pin<Box<dyn Stream<Item = Response> + Send>>>,
#[pin]
stream: S,
}
}
impl<S, Query, Mutation, Subscription>
WebSocket<S, fn(serde_json::Value) -> Result<Data>, Query, Mutation, Subscription>
{
/// Create a new websocket.
#[must_use]
pub fn new(schema: Schema<Query, Mutation, Subscription>, stream: S) -> Self {
Self {
data_initializer: None,
data: Arc::default(),
schema,
streams: HashMap::new(),
stream,
}
}
}
impl<S, F, Query, Mutation, Subscription> WebSocket<S, F, Query, Mutation, Subscription> {
/// Create a new websocket with a data initialization function.
///
/// This function, if present, will be called with the data sent by the client in the
/// [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init).
/// From that point on the returned data will be accessible to all requests.
#[must_use]
pub fn with_data(
schema: Schema<Query, Mutation, Subscription>,
stream: S,
data_initializer: Option<F>,
) -> Self {
Self {
data_initializer,
data: Arc::default(),
schema,
streams: HashMap::new(),
stream,
}
}
}
impl<S, F, Query, Mutation, Subscription> Stream for WebSocket<S, F, Query, Mutation, Subscription>
where
S: Stream,
S::Item: AsRef<[u8]>,
F: FnOnce(serde_json::Value) -> Result<Data>,
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
type Item = String;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let this = self.project();
match this.stream.poll_next(cx) {
Poll::Ready(message) => {
let message = match message {
Some(message) => message,
None => return Poll::Ready(None),
};
let message: ClientMessage = match serde_json::from_slice(message.as_ref()) {
Ok(message) => message,
Err(e) => {
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::ConnectionError {
payload: Error::new(e.to_string()),
})
.unwrap(),
))
}
};
match message {
ClientMessage::ConnectionInit { payload } => {
if let Some(payload) = payload {
if let Some(data_initializer) = this.data_initializer.take() {
*this.data = Arc::new(match data_initializer(payload) {
Ok(data) => data,
Err(e) => {
return Poll::Ready(Some(
serde_json::to_string(
&ServerMessage::ConnectionError { payload: e },
)
.unwrap(),
))
}
});
}
}
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::ConnectionAck).unwrap(),
));
}
ClientMessage::Start {
id,
payload: request,
} => {
this.streams.insert(
id,
Box::pin(
this.schema
.execute_stream_with_ctx_data(request, Arc::clone(this.data)),
),
);
}
ClientMessage::Stop { id } => {
if this.streams.remove(id).is_some() {
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::Complete { id }).unwrap(),
));
}
}
ClientMessage::ConnectionTerminate => return Poll::Ready(None),
}
}
Poll::Pending => {}
}
for (id, stream) in &mut *this.streams {
match Pin::new(stream).poll_next(cx) {
Poll::Ready(Some(payload)) => {
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::Data {
id,
payload: Box::new(payload),
})
.unwrap(),
));
}
Poll::Ready(None) => {
let id = id.clone();
this.streams.remove(&id);
return Poll::Ready(Some(
serde_json::to_string(&ServerMessage::Complete { id: &id }).unwrap(),
));
}
Poll::Pending => {}
}
}
Poll::Pending
}
}
#[derive(Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ClientMessage<'a> {
ConnectionInit { payload: Option<serde_json::Value> },
Start { id: String, payload: Request },
Stop { id: &'a str },
ConnectionTerminate,
}
#[derive(Serialize)]
#[serde(tag = "type", rename_all = "snake_case")]
enum ServerMessage<'a> {
ConnectionError { payload: Error },
ConnectionAck,
Data { id: &'a str, payload: Box<Response> },
// Not used by this library, as it's not necessary to send
// Error {
// id: &'a str,
// payload: serde_json::Value,
// },
Complete { id: &'a str },
// Not used by this library
// #[serde(rename = "ka")]
// KeepAlive
}