Make GraphQLWebSocket::new use generic stream.

Add `GraphQLWebSocket::new_with_pair` method.
This commit is contained in:
Sunli 2021-11-18 09:25:32 +08:00
parent d97714c560
commit 7275c7ded8
4 changed files with 164 additions and 49 deletions

View File

@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [3.0.3] 2021-11-18
- [integrations] Make `GraphQLWebSocket::new` use generic stream.
- [integrations] Add `GraphQLWebSocket::new_with_pair` method.
## [3.0.2] 2021-11-16 ## [3.0.2] 2021-11-16
- Add `url`, `regex` and `ip` validators. - Add `url`, `regex` and `ip` validators.

View File

@ -7,12 +7,14 @@ use async_graphql::futures_util::task::{Context, Poll};
use async_graphql::http::{WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS}; use async_graphql::http::{WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS};
use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType};
use axum::body::{box_body, BoxBody, HttpBody}; use axum::body::{box_body, BoxBody, HttpBody};
use axum::extract::ws::{CloseFrame, Message, WebSocket}; use axum::extract::ws::{CloseFrame, Message};
use axum::extract::{FromRequest, RequestParts, WebSocketUpgrade}; use axum::extract::{FromRequest, RequestParts, WebSocketUpgrade};
use axum::http::{self, Request, Response, StatusCode}; use axum::http::{self, Request, Response, StatusCode};
use axum::response::IntoResponse; use axum::response::IntoResponse;
use axum::Error;
use futures_util::future::{BoxFuture, Ready}; use futures_util::future::{BoxFuture, Ready};
use futures_util::{future, SinkExt, StreamExt}; use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{future, Sink, SinkExt, Stream, StreamExt};
use tower_service::Service; use tower_service::Service;
/// A GraphQL protocol extractor. /// A GraphQL protocol extractor.
@ -116,30 +118,61 @@ fn default_on_connection_init(_: serde_json::Value) -> Ready<async_graphql::Resu
} }
/// A Websocket connection for GraphQL subscription. /// A Websocket connection for GraphQL subscription.
pub struct GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit> { pub struct GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit> {
sink: Sink,
stream: Stream,
schema: Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
stream: WebSocket,
data: Data, data: Data,
on_connection_init: OnConnInit, on_connection_init: OnConnInit,
protocol: GraphQLProtocol, protocol: GraphQLProtocol,
} }
impl<Query, Mutation, Subscription> impl<S, Query, Mutation, Subscription>
GraphQLWebSocket<Query, Mutation, Subscription, DefaultOnConnInitType> GraphQLWebSocket<
SplitSink<S, Message>,
SplitStream<S>,
Query,
Mutation,
Subscription,
DefaultOnConnInitType,
>
where where
S: Stream<Item = Result<Message, Error>> + Sink<Message>,
Query: ObjectType + 'static, Query: ObjectType + 'static,
Mutation: ObjectType + 'static, Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static, Subscription: SubscriptionType + 'static,
{ {
/// Create a [`GraphQLWebSocket`] object. /// Create a [`GraphQLWebSocket`] object.
pub fn new( pub fn new(
stream: WebSocket, stream: S,
schema: Schema<Query, Mutation, Subscription>,
protocol: GraphQLProtocol,
) -> Self {
let (sink, stream) = stream.split();
GraphQLWebSocket::new_with_pair(sink, stream, schema, protocol)
}
}
impl<Sink, Stream, Query, Mutation, Subscription>
GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, DefaultOnConnInitType>
where
Sink: futures_util::sink::Sink<Message>,
Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
{
/// Create a [`GraphQLWebSocket`] object with sink and stream objects.
pub fn new_with_pair(
sink: Sink,
stream: Stream,
schema: Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
protocol: GraphQLProtocol, protocol: GraphQLProtocol,
) -> Self { ) -> Self {
GraphQLWebSocket { GraphQLWebSocket {
schema, sink,
stream, stream,
schema,
data: Data::default(), data: Data::default(),
on_connection_init: default_on_connection_init, on_connection_init: default_on_connection_init,
protocol, protocol,
@ -147,9 +180,11 @@ where
} }
} }
impl<Query, Mutation, Subscription, OnConnInit, OnConnInitFut> impl<Sink, Stream, Query, Mutation, Subscription, OnConnInit, OnConnInitFut>
GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit> GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit>
where where
Sink: futures_util::sink::Sink<Message>,
Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
Query: ObjectType + 'static, Query: ObjectType + 'static,
Mutation: ObjectType + 'static, Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static, Subscription: SubscriptionType + 'static,
@ -169,14 +204,15 @@ where
pub fn on_connection_init<OnConnInit2, Fut>( pub fn on_connection_init<OnConnInit2, Fut>(
self, self,
callback: OnConnInit2, callback: OnConnInit2,
) -> GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit2> ) -> GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit2>
where where
OnConnInit2: Fn(serde_json::Value) -> Fut + Send + Sync + 'static, OnConnInit2: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = async_graphql::Result<Data>> + Send + 'static, Fut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
{ {
GraphQLWebSocket { GraphQLWebSocket {
schema: self.schema, sink: self.sink,
stream: self.stream, stream: self.stream,
schema: self.schema,
data: self.data, data: self.data,
on_connection_init: callback, on_connection_init: callback,
protocol: self.protocol, protocol: self.protocol,
@ -185,9 +221,8 @@ where
/// Processing subscription requests. /// Processing subscription requests.
pub async fn serve(self) { pub async fn serve(self) {
let (mut sink, stream) = self.stream.split(); let input = self
.stream
let input = stream
.take_while(|res| future::ready(res.is_ok())) .take_while(|res| future::ready(res.is_ok()))
.map(Result::unwrap) .map(Result::unwrap)
.filter_map(|msg| { .filter_map(|msg| {
@ -199,7 +234,7 @@ where
}) })
.map(Message::into_data); .map(Message::into_data);
let mut stream = let stream =
async_graphql::http::WebSocket::new(self.schema.clone(), input, self.protocol.0) async_graphql::http::WebSocket::new(self.schema.clone(), input, self.protocol.0)
.connection_data(self.data) .connection_data(self.data)
.on_connection_init(self.on_connection_init) .on_connection_init(self.on_connection_init)
@ -211,6 +246,9 @@ where
})), })),
}); });
let sink = self.sink;
futures_util::pin_mut!(stream, sink);
while let Some(item) = stream.next().await { while let Some(item) = stream.next().await {
let _ = sink.send(item).await; let _ = sink.send(item).await;
} }

