diff --git a/integrations/axum/Cargo.toml b/integrations/axum/Cargo.toml index 363b3803..14139e6a 100644 --- a/integrations/axum/Cargo.toml +++ b/integrations/axum/Cargo.toml @@ -15,12 +15,11 @@ categories = ["network-programming", "asynchronous"] async-graphql = { path = "../..", version = "=2.11.2" } async-trait = "0.1.51" -axum = { version = "0.3", features = ["ws", "headers"] } +axum = { version = "0.3.2", features = ["ws", "headers"] } bytes = "1.0.1" -headers = "0.3.4" -http = "0.2.4" http-body = "0.4.2" serde_json = "1.0.66" serde_urlencoded = "0.7.0" tokio-util = { version = "0.6.7", features = ["io", "compat"] } futures-util = "0.3.13" +tower-service = "0.3" diff --git a/integrations/axum/src/extract.rs b/integrations/axum/src/extract.rs index 4d759ff3..0db2af3d 100644 --- a/integrations/axum/src/extract.rs +++ b/integrations/axum/src/extract.rs @@ -3,12 +3,12 @@ use std::io::ErrorKind; use async_graphql::futures_util::TryStreamExt; use async_graphql::http::MultipartOptions; use async_graphql::ParseRequestError; +use axum::http::Method; use axum::{ extract::{BodyStream, FromRequest, RequestParts}, - BoxError, + http, BoxError, }; use bytes::Bytes; -use http::Method; use tokio_util::compat::TokioAsyncReadCompatExt; /// Extractor for GraphQL request. @@ -26,8 +26,9 @@ impl GraphQLRequest { pub mod rejection { use async_graphql::ParseRequestError; use axum::body::Body; + use axum::http; + use axum::http::StatusCode; use axum::response::IntoResponse; - use http::StatusCode; /// Rejection used for [`GraphQLRequest`](GraphQLRequest). pub struct GraphQLRejection(pub ParseRequestError); diff --git a/integrations/axum/src/lib.rs b/integrations/axum/src/lib.rs index ec6814e2..db48c948 100644 --- a/integrations/axum/src/lib.rs +++ b/integrations/axum/src/lib.rs @@ -8,6 +8,4 @@ mod subscription; pub use extract::{GraphQLBatchRequest, GraphQLRequest}; pub use response::GraphQLResponse; -pub use subscription::{ - graphql_subscription, graphql_subscription_with_data, SecWebsocketProtocol, -}; +pub use subscription::{GraphQLProtocol, GraphQLSubscription, GraphQLWebSocket}; diff --git a/integrations/axum/src/response.rs b/integrations/axum/src/response.rs index ed3d7b51..041ca8c8 100644 --- a/integrations/axum/src/response.rs +++ b/integrations/axum/src/response.rs @@ -1,7 +1,8 @@ use axum::body::Body; +use axum::http; +use axum::http::header::HeaderName; +use axum::http::{HeaderValue, Response}; use axum::response::IntoResponse; -use headers::HeaderName; -use http::{HeaderValue, Response}; /// Responder for a GraphQL response. /// diff --git a/integrations/axum/src/subscription.rs b/integrations/axum/src/subscription.rs index b4466988..ae06cb0f 100644 --- a/integrations/axum/src/subscription.rs +++ b/integrations/axum/src/subscription.rs @@ -1,103 +1,200 @@ use std::borrow::Cow; +use std::convert::Infallible; use std::future::Future; +use std::str::FromStr; -use async_graphql::http::{WebSocketProtocols, WsMessage}; +use async_graphql::futures_util::task::{Context, Poll}; +use async_graphql::http::{WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS}; use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; +use axum::body::{box_body, BoxBody, HttpBody}; use axum::extract::ws::{CloseFrame, Message, WebSocket}; +use axum::extract::{FromRequest, RequestParts, WebSocketUpgrade}; +use axum::http::{self, Request, Response, StatusCode}; +use axum::response::IntoResponse; +use futures_util::future::{BoxFuture, Ready}; use futures_util::{future, SinkExt, StreamExt}; -use headers::{Header, HeaderName, HeaderValue}; +use tower_service::Service; -/// The Sec-Websocket-Protocol header. -#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] -pub struct SecWebsocketProtocol(pub WebSocketProtocols); - -impl Header for SecWebsocketProtocol { - fn name() -> &'static HeaderName { - &http::header::SEC_WEBSOCKET_PROTOCOL - } - - fn decode<'i, I>(values: &mut I) -> Result - where - Self: Sized, - I: Iterator, - { - match values.next() { - Some(value) => { - let value = value.to_str().map_err(|_| headers::Error::invalid())?; - Ok(SecWebsocketProtocol( - value - .parse() - .ok() - .unwrap_or(WebSocketProtocols::SubscriptionsTransportWS), - )) - } - None => Err(headers::Error::invalid()), - } - } - - fn encode>(&self, values: &mut E) { - values.extend(std::iter::once(HeaderValue::from_static( - self.0.sec_websocket_protocol(), - ))) - } -} - -/// GraphQL subscription handler -pub async fn graphql_subscription( - websocket: WebSocket, - schema: Schema, - protocol: SecWebsocketProtocol, -) where - Query: ObjectType + Sync + Send + 'static, - Mutation: ObjectType + Sync + Send + 'static, - Subscription: SubscriptionType + Send + Sync + 'static, -{ - graphql_subscription_with_data(websocket, schema, protocol, |_| async { - Ok(Default::default()) - }) - .await -} - -/// GraphQL subscription handler +/// A GraphQL protocol extractor. /// -/// Specifies that a function converts the init payload to data. -pub async fn graphql_subscription_with_data( - websocket: WebSocket, +/// It extract GraphQL protocol from `SEC_WEBSOCKET_PROTOCOL` header. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub struct GraphQLProtocol(WebSocketProtocols); + +#[async_trait::async_trait] +impl FromRequest for GraphQLProtocol { + type Rejection = StatusCode; + + async fn from_request(req: &mut RequestParts) -> Result { + req.headers() + .and_then(|headers| headers.get(http::header::SEC_WEBSOCKET_PROTOCOL)) + .and_then(|value| value.to_str().ok()) + .and_then(|protocols| { + protocols + .split(',') + .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok()) + }) + .map(Self) + .ok_or_else(|| StatusCode::BAD_REQUEST) + } +} + +/// A GraphQL subscription service. +pub struct GraphQLSubscription { schema: Schema, - protocol: SecWebsocketProtocol, - initializer: F, -) where +} + +impl GraphQLSubscription { + /// Create a GraphQL subscription service. + pub fn new(schema: Schema) -> Self { + Self { schema } + } +} + +impl Service> + for GraphQLSubscription +where + B: HttpBody + Send + 'static, Query: ObjectType + 'static, Mutation: ObjectType + 'static, Subscription: SubscriptionType + 'static, - F: FnOnce(serde_json::Value) -> R + Send + 'static, - R: Future> + Send + 'static, { - let (mut sink, stream) = websocket.split(); - let input = stream - .take_while(|res| future::ready(res.is_ok())) - .map(Result::unwrap) - .filter_map(|msg| { - if let Message::Text(_) | Message::Binary(_) = msg { - future::ready(Some(msg)) - } else { - future::ready(None) - } + type Response = Response; + type Error = Infallible; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, req: Request) -> Self::Future { + let schema = self.schema.clone(); + + Box::pin(async move { + let mut parts = RequestParts::new(req); + let protocol = match GraphQLProtocol::from_request(&mut parts).await { + Ok(protocol) => protocol, + Err(err) => return Ok(err.into_response().map(box_body)), + }; + let upgrade = match WebSocketUpgrade::from_request(&mut parts).await { + Ok(protocol) => protocol, + Err(err) => return Ok(err.into_response().map(box_body)), + }; + + let schema = schema.clone(); + + let resp = upgrade + .protocols(ALL_WEBSOCKET_PROTOCOLS) + .on_upgrade(move |stream| GraphQLWebSocket::new(stream, schema, protocol).serve()); + Ok(resp.into_response().map(box_body)) }) - .map(Message::into_data); - - let mut stream = - async_graphql::http::WebSocket::with_data(schema, input, initializer, protocol.0).map( - |msg| match msg { - WsMessage::Text(text) => Message::Text(text), - WsMessage::Close(code, status) => Message::Close(Some(CloseFrame { - code, - reason: Cow::from(status), - })), - }, - ); - - while let Some(item) = stream.next().await { - let _ = sink.send(item).await; + } +} + +type DefaultOnConnInitType = fn(serde_json::Value) -> Ready>; + +fn default_on_connection_init(_: serde_json::Value) -> Ready> { + futures_util::future::ready(Ok(Data::default())) +} + +/// A Websocket connection for GraphQL subscription. +pub struct GraphQLWebSocket { + schema: Schema, + stream: WebSocket, + data: Data, + on_connection_init: OnConnInit, + protocol: GraphQLProtocol, +} + +impl + GraphQLWebSocket +where + Query: ObjectType + 'static, + Mutation: ObjectType + 'static, + Subscription: SubscriptionType + 'static, +{ + /// Create a [`GraphQLWebSocket`] object. + pub fn new( + stream: WebSocket, + schema: Schema, + protocol: GraphQLProtocol, + ) -> Self { + GraphQLWebSocket { + schema, + stream, + data: Data::default(), + on_connection_init: default_on_connection_init, + protocol, + } + } +} + +impl + GraphQLWebSocket +where + Query: ObjectType + 'static, + Mutation: ObjectType + 'static, + Subscription: SubscriptionType + 'static, + OnConnInit: Fn(serde_json::Value) -> OnConnInitFut + Send + Sync + 'static, + OnConnInitFut: Future> + Send + 'static, +{ + /// Specify the initial subscription context data, usually you can get something from the + /// incoming request to create it. + pub fn with_data(self, data: Data) -> Self { + Self { data, ..self } + } + + /// Specify a callback function to be called when the connection is initialized. + /// + /// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`]. + /// The data returned by this callback function will be merged with the data specified by [`with_data`]. + pub fn on_connection_init( + self, + callback: OnConnInit2, + ) -> GraphQLWebSocket + where + OnConnInit2: Fn(serde_json::Value) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + GraphQLWebSocket { + schema: self.schema, + stream: self.stream, + data: self.data, + on_connection_init: callback, + protocol: self.protocol, + } + } + + /// Processing subscription requests. + pub async fn serve(self) { + let (mut sink, stream) = self.stream.split(); + + let input = stream + .take_while(|res| future::ready(res.is_ok())) + .map(Result::unwrap) + .filter_map(|msg| { + if let Message::Text(_) | Message::Binary(_) = msg { + future::ready(Some(msg)) + } else { + future::ready(None) + } + }) + .map(Message::into_data); + + let mut stream = + async_graphql::http::WebSocket::new(self.schema.clone(), input, self.protocol.0) + .connection_data(self.data) + .on_connection_init(self.on_connection_init) + .map(|msg| match msg { + WsMessage::Text(text) => Message::Text(text), + WsMessage::Close(code, status) => Message::Close(Some(CloseFrame { + code, + reason: Cow::from(status), + })), + }); + + while let Some(item) = stream.next().await { + let _ = sink.send(item).await; + } } } diff --git a/integrations/poem/Cargo.toml b/integrations/poem/Cargo.toml index 4137d5a7..a15ce794 100644 --- a/integrations/poem/Cargo.toml +++ b/integrations/poem/Cargo.toml @@ -18,3 +18,4 @@ poem = { version = "1.0.23", features = ["websocket"] } futures-util = { version = "0.3.13", default-features = false } serde_json = "1.0.66" tokio-util = { version = "0.6.7", features = ["compat"] } + diff --git a/integrations/poem/src/subscription.rs b/integrations/poem/src/subscription.rs index 216de868..e4f64323 100644 --- a/integrations/poem/src/subscription.rs +++ b/integrations/poem/src/subscription.rs @@ -6,9 +6,7 @@ use futures_util::future::{self, Ready}; use futures_util::{Future, SinkExt, StreamExt}; use poem::http::StatusCode; use poem::web::websocket::{Message, WebSocket, WebSocketStream}; -use poem::{ - http, Endpoint, Error, FromRequest, IntoResponse, Request, RequestBody, Response, Result, -}; +use poem::{http, Endpoint, FromRequest, IntoResponse, Request, RequestBody, Response, Result}; /// A GraphQL protocol extractor. /// @@ -18,7 +16,7 @@ pub struct GraphQLProtocol(WebSocketProtocols); #[poem::async_trait] impl<'a> FromRequest<'a> for GraphQLProtocol { - type Error = Error; + type Error = StatusCode; async fn from_request(req: &'a Request, _body: &mut RequestBody) -> Result { req.headers() @@ -30,7 +28,7 @@ impl<'a> FromRequest<'a> for GraphQLProtocol { .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok()) }) .map(Self) - .ok_or_else(|| Error::new(StatusCode::BAD_REQUEST)) + .ok_or_else(|| StatusCode::BAD_REQUEST) } } @@ -97,7 +95,6 @@ where .protocols(ALL_WEBSOCKET_PROTOCOLS) .on_upgrade(move |stream| GraphQLWebSocket::new(stream, schema, protocol).serve()) .into_response(); - Ok(resp) } }