Make `GraphQLWebSocket::new` use generic stream.

Add `GraphQLWebSocket::new_with_pair` method.
This commit is contained in:
Sunli 2021-11-18 09:25:32 +08:00
parent 8243921fa4
commit bd621f70c1
4 changed files with 164 additions and 49 deletions

View File

@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [3.0.3] 2021-11-18
- [integrations] Make `GraphQLWebSocket::new` use generic stream.
- [integrations] Add `GraphQLWebSocket::new_with_pair` method.
## [3.0.2] 2021-11-16
- Add `url`, `regex` and `ip` validators.

View File

@ -7,12 +7,14 @@ use async_graphql::futures_util::task::{Context, Poll};
use async_graphql::http::{WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS};
use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType};
use axum::body::{box_body, BoxBody, HttpBody};
use axum::extract::ws::{CloseFrame, Message, WebSocket};
use axum::extract::ws::{CloseFrame, Message};
use axum::extract::{FromRequest, RequestParts, WebSocketUpgrade};
use axum::http::{self, Request, Response, StatusCode};
use axum::response::IntoResponse;
use axum::Error;
use futures_util::future::{BoxFuture, Ready};
use futures_util::{future, SinkExt, StreamExt};
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{future, Sink, SinkExt, Stream, StreamExt};
use tower_service::Service;
/// A GraphQL protocol extractor.
@ -116,30 +118,61 @@ fn default_on_connection_init(_: serde_json::Value) -> Ready<async_graphql::Resu
}
/// A Websocket connection for GraphQL subscription.
pub struct GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit> {
pub struct GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit> {
sink: Sink,
stream: Stream,
schema: Schema<Query, Mutation, Subscription>,
stream: WebSocket,
data: Data,
on_connection_init: OnConnInit,
protocol: GraphQLProtocol,
}
impl<Query, Mutation, Subscription>
GraphQLWebSocket<Query, Mutation, Subscription, DefaultOnConnInitType>
impl<S, Query, Mutation, Subscription>
GraphQLWebSocket<
SplitSink<S, Message>,
SplitStream<S>,
Query,
Mutation,
Subscription,
DefaultOnConnInitType,
>
where
S: Stream<Item = Result<Message, Error>> + Sink<Message>,
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
{
/// Create a [`GraphQLWebSocket`] object.
pub fn new(
stream: WebSocket,
stream: S,
schema: Schema<Query, Mutation, Subscription>,
protocol: GraphQLProtocol,
) -> Self {
let (sink, stream) = stream.split();
GraphQLWebSocket::new_with_pair(sink, stream, schema, protocol)
}
}
impl<Sink, Stream, Query, Mutation, Subscription>
GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, DefaultOnConnInitType>
where
Sink: futures_util::sink::Sink<Message>,
Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
{
/// Create a [`GraphQLWebSocket`] object with sink and stream objects.
pub fn new_with_pair(
sink: Sink,
stream: Stream,
schema: Schema<Query, Mutation, Subscription>,
protocol: GraphQLProtocol,
) -> Self {
GraphQLWebSocket {
schema,
sink,
stream,
schema,
data: Data::default(),
on_connection_init: default_on_connection_init,
protocol,
@ -147,9 +180,11 @@ where
}
}
impl<Query, Mutation, Subscription, OnConnInit, OnConnInitFut>
GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit>
impl<Sink, Stream, Query, Mutation, Subscription, OnConnInit, OnConnInitFut>
GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit>
where
Sink: futures_util::sink::Sink<Message>,
Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
@ -169,14 +204,15 @@ where
pub fn on_connection_init<OnConnInit2, Fut>(
self,
callback: OnConnInit2,
) -> GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit2>
) -> GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit2>
where
OnConnInit2: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
{
GraphQLWebSocket {
schema: self.schema,
sink: self.sink,
stream: self.stream,
schema: self.schema,
data: self.data,
on_connection_init: callback,
protocol: self.protocol,
@ -185,9 +221,8 @@ where
/// Processing subscription requests.
pub async fn serve(self) {
let (mut sink, stream) = self.stream.split();
let input = stream
let input = self
.stream
.take_while(|res| future::ready(res.is_ok()))
.map(Result::unwrap)
.filter_map(|msg| {
@ -199,7 +234,7 @@ where
})
.map(Message::into_data);
let mut stream =
let stream =
async_graphql::http::WebSocket::new(self.schema.clone(), input, self.protocol.0)
.connection_data(self.data)
.on_connection_init(self.on_connection_init)
@ -211,6 +246,9 @@ where
})),
});
let sink = self.sink;
futures_util::pin_mut!(stream, sink);
while let Some(item) = stream.next().await {
let _ = sink.send(item).await;
}

