diff --git a/integrations/axum/src/extract.rs b/integrations/axum/src/extract.rs index e5e54201..2ec55254 100644 --- a/integrations/axum/src/extract.rs +++ b/integrations/axum/src/extract.rs @@ -1,19 +1,23 @@ -use std::io::ErrorKind; +use std::{io::ErrorKind, marker::PhantomData}; use async_graphql::{futures_util::TryStreamExt, http::MultipartOptions, ParseRequestError}; use axum::{ extract::{BodyStream, FromRequest, RequestParts}, http, http::Method, + response::IntoResponse, BoxError, }; use bytes::Bytes; use tokio_util::compat::TokioAsyncReadCompatExt; /// Extractor for GraphQL request. -pub struct GraphQLRequest(pub async_graphql::Request); +pub struct GraphQLRequest( + pub async_graphql::Request, + PhantomData, +); -impl GraphQLRequest { +impl GraphQLRequest { /// Unwraps the value to `async_graphql::Request`. #[must_use] pub fn into_inner(self) -> async_graphql::Request { @@ -57,28 +61,33 @@ pub mod rejection { } #[async_trait::async_trait] -impl FromRequest for GraphQLRequest +impl FromRequest for GraphQLRequest where B: http_body::Body + Unpin + Send + Sync + 'static, B::Data: Into, B::Error: Into, + R: IntoResponse + From, { - type Rejection = rejection::GraphQLRejection; + type Rejection = R; async fn from_request(req: &mut RequestParts) -> Result { Ok(GraphQLRequest( - GraphQLBatchRequest::from_request(req) + GraphQLBatchRequest::::from_request(req) .await? .0 .into_single()?, + PhantomData, )) } } /// Extractor for GraphQL batch request. -pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest); +pub struct GraphQLBatchRequest( + pub async_graphql::BatchRequest, + PhantomData, +); -impl GraphQLBatchRequest { +impl GraphQLBatchRequest { /// Unwraps the value to `async_graphql::BatchRequest`. #[must_use] pub fn into_inner(self) -> async_graphql::BatchRequest { @@ -87,13 +96,14 @@ impl GraphQLBatchRequest { } #[async_trait::async_trait] -impl FromRequest for GraphQLBatchRequest +impl FromRequest for GraphQLBatchRequest where B: http_body::Body + Unpin + Send + Sync + 'static, B::Data: Into, B::Error: Into, + R: IntoResponse + From, { - type Rejection = rejection::GraphQLRejection; + type Rejection = R; async fn from_request(req: &mut RequestParts) -> Result { if let (&Method::GET, uri) = (req.method(), req.uri()) { @@ -103,7 +113,7 @@ where format!("failed to parse graphql request from uri query: {}", err), )) }); - Ok(Self(async_graphql::BatchRequest::Single(res?))) + Ok(Self(async_graphql::BatchRequest::Single(res?), PhantomData)) } else { let content_type = req .headers() @@ -127,6 +137,7 @@ where MultipartOptions::default(), ) .await?, + PhantomData, )) } }