diff --git a/integrations/axum/Cargo.toml b/integrations/axum/Cargo.toml index 1a110a3f..688c6571 100644 --- a/integrations/axum/Cargo.toml +++ b/integrations/axum/Cargo.toml @@ -15,7 +15,7 @@ version = "4.0.15" async-graphql = { path = "../..", version = "4.0.15", default-features = false } async-trait = "0.1.51" -axum = { version = "0.5.1", features = ["ws", "headers"] } +axum = { git = "https://github.com/tokio-rs/axum", features = ["headers"] } bytes = "1.0.1" futures-util = "0.3.0" http-body = "0.4.2" diff --git a/integrations/axum/src/extract.rs b/integrations/axum/src/extract.rs index 0190b81c..98243199 100644 --- a/integrations/axum/src/extract.rs +++ b/integrations/axum/src/extract.rs @@ -2,12 +2,13 @@ use std::{io::ErrorKind, marker::PhantomData}; use async_graphql::{futures_util::TryStreamExt, http::MultipartOptions, ParseRequestError}; use axum::{ - extract::{BodyStream, FromRequest, RequestParts}, + extract::{BodyStream, FromRequest}, http, http::Method, response::IntoResponse, BoxError, }; +use axum::http::Request; use bytes::Bytes; use tokio_util::compat::TokioAsyncReadCompatExt; @@ -61,8 +62,9 @@ pub mod rejection { } #[async_trait::async_trait] -impl FromRequest for GraphQLRequest +impl FromRequest for GraphQLRequest where + S: Send + Sync, B: http_body::Body + Unpin + Send + Sync + 'static, B::Data: Into, B::Error: Into, @@ -70,9 +72,9 @@ where { type Rejection = R; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: Request, state: &S) -> Result { Ok(GraphQLRequest( - GraphQLBatchRequest::::from_request(req) + GraphQLBatchRequest::::from_request(req, state) .await? .0 .into_single()?, @@ -96,8 +98,9 @@ impl GraphQLBatchRequest { } #[async_trait::async_trait] -impl FromRequest for GraphQLBatchRequest +impl FromRequest for GraphQLBatchRequest where + S: Send + Sync, B: http_body::Body + Unpin + Send + Sync + 'static, B::Data: Into, B::Error: Into, @@ -105,13 +108,13 @@ where { type Rejection = R; - async fn from_request(req: &mut RequestParts) -> Result { + async fn from_request(req: Request, state: &S) -> Result { if let (&Method::GET, uri) = (req.method(), req.uri()) { let res = async_graphql::http::parse_query_string(uri.query().unwrap_or_default()) .map_err(|err| { ParseRequestError::Io(std::io::Error::new( ErrorKind::Other, - format!("failed to parse graphql request from uri query: {}", err), + format!("failed to parse graphql request from uri query: {:?}", err), )) }); Ok(Self(async_graphql::BatchRequest::Single(res?), PhantomData)) @@ -121,7 +124,7 @@ where .get(http::header::CONTENT_TYPE) .and_then(|value| value.to_str().ok()) .map(ToString::to_string); - let body_stream = BodyStream::from_request(req) + let body_stream = BodyStream::from_request(req, state) .await .map_err(|_| { ParseRequestError::Io(std::io::Error::new( diff --git a/integrations/axum/src/lib.rs b/integrations/axum/src/lib.rs index db48c948..b2d92b5e 100644 --- a/integrations/axum/src/lib.rs +++ b/integrations/axum/src/lib.rs @@ -4,8 +4,6 @@ mod extract; mod response; -mod subscription; pub use extract::{GraphQLBatchRequest, GraphQLRequest}; pub use response::GraphQLResponse; -pub use subscription::{GraphQLProtocol, GraphQLSubscription, GraphQLWebSocket}; diff --git a/integrations/axum/src/subscription.rs b/integrations/axum/src/subscription.rs deleted file mode 100644 index 0f554f1c..00000000 --- a/integrations/axum/src/subscription.rs +++ /dev/null @@ -1,265 +0,0 @@ -use std::{borrow::Cow, convert::Infallible, future::Future, str::FromStr}; - -use async_graphql::{ - futures_util::task::{Context, Poll}, - http::{WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS}, - Data, ObjectType, Result, Schema, SubscriptionType, -}; -use axum::{ - body::{boxed, BoxBody, HttpBody}, - extract::{ - ws::{CloseFrame, Message}, - FromRequest, RequestParts, WebSocketUpgrade, - }, - http::{self, Request, Response, StatusCode}, - response::IntoResponse, - Error, -}; -use futures_util::{ - future, - future::{BoxFuture, Ready}, - stream::{SplitSink, SplitStream}, - Sink, SinkExt, Stream, StreamExt, -}; -use tower_service::Service; - -/// A GraphQL protocol extractor. -/// -/// It extract GraphQL protocol from `SEC_WEBSOCKET_PROTOCOL` header. -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub struct GraphQLProtocol(WebSocketProtocols); - -#[async_trait::async_trait] -impl FromRequest for GraphQLProtocol { - type Rejection = StatusCode; - - async fn from_request(req: &mut RequestParts) -> Result { - req.headers() - .get(http::header::SEC_WEBSOCKET_PROTOCOL) - .and_then(|value| value.to_str().ok()) - .and_then(|protocols| { - protocols - .split(',') - .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok()) - }) - .map(Self) - .ok_or(StatusCode::BAD_REQUEST) - } -} - -/// A GraphQL subscription service. -pub struct GraphQLSubscription { - schema: Schema, -} - -impl Clone for GraphQLSubscription -where - Query: ObjectType + 'static, - Mutation: ObjectType + 'static, - Subscription: SubscriptionType + 'static, -{ - fn clone(&self) -> Self { - Self { - schema: self.schema.clone(), - } - } -} - -impl GraphQLSubscription -where - Query: ObjectType + 'static, - Mutation: ObjectType + 'static, - Subscription: SubscriptionType + 'static, -{ - /// Create a GraphQL subscription service. - pub fn new(schema: Schema) -> Self { - Self { schema } - } -} - -impl Service> - for GraphQLSubscription -where - B: HttpBody + Send + 'static, - Query: ObjectType + 'static, - Mutation: ObjectType + 'static, - Subscription: SubscriptionType + 'static, -{ - type Response = Response; - type Error = Infallible; - type Future = BoxFuture<'static, Result>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - let schema = self.schema.clone(); - - Box::pin(async move { - let mut parts = RequestParts::new(req); - let protocol = match GraphQLProtocol::from_request(&mut parts).await { - Ok(protocol) => protocol, - Err(err) => return Ok(err.into_response().map(boxed)), - }; - let upgrade = match WebSocketUpgrade::from_request(&mut parts).await { - Ok(protocol) => protocol, - Err(err) => return Ok(err.into_response().map(boxed)), - }; - - let schema = schema.clone(); - - let resp = upgrade - .protocols(ALL_WEBSOCKET_PROTOCOLS) - .on_upgrade(move |stream| GraphQLWebSocket::new(stream, schema, protocol).serve()); - Ok(resp.into_response().map(boxed)) - }) - } -} - -type DefaultOnConnInitType = fn(serde_json::Value) -> Ready>; - -fn default_on_connection_init(_: serde_json::Value) -> Ready> { - futures_util::future::ready(Ok(Data::default())) -} - -/// A Websocket connection for GraphQL subscription. -pub struct GraphQLWebSocket { - sink: Sink, - stream: Stream, - schema: Schema, - data: Data, - on_connection_init: OnConnInit, - protocol: GraphQLProtocol, -} - -impl - GraphQLWebSocket< - SplitSink, - SplitStream, - Query, - Mutation, - Subscription, - DefaultOnConnInitType, - > -where - S: Stream> + Sink, - Query: ObjectType + 'static, - Mutation: ObjectType + 'static, - Subscription: SubscriptionType + 'static, -{ - /// Create a [`GraphQLWebSocket`] object. - pub fn new( - stream: S, - schema: Schema, - protocol: GraphQLProtocol, - ) -> Self { - let (sink, stream) = stream.split(); - GraphQLWebSocket::new_with_pair(sink, stream, schema, protocol) - } -} - -impl - GraphQLWebSocket -where - Sink: futures_util::sink::Sink, - Stream: futures_util::stream::Stream>, - Query: ObjectType + 'static, - Mutation: ObjectType + 'static, - Subscription: SubscriptionType + 'static, -{ - /// Create a [`GraphQLWebSocket`] object with sink and stream objects. - pub fn new_with_pair( - sink: Sink, - stream: Stream, - schema: Schema, - protocol: GraphQLProtocol, - ) -> Self { - GraphQLWebSocket { - sink, - stream, - schema, - data: Data::default(), - on_connection_init: default_on_connection_init, - protocol, - } - } -} - -impl - GraphQLWebSocket -where - Sink: futures_util::sink::Sink, - Stream: futures_util::stream::Stream>, - Query: ObjectType + 'static, - Mutation: ObjectType + 'static, - Subscription: SubscriptionType + 'static, - OnConnInit: FnOnce(serde_json::Value) -> OnConnInitFut + Send + 'static, - OnConnInitFut: Future> + Send + 'static, -{ - /// Specify the initial subscription context data, usually you can get - /// something from the incoming request to create it. - #[must_use] - pub fn with_data(self, data: Data) -> Self { - Self { data, ..self } - } - - /// Specify a callback function to be called when the connection is - /// initialized. - /// - /// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`]. - /// The data returned by this callback function will be merged with the data - /// specified by [`with_data`]. - pub fn on_connection_init( - self, - callback: OnConnInit2, - ) -> GraphQLWebSocket - where - OnConnInit2: FnOnce(serde_json::Value) -> Fut + Send + 'static, - Fut: Future> + Send + 'static, - { - GraphQLWebSocket { - sink: self.sink, - stream: self.stream, - schema: self.schema, - data: self.data, - on_connection_init: callback, - protocol: self.protocol, - } - } - - /// Processing subscription requests. - pub async fn serve(self) { - let input = self - .stream - .take_while(|res| future::ready(res.is_ok())) - .map(Result::unwrap) - .filter_map(|msg| { - if let Message::Text(_) | Message::Binary(_) = msg { - future::ready(Some(msg)) - } else { - future::ready(None) - } - }) - .map(Message::into_data); - - let stream = - async_graphql::http::WebSocket::new(self.schema.clone(), input, self.protocol.0) - .connection_data(self.data) - .on_connection_init(self.on_connection_init) - .map(|msg| match msg { - WsMessage::Text(text) => Message::Text(text), - WsMessage::Close(code, status) => Message::Close(Some(CloseFrame { - code, - reason: Cow::from(status), - })), - }); - - let sink = self.sink; - futures_util::pin_mut!(stream, sink); - - while let Some(item) = stream.next().await { - let _ = sink.send(item).await; - } - } -}