View File

@ -1,11 +1,13 @@
use std::io::Error as IoError;
use std::str::FromStr;
use async_graphql::http::{WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS};
use async_graphql::{Data, ObjectType, Schema, SubscriptionType};
use futures_util::future::{self, Ready};
use futures_util::{Future, SinkExt, StreamExt};
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{Future, Sink, SinkExt, Stream, StreamExt};
use poem::http::StatusCode;
use poem::web::websocket::{Message, WebSocket, WebSocketStream};
use poem::web::websocket::{Message, WebSocket};
use poem::{http, Endpoint, FromRequest, IntoResponse, Request, RequestBody, Response, Result};
/// A GraphQL protocol extractor.
@ -106,30 +108,61 @@ fn default_on_connection_init(_: serde_json::Value) -> Ready<async_graphql::Resu
}
/// A Websocket connection for GraphQL subscription.
pub struct GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit> {
pub struct GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit> {
sink: Sink,
stream: Stream,
schema: Schema<Query, Mutation, Subscription>,
stream: WebSocketStream,
data: Data,
on_connection_init: OnConnInit,
protocol: GraphQLProtocol,
}
impl<Query, Mutation, Subscription>
GraphQLWebSocket<Query, Mutation, Subscription, DefaultOnConnInitType>
impl<S, Query, Mutation, Subscription>
GraphQLWebSocket<
SplitSink<S, Message>,
SplitStream<S>,
Query,
Mutation,
Subscription,
DefaultOnConnInitType,
>
where
S: Stream<Item = Result<Message, IoError>> + Sink<Message>,
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
{
/// Create a [`GraphQLWebSocket`] object.
pub fn new(
stream: WebSocketStream,
stream: S,
schema: Schema<Query, Mutation, Subscription>,
protocol: GraphQLProtocol,
) -> Self {
let (sink, stream) = stream.split();
GraphQLWebSocket::new_with_pair(sink, stream, schema, protocol)
}
}
impl<Sink, Stream, Query, Mutation, Subscription>
GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, DefaultOnConnInitType>
where
Sink: futures_util::sink::Sink<Message>,
Stream: futures_util::stream::Stream<Item = Result<Message, IoError>>,
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
{
/// Create a [`GraphQLWebSocket`] object with sink and stream objects.
pub fn new_with_pair(
sink: Sink,
stream: Stream,
schema: Schema<Query, Mutation, Subscription>,
protocol: GraphQLProtocol,
) -> Self {
GraphQLWebSocket {
schema,
sink,
stream,
schema,
data: Data::default(),
on_connection_init: default_on_connection_init,
protocol,
@ -137,9 +170,11 @@ where
}
}
impl<Query, Mutation, Subscription, OnConnInit, OnConnInitFut>
GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit>
impl<Sink, Stream, Query, Mutation, Subscription, OnConnInit, OnConnInitFut>
GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit>
where
Sink: futures_util::sink::Sink<Message>,
Stream: futures_util::stream::Stream<Item = Result<Message, IoError>>,
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
@ -159,14 +194,15 @@ where
pub fn on_connection_init<OnConnInit2, Fut>(
self,
callback: OnConnInit2,
) -> GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit2>
) -> GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit2>
where
OnConnInit2: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
{
GraphQLWebSocket {
schema: self.schema,
sink: self.sink,
stream: self.stream,
schema: self.schema,
data: self.data,
on_connection_init: callback,
protocol: self.protocol,
@ -175,9 +211,8 @@ where
/// Processing subscription requests.
pub async fn serve(self) {
let (mut sink, stream) = self.stream.split();
let stream = stream
let stream = self
.stream
.take_while(|res| future::ready(res.is_ok()))
.map(Result::unwrap)
.filter_map(|msg| {
@ -189,7 +224,7 @@ where
})
.map(Message::into_bytes);
let mut stream =
let stream =
async_graphql::http::WebSocket::new(self.schema.clone(), stream, self.protocol.0)
.connection_data(self.data)
.on_connection_init(self.on_connection_init)
@ -198,6 +233,9 @@ where
WsMessage::Close(code, status) => Message::close_with(code, status),
});
let sink = self.sink;
futures_util::pin_mut!(stream, sink);
while let Some(item) = stream.next().await {
let _ = sink.send(item).await;
}

View File

@ -4,10 +4,11 @@ use std::str::FromStr;
use async_graphql::http::{WebSocketProtocols, WsMessage};
use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType};
use futures_util::future::Ready;
use futures_util::{future, StreamExt};
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{future, Sink, Stream, StreamExt};
use warp::filters::ws;
use warp::ws::WebSocket;
use warp::{Filter, Rejection, Reply};
use warp::ws::Message;
use warp::{Error, Filter, Rejection, Reply};
/// GraphQL subscription filter
///
@ -157,29 +158,60 @@ fn default_on_connection_init(_: serde_json::Value) -> Ready<async_graphql::Resu
/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
/// });
/// ```
pub struct GraphQLWebSocket<Query, Mutation, Subscription, OnInit> {
socket: WebSocket,
pub struct GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnInit> {
sink: Sink,
stream: Stream,
protocol: WebSocketProtocols,
schema: Schema<Query, Mutation, Subscription>,
data: Data,
on_init: OnInit,
}
impl<Query, Mutation, Subscription>
GraphQLWebSocket<Query, Mutation, Subscription, DefaultOnConnInitType>
impl<S, Query, Mutation, Subscription>
GraphQLWebSocket<
SplitSink<S, Message>,
SplitStream<S>,
Query,
Mutation,
Subscription,
DefaultOnConnInitType,
>
where
S: Stream<Item = Result<Message, Error>> + Sink<Message>,
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
{
/// Create a [`GraphQLWebSocket`] object.
pub fn new(
socket: WebSocket,
socket: S,
schema: Schema<Query, Mutation, Subscription>,
protocol: WebSocketProtocols,
) -> Self {
let (sink, stream) = socket.split();
GraphQLWebSocket::new_with_pair(sink, stream, schema, protocol)
}
}
impl<Sink, Stream, Query, Mutation, Subscription>
GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, DefaultOnConnInitType>
where
Sink: futures_util::sink::Sink<Message>,
Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
{
/// Create a [`GraphQLWebSocket`] object with sink and stream objects.
pub fn new_with_pair(
sink: Sink,
stream: Stream,
schema: Schema<Query, Mutation, Subscription>,
protocol: WebSocketProtocols,
) -> Self {
GraphQLWebSocket {
socket,
sink,
stream,
protocol,
schema,
data: Data::default(),
@ -188,9 +220,11 @@ where
}
}
impl<Query, Mutation, Subscription, OnConnInit, OnConnInitFut>
GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit>
impl<Sink, Stream, Query, Mutation, Subscription, OnConnInit, OnConnInitFut>
GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit>
where
Sink: futures_util::sink::Sink<Message>,
Stream: futures_util::stream::Stream<Item = Result<Message, Error>>,
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
@ -210,13 +244,14 @@ where
pub fn on_connection_init<OnConnInit2, Fut>(
self,
callback: OnConnInit2,
) -> GraphQLWebSocket<Query, Mutation, Subscription, OnConnInit2>
) -> GraphQLWebSocket<Sink, Stream, Query, Mutation, Subscription, OnConnInit2>
where
OnConnInit2: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
{
GraphQLWebSocket {
socket: self.socket,
sink: self.sink,
stream: self.stream,
schema: self.schema,
data: self.data,
on_init: callback,
@ -226,9 +261,8 @@ where
/// Processing subscription requests.
pub async fn serve(self) {
let (ws_sender, ws_receiver) = self.socket.split();
let stream = ws_receiver
let stream = self
.stream
.take_while(|msg| future::ready(msg.is_ok()))
.map(Result::unwrap)
.filter(|msg| future::ready(msg.is_text() || msg.is_binary()))
@ -242,7 +276,7 @@ where
WsMessage::Close(code, status) => ws::Message::close_with(code, status),
})
.map(Ok)
.forward(ws_sender)
.forward(self.sink)
.await;
}
}