From 7275c7ded89cd8da95f6c11bc409a2aa388c405c Mon Sep 17 00:00:00 2001 From: Sunli Date: Thu, 18 Nov 2021 09:25:32 +0800 Subject: [PATCH] Make `GraphQLWebSocket::new` use generic stream. Add `GraphQLWebSocket::new_with_pair` method. --- CHANGELOG.md | 5 ++ integrations/axum/src/subscription.rs | 70 +++++++++++++++++++++------ integrations/poem/src/subscription.rs | 70 +++++++++++++++++++++------ integrations/warp/src/subscription.rs | 68 +++++++++++++++++++------- 4 files changed, 164 insertions(+), 49 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6aa0ee73..512ccb56 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [3.0.3] 2021-11-18 + +- [integrations] Make `GraphQLWebSocket::new` use generic stream. +- [integrations] Add `GraphQLWebSocket::new_with_pair` method. + ## [3.0.2] 2021-11-16 - Add `url`, `regex` and `ip` validators. diff --git a/integrations/axum/src/subscription.rs b/integrations/axum/src/subscription.rs index a88ca8cc..d7856b6f 100644 --- a/integrations/axum/src/subscription.rs +++ b/integrations/axum/src/subscription.rs @@ -7,12 +7,14 @@ use async_graphql::futures_util::task::{Context, Poll}; use async_graphql::http::{WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS}; use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; use axum::body::{box_body, BoxBody, HttpBody}; -use axum::extract::ws::{CloseFrame, Message, WebSocket}; +use axum::extract::ws::{CloseFrame, Message}; use axum::extract::{FromRequest, RequestParts, WebSocketUpgrade}; use axum::http::{self, Request, Response, StatusCode}; use axum::response::IntoResponse; +use axum::Error; use futures_util::future::{BoxFuture, Ready}; -use futures_util::{future, SinkExt, StreamExt}; +use futures_util::stream::{SplitSink, SplitStream}; +use futures_util::{future, Sink, SinkExt, Stream, StreamExt}; use tower_service::Service; /// A GraphQL protocol extractor. @@ -116,30 +118,61 @@ fn default_on_connection_init(_: serde_json::Value) -> Ready { +pub struct GraphQLWebSocket { + sink: Sink, + stream: Stream, schema: Schema, - stream: WebSocket, data: Data, on_connection_init: OnConnInit, protocol: GraphQLProtocol, } -impl - GraphQLWebSocket +impl + GraphQLWebSocket< + SplitSink, + SplitStream, + Query, + Mutation, + Subscription, + DefaultOnConnInitType, + > where + S: Stream> + Sink, Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, { /// Create a [`GraphQLWebSocket`] object. pub fn new( - stream: WebSocket, + stream: S, + schema: Schema, + protocol: GraphQLProtocol, + ) -> Self { + let (sink, stream) = stream.split(); + GraphQLWebSocket::new_with_pair(sink, stream, schema, protocol) + } +} + +impl + GraphQLWebSocket +where + Sink: futures_util::sink::Sink, + Stream: futures_util::stream::Stream>, + Query: ObjectType + 'static, + Mutation: ObjectType + 'static, + Subscription: SubscriptionType + 'static, +{ + /// Create a [`GraphQLWebSocket`] object with sink and stream objects. + pub fn new_with_pair( + sink: Sink, + stream: Stream, schema: Schema, protocol: GraphQLProtocol, ) -> Self { GraphQLWebSocket { - schema, + sink, stream, + schema, data: Data::default(), on_connection_init: default_on_connection_init, protocol, @@ -147,9 +180,11 @@ where } } -impl - GraphQLWebSocket +impl + GraphQLWebSocket where + Sink: futures_util::sink::Sink, + Stream: futures_util::stream::Stream>, Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, @@ -169,14 +204,15 @@ where pub fn on_connection_init( self, callback: OnConnInit2, - ) -> GraphQLWebSocket + ) -> GraphQLWebSocket where OnConnInit2: Fn(serde_json::Value) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { GraphQLWebSocket { - schema: self.schema, + sink: self.sink, stream: self.stream, + schema: self.schema, data: self.data, on_connection_init: callback, protocol: self.protocol, @@ -185,9 +221,8 @@ where /// Processing subscription requests. pub async fn serve(self) { - let (mut sink, stream) = self.stream.split(); - - let input = stream + let input = self + .stream .take_while(|res| future::ready(res.is_ok())) .map(Result::unwrap) .filter_map(|msg| { @@ -199,7 +234,7 @@ where }) .map(Message::into_data); - let mut stream = + let stream = async_graphql::http::WebSocket::new(self.schema.clone(), input, self.protocol.0) .connection_data(self.data) .on_connection_init(self.on_connection_init) @@ -211,6 +246,9 @@ where })), }); + let sink = self.sink; + futures_util::pin_mut!(stream, sink); + while let Some(item) = stream.next().await { let _ = sink.send(item).await; } diff --git a/integrations/poem/src/subscription.rs b/integrations/poem/src/subscription.rs index 61fe7043..a5d763d5 100644 --- a/integrations/poem/src/subscription.rs +++ b/integrations/poem/src/subscription.rs @@ -1,11 +1,13 @@ +use std::io::Error as IoError; use std::str::FromStr; use async_graphql::http::{WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS}; use async_graphql::{Data, ObjectType, Schema, SubscriptionType}; use futures_util::future::{self, Ready}; -use futures_util::{Future, SinkExt, StreamExt}; +use futures_util::stream::{SplitSink, SplitStream}; +use futures_util::{Future, Sink, SinkExt, Stream, StreamExt}; use poem::http::StatusCode; -use poem::web::websocket::{Message, WebSocket, WebSocketStream}; +use poem::web::websocket::{Message, WebSocket}; use poem::{http, Endpoint, FromRequest, IntoResponse, Request, RequestBody, Response, Result}; /// A GraphQL protocol extractor. @@ -106,30 +108,61 @@ fn default_on_connection_init(_: serde_json::Value) -> Ready { +pub struct GraphQLWebSocket { + sink: Sink, + stream: Stream, schema: Schema, - stream: WebSocketStream, data: Data, on_connection_init: OnConnInit, protocol: GraphQLProtocol, } -impl - GraphQLWebSocket +impl + GraphQLWebSocket< + SplitSink, + SplitStream, + Query, + Mutation, + Subscription, + DefaultOnConnInitType, + > where + S: Stream> + Sink, Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, { /// Create a [`GraphQLWebSocket`] object. pub fn new( - stream: WebSocketStream, + stream: S, + schema: Schema, + protocol: GraphQLProtocol, + ) -> Self { + let (sink, stream) = stream.split(); + GraphQLWebSocket::new_with_pair(sink, stream, schema, protocol) + } +} + +impl + GraphQLWebSocket +where + Sink: futures_util::sink::Sink, + Stream: futures_util::stream::Stream>, + Query: ObjectType + 'static, + Mutation: ObjectType + 'static, + Subscription: SubscriptionType + 'static, +{ + /// Create a [`GraphQLWebSocket`] object with sink and stream objects. + pub fn new_with_pair( + sink: Sink, + stream: Stream, schema: Schema, protocol: GraphQLProtocol, ) -> Self { GraphQLWebSocket { - schema, + sink, stream, + schema, data: Data::default(), on_connection_init: default_on_connection_init, protocol, @@ -137,9 +170,11 @@ where } } -impl - GraphQLWebSocket +impl + GraphQLWebSocket where + Sink: futures_util::sink::Sink, + Stream: futures_util::stream::Stream>, Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, @@ -159,14 +194,15 @@ where pub fn on_connection_init( self, callback: OnConnInit2, - ) -> GraphQLWebSocket + ) -> GraphQLWebSocket where OnConnInit2: Fn(serde_json::Value) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { GraphQLWebSocket { - schema: self.schema, + sink: self.sink, stream: self.stream, + schema: self.schema, data: self.data, on_connection_init: callback, protocol: self.protocol, @@ -175,9 +211,8 @@ where /// Processing subscription requests. pub async fn serve(self) { - let (mut sink, stream) = self.stream.split(); - - let stream = stream + let stream = self + .stream .take_while(|res| future::ready(res.is_ok())) .map(Result::unwrap) .filter_map(|msg| { @@ -189,7 +224,7 @@ where }) .map(Message::into_bytes); - let mut stream = + let stream = async_graphql::http::WebSocket::new(self.schema.clone(), stream, self.protocol.0) .connection_data(self.data) .on_connection_init(self.on_connection_init) @@ -198,6 +233,9 @@ where WsMessage::Close(code, status) => Message::close_with(code, status), }); + let sink = self.sink; + futures_util::pin_mut!(stream, sink); + while let Some(item) = stream.next().await { let _ = sink.send(item).await; } diff --git a/integrations/warp/src/subscription.rs b/integrations/warp/src/subscription.rs index d3034337..d3a72d9f 100644 --- a/integrations/warp/src/subscription.rs +++ b/integrations/warp/src/subscription.rs @@ -4,10 +4,11 @@ use std::str::FromStr; use async_graphql::http::{WebSocketProtocols, WsMessage}; use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; use futures_util::future::Ready; -use futures_util::{future, StreamExt}; +use futures_util::stream::{SplitSink, SplitStream}; +use futures_util::{future, Sink, Stream, StreamExt}; use warp::filters::ws; -use warp::ws::WebSocket; -use warp::{Filter, Rejection, Reply}; +use warp::ws::Message; +use warp::{Error, Filter, Rejection, Reply}; /// GraphQL subscription filter /// @@ -157,29 +158,60 @@ fn default_on_connection_init(_: serde_json::Value) -> Ready { - socket: WebSocket, +pub struct GraphQLWebSocket { + sink: Sink, + stream: Stream, protocol: WebSocketProtocols, schema: Schema, data: Data, on_init: OnInit, } -impl - GraphQLWebSocket +impl + GraphQLWebSocket< + SplitSink, + SplitStream, + Query, + Mutation, + Subscription, + DefaultOnConnInitType, + > where + S: Stream> + Sink, Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, { /// Create a [`GraphQLWebSocket`] object. pub fn new( - socket: WebSocket, + socket: S, + schema: Schema, + protocol: WebSocketProtocols, + ) -> Self { + let (sink, stream) = socket.split(); + GraphQLWebSocket::new_with_pair(sink, stream, schema, protocol) + } +} + +impl + GraphQLWebSocket +where + Sink: futures_util::sink::Sink, + Stream: futures_util::stream::Stream>, + Query: ObjectType + 'static, + Mutation: ObjectType + 'static, + Subscription: SubscriptionType + 'static, +{ + /// Create a [`GraphQLWebSocket`] object with sink and stream objects. + pub fn new_with_pair( + sink: Sink, + stream: Stream, schema: Schema, protocol: WebSocketProtocols, ) -> Self { GraphQLWebSocket { - socket, + sink, + stream, protocol, schema, data: Data::default(), @@ -188,9 +220,11 @@ where } } -impl - GraphQLWebSocket +impl + GraphQLWebSocket where + Sink: futures_util::sink::Sink, + Stream: futures_util::stream::Stream>, Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, @@ -210,13 +244,14 @@ where pub fn on_connection_init( self, callback: OnConnInit2, - ) -> GraphQLWebSocket + ) -> GraphQLWebSocket where OnConnInit2: Fn(serde_json::Value) -> Fut + Send + Sync + 'static, Fut: Future> + Send + 'static, { GraphQLWebSocket { - socket: self.socket, + sink: self.sink, + stream: self.stream, schema: self.schema, data: self.data, on_init: callback, @@ -226,9 +261,8 @@ where /// Processing subscription requests. pub async fn serve(self) { - let (ws_sender, ws_receiver) = self.socket.split(); - - let stream = ws_receiver + let stream = self + .stream .take_while(|msg| future::ready(msg.is_ok())) .map(Result::unwrap) .filter(|msg| future::ready(msg.is_text() || msg.is_binary())) @@ -242,7 +276,7 @@ where WsMessage::Close(code, status) => ws::Message::close_with(code, status), }) .map(Ok) - .forward(ws_sender) + .forward(self.sink) .await; } }