View File

@ -1,11 +1,13 @@
use std::io::Error as IoError;
use std::str::FromStr; use std::str::FromStr;
use async_graphql::http::{WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS}; use async_graphql::http::{WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS};
use async_graphql::{Data, ObjectType, Schema, SubscriptionType}; use async_graphql::{Data, ObjectType, Schema, SubscriptionType};
use futures_util::future::{self, Ready}; use futures_util::future::{self, Ready};
use futures_util::{Future, SinkExt, StreamExt}; use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{Future, Sink, SinkExt, Stream, StreamExt};
use poem::http::StatusCode; use poem::http::StatusCode;
use poem::web::websocket::{Message, WebSocket, WebSocketStream}; use poem::web::websocket::{Message, WebSocket};
use poem::{http, Endpoint, FromRequest, IntoResponse, Request, RequestBody, Response, Result}; use poem::{http, Endpoint, FromRequest, IntoResponse, Request, RequestBody, Response, Result};
/// A GraphQL protocol extractor. /// A GraphQL protocol extractor.
@ -106,30 +108,61 @@ fn default_on_connection_init(_: serde_json::Value) -> Ready<async_graphql::Resu
} }
/// A Websocket connection for GraphQL subscription. /// A Websocket connection for GraphQL subscription.
pub struct GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit> { pub struct GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit> {
sink: Sink,
stream: Stream,
schema: Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
stream: WebSocketStream,
data: Data, data: Data,
on_connection_init: OnConnInit, on_connection_init: OnConnInit,
protocol: GraphQLProtocol, protocol: GraphQLProtocol,
} }
impl<Query, Mutation, Subscription> impl<S, Query, Mutation, Subscription>
GraphQLWebSocket<Query, Mutation, Subscription, DefaultOnConnInitType> GraphQLWebSocket<
SplitSink<S, Message>,
SplitStream<S>,
Query,
Mutation,
Subscription,
DefaultOnConnInitType,
>
where where
S: Stream<Item = Result<Message, IoError>> + Sink<Message>,
Query: ObjectType + 'static, Query: ObjectType + 'static,
Mutation: ObjectType + 'static, Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static, Subscription: SubscriptionType + 'static,
{ {
/// Create a [`GraphQLWebSocket`] object. /// Create a [`GraphQLWebSocket`] object.
pub fn new( pub fn new(
stream: WebSocketStream, stream: S,
schema: Schema<Query, Mutation, Subscription>,
protocol: GraphQLProtocol,
) -> Self {
let (sink, stream) = stream.split();
GraphQLWebSocket::new_with_pair(sink, stream, schema, protocol)
}
}
impl<Sink, Stream, Query, Mutation, Subscription>
GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, DefaultOnConnInitType>
where
Sink: futures_util::sink::Sink<Message>,
Stream: futures_util::stream::Stream<Item = Result<Message, IoError>>,
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
{
/// Create a [`GraphQLWebSocket`] object with sink and stream objects.
pub fn new_with_pair(
sink: Sink,
stream: Stream,
schema: Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
protocol: GraphQLProtocol, protocol: GraphQLProtocol,
) -> Self { ) -> Self {
GraphQLWebSocket { GraphQLWebSocket {
schema, sink,
stream, stream,
schema,
data: Data::default(), data: Data::default(),
on_connection_init: default_on_connection_init, on_connection_init: default_on_connection_init,
protocol, protocol,
@ -137,9 +170,11 @@ where
} }
} }
impl<Query, Mutation, Subscription, OnConnInit, OnConnInitFut> impl<Sink, Stream, Query, Mutation, Subscription, OnConnInit, OnConnInitFut>
GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit> GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit>
where where
Sink: futures_util::sink::Sink<Message>,
Stream: futures_util::stream::Stream<Item = Result<Message, IoError>>,
Query: ObjectType + 'static, Query: ObjectType + 'static,
Mutation: ObjectType + 'static, Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static, Subscription: SubscriptionType + 'static,
@ -159,14 +194,15 @@ where
pub fn on_connection_init<OnConnInit2, Fut>( pub fn on_connection_init<OnConnInit2, Fut>(
self, self,
callback: OnConnInit2, callback: OnConnInit2,
) -> GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit2> ) -> GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit2>
where where
OnConnInit2: Fn(serde_json::Value) -> Fut + Send + Sync + 'static, OnConnInit2: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = async_graphql::Result<Data>> + Send + 'static, Fut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
{ {
GraphQLWebSocket { GraphQLWebSocket {
schema: self.schema, sink: self.sink,
stream: self.stream, stream: self.stream,
schema: self.schema,
data: self.data, data: self.data,
on_connection_init: callback, on_connection_init: callback,
protocol: self.protocol, protocol: self.protocol,
@ -175,9 +211,8 @@ where
/// Processing subscription requests. /// Processing subscription requests.
pub async fn serve(self) { pub async fn serve(self) {
let (mut sink, stream) = self.stream.split(); let stream = self
.stream
let stream = stream
.take_while(|res| future::ready(res.is_ok())) .take_while(|res| future::ready(res.is_ok()))
.map(Result::unwrap) .map(Result::unwrap)
.filter_map(|msg| { .filter_map(|msg| {
@ -189,7 +224,7 @@ where
}) })
.map(Message::into_bytes); .map(Message::into_bytes);
let mut stream = let stream =
async_graphql::http::WebSocket::new(self.schema.clone(), stream, self.protocol.0) async_graphql::http::WebSocket::new(self.schema.clone(), stream, self.protocol.0)
.connection_data(self.data) .connection_data(self.data)
.on_connection_init(self.on_connection_init) .on_connection_init(self.on_connection_init)
@ -198,6 +233,9 @@ where
WsMessage::Close(code, status) => Message::close_with(code, status), WsMessage::Close(code, status) => Message::close_with(code, status),
}); });
let sink = self.sink;
futures_util::pin_mut!(stream, sink);
while let Some(item) = stream.next().await { while let Some(item) = stream.next().await {
let _ = sink.send(item).await; let _ = sink.send(item).await;
} }

View File

@ -4,10 +4,11 @@ use std::str::FromStr;
use async_graphql::http::{WebSocketProtocols, WsMessage}; use async_graphql::http::{WebSocketProtocols, WsMessage};
use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType};
use futures_util::future::Ready; use futures_util::future::Ready;
use futures_util::{future, StreamExt}; use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{future, Sink, Stream, StreamExt};
use warp::filters::ws; use warp::filters::ws;
use warp::ws::WebSocket; use warp::ws::Message;
use warp::{Filter, Rejection, Reply}; use warp::{Error, Filter, Rejection, Reply};
/// GraphQL subscription filter /// GraphQL subscription filter
/// ///
@ -157,29 +158,60 @@ fn default_on_connection_init(_: serde_json::Value) -> Ready<async_graphql::Resu
/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await; /// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
/// }); /// });
/// ``` /// ```
pub struct GraphQLWebSocket<Query, Mutation, Subscription, OnInit> { pub struct GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnInit> {
socket: WebSocket, sink: Sink,
stream: Stream,
protocol: WebSocketProtocols, protocol: WebSocketProtocols,
schema: Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
data: Data, data: Data,
on_init: OnInit, on_init: OnInit,
} }
impl<Query, Mutation, Subscription> impl<S, Query, Mutation, Subscription>
GraphQLWebSocket<Query, Mutation, Subscription, DefaultOnConnInitType> GraphQLWebSocket<
SplitSink<S, Message>,
SplitStream<S>,
Query,
Mutation,
Subscription,
DefaultOnConnInitType,
>
where where
S: Stream<Item = Result<Message, Error>> + Sink<Message>,
Query: ObjectType + 'static, Query: ObjectType + 'static,
Mutation: ObjectType + 'static, Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static, Subscription: SubscriptionType + 'static,
{ {
/// Create a [`GraphQLWebSocket`] object. /// Create a [`GraphQLWebSocket`] object.
pub fn new( pub fn new(
socket: WebSocket, socket: S,
schema: Schema<Query, Mutation, Subscription>,
protocol: WebSocketProtocols,
) -> Self {
let (sink, stream) = socket.split();
GraphQLWebSocket::new_with_pair(sink, stream, schema, protocol)
}
}
impl<Sink, Stream, Query, Mutation, Subscription>
GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, DefaultOnConnInitType>
where
Sink: futures_util::sink::Sink<Message>,
Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
{
/// Create a [`GraphQLWebSocket`] object with sink and stream objects.
pub fn new_with_pair(
sink: Sink,
stream: Stream,
schema: Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
protocol: WebSocketProtocols, protocol: WebSocketProtocols,
) -> Self { ) -> Self {
GraphQLWebSocket { GraphQLWebSocket {
socket, sink,
stream,
protocol, protocol,
schema, schema,
data: Data::default(), data: Data::default(),
@ -188,9 +220,11 @@ where
} }
} }
impl<Query, Mutation, Subscription, OnConnInit, OnConnInitFut> impl<Sink, Stream, Query, Mutation, Subscription, OnConnInit, OnConnInitFut>
GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit> GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit>
where where
Sink: futures_util::sink::Sink<Message>,
Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
Query: ObjectType + 'static, Query: ObjectType + 'static,
Mutation: ObjectType + 'static, Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static, Subscription: SubscriptionType + 'static,
@ -210,13 +244,14 @@ where
pub fn on_connection_init<OnConnInit2, Fut>( pub fn on_connection_init<OnConnInit2, Fut>(
self, self,
callback: OnConnInit2, callback: OnConnInit2,
) -> GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit2> ) -> GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit2>
where where
OnConnInit2: Fn(serde_json::Value) -> Fut + Send + Sync + 'static, OnConnInit2: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = async_graphql::Result<Data>> + Send + 'static, Fut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
{ {
GraphQLWebSocket { GraphQLWebSocket {
socket: self.socket, sink: self.sink,
stream: self.stream,
schema: self.schema, schema: self.schema,
data: self.data, data: self.data,
on_init: callback, on_init: callback,
@ -226,9 +261,8 @@ where
/// Processing subscription requests. /// Processing subscription requests.
pub async fn serve(self) { pub async fn serve(self) {
let (ws_sender, ws_receiver) = self.socket.split(); let stream = self
.stream
let stream = ws_receiver
.take_while(|msg| future::ready(msg.is_ok())) .take_while(|msg| future::ready(msg.is_ok()))
.map(Result::unwrap) .map(Result::unwrap)
.filter(|msg| future::ready(msg.is_text() || msg.is_binary())) .filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
@ -242,7 +276,7 @@ where
WsMessage::Close(code, status) => ws::Message::close_with(code, status), WsMessage::Close(code, status) => ws::Message::close_with(code, status),
}) })
.map(Ok) .map(Ok)
.forward(ws_sender) .forward(self.sink)
.await; .await;
} }
} }