diff --git a/integrations/axum/src/extract.rs b/integrations/axum/src/extract.rs index 34613787..2188b2fd 100644 --- a/integrations/axum/src/extract.rs +++ b/integrations/axum/src/extract.rs @@ -1,11 +1,13 @@ -use std::fmt::Display; use std::io::ErrorKind; use async_graphql::futures_util::TryStreamExt; use async_graphql::http::MultipartOptions; use async_graphql::ParseRequestError; -use axum::extract::{BodyStream, FromRequest, RequestParts}; -use bytes::Buf; +use axum::{ + extract::{BodyStream, FromRequest, RequestParts}, + BoxError, +}; +use bytes::Bytes; use http::Method; use tokio_util::compat::TokioAsyncReadCompatExt; @@ -30,6 +32,9 @@ pub mod rejection { pub struct GraphQLRejection(pub ParseRequestError); impl IntoResponse for GraphQLRejection { + type Body = axum::body::Body; + type BodyError = ::Error; + fn into_response(self) -> http::Response { todo!() } @@ -46,8 +51,8 @@ pub mod rejection { impl FromRequest for GraphQLRequest where B: http_body::Body + Unpin + Send + Sync + 'static, - B::Data: Buf + Send, - B::Error: Display + Send, + B::Data: Into, + B::Error: Into, { type Rejection = rejection::GraphQLRejection; @@ -76,13 +81,13 @@ impl GraphQLBatchRequest { impl FromRequest for GraphQLBatchRequest where B: http_body::Body + Unpin + Send + Sync + 'static, - B::Data: Buf + Send, - B::Error: Display + Send, + B::Data: Into, + B::Error: Into, { type Rejection = rejection::GraphQLRejection; async fn from_request(req: &mut RequestParts) -> Result { - if let (Some(&Method::GET), Some(uri)) = (req.method(), req.uri()) { + if let (&Method::GET, uri) = (req.method(), req.uri()) { let res = serde_urlencoded::from_str(uri.query().unwrap_or_default()).map_err(|err| { ParseRequestError::Io(std::io::Error::new( ErrorKind::Other, diff --git a/integrations/axum/src/response.rs b/integrations/axum/src/response.rs index c92549fd..d124c358 100644 --- a/integrations/axum/src/response.rs +++ b/integrations/axum/src/response.rs @@ -1,7 +1,7 @@ use std::convert::TryFrom; use axum::body::Body; -use axum::prelude::response::IntoResponse; +use axum::response::IntoResponse; use headers::HeaderName; use http::{HeaderValue, Response}; @@ -24,6 +24,9 @@ impl From for GraphQLResponse { } impl IntoResponse for GraphQLResponse { + type Body = Body; + type BodyError = ::Error; + fn into_response(self) -> Response { let mut resp = Response::new(serde_json::to_string(&self.0).unwrap().into()); resp.headers_mut().insert( diff --git a/integrations/axum/src/subscription.rs b/integrations/axum/src/subscription.rs index f34f14ac..b4466988 100644 --- a/integrations/axum/src/subscription.rs +++ b/integrations/axum/src/subscription.rs @@ -1,8 +1,9 @@ +use std::borrow::Cow; use std::future::Future; use async_graphql::http::{WebSocketProtocols, WsMessage}; use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; -use axum::ws::{Message, WebSocket}; +use axum::extract::ws::{CloseFrame, Message, WebSocket}; use futures_util::{future, SinkExt, StreamExt}; use headers::{Header, HeaderName, HeaderValue}; @@ -77,19 +78,22 @@ pub async fn graphql_subscription_with_data .take_while(|res| future::ready(res.is_ok())) .map(Result::unwrap) .filter_map(|msg| { - if msg.is_text() || msg.is_binary() { + if let Message::Text(_) | Message::Binary(_) = msg { future::ready(Some(msg)) } else { future::ready(None) } }) - .map(Message::into_bytes); + .map(Message::into_data); let mut stream = async_graphql::http::WebSocket::with_data(schema, input, initializer, protocol.0).map( |msg| match msg { - WsMessage::Text(text) => Message::text(text), - WsMessage::Close(code, status) => Message::close_with(code, status), + WsMessage::Text(text) => Message::Text(text), + WsMessage::Close(code, status) => Message::Close(Some(CloseFrame { + code, + reason: Cow::from(status), + })), }, );