2021-08-01 09:44:28 +00:00
|
|
|
use std::io::ErrorKind;
|
|
|
|
|
|
|
|
use async_graphql::futures_util::TryStreamExt;
|
|
|
|
use async_graphql::http::MultipartOptions;
|
|
|
|
use async_graphql::ParseRequestError;
|
2021-11-12 04:58:32 +00:00
|
|
|
use axum::http::Method;
|
2021-08-21 21:43:20 +00:00
|
|
|
use axum::{
|
|
|
|
extract::{BodyStream, FromRequest, RequestParts},
|
2021-11-12 04:58:32 +00:00
|
|
|
http, BoxError,
|
2021-08-21 21:43:20 +00:00
|
|
|
};
|
|
|
|
use bytes::Bytes;
|
2021-08-01 09:44:28 +00:00
|
|
|
use tokio_util::compat::TokioAsyncReadCompatExt;
|
|
|
|
|
|
|
|
/// Extractor for GraphQL request.
|
|
|
|
pub struct GraphQLRequest(pub async_graphql::Request);
|
|
|
|
|
|
|
|
impl GraphQLRequest {
|
|
|
|
/// Unwraps the value to `async_graphql::Request`.
|
|
|
|
#[must_use]
|
|
|
|
pub fn into_inner(self) -> async_graphql::Request {
|
|
|
|
self.0
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Rejection response types.
|
|
|
|
pub mod rejection {
|
|
|
|
use async_graphql::ParseRequestError;
|
2021-12-06 04:57:15 +00:00
|
|
|
use axum::body::{boxed, Body, BoxBody};
|
2021-11-12 04:58:32 +00:00
|
|
|
use axum::http;
|
|
|
|
use axum::http::StatusCode;
|
2021-08-01 09:44:28 +00:00
|
|
|
use axum::response::IntoResponse;
|
|
|
|
|
|
|
|
/// Rejection used for [`GraphQLRequest`](GraphQLRequest).
|
|
|
|
pub struct GraphQLRejection(pub ParseRequestError);
|
|
|
|
|
|
|
|
impl IntoResponse for GraphQLRejection {
|
2021-12-06 04:57:15 +00:00
|
|
|
fn into_response(self) -> http::Response<BoxBody> {
|
2021-09-10 02:23:16 +00:00
|
|
|
match self.0 {
|
|
|
|
ParseRequestError::PayloadTooLarge => http::Response::builder()
|
|
|
|
.status(StatusCode::PAYLOAD_TOO_LARGE)
|
2021-12-06 04:57:15 +00:00
|
|
|
.body(boxed(Body::empty()))
|
2021-09-10 02:23:16 +00:00
|
|
|
.unwrap(),
|
|
|
|
bad_request => http::Response::builder()
|
|
|
|
.status(StatusCode::BAD_REQUEST)
|
2021-12-06 04:57:15 +00:00
|
|
|
.body(boxed(Body::from(format!("{:?}", bad_request))))
|
2021-09-10 02:23:16 +00:00
|
|
|
.unwrap(),
|
|
|
|
}
|
2021-08-01 09:44:28 +00:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
impl From<ParseRequestError> for GraphQLRejection {
|
|
|
|
fn from(err: ParseRequestError) -> Self {
|
|
|
|
GraphQLRejection(err)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[async_trait::async_trait]
|
|
|
|
impl<B> FromRequest<B> for GraphQLRequest
|
|
|
|
where
|
|
|
|
B: http_body::Body + Unpin + Send + Sync + 'static,
|
2021-08-21 21:43:20 +00:00
|
|
|
B::Data: Into<Bytes>,
|
|
|
|
B::Error: Into<BoxError>,
|
2021-08-01 09:44:28 +00:00
|
|
|
{
|
|
|
|
type Rejection = rejection::GraphQLRejection;
|
|
|
|
|
|
|
|
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
|
|
|
Ok(GraphQLRequest(
|
|
|
|
GraphQLBatchRequest::from_request(req)
|
|
|
|
.await?
|
|
|
|
.0
|
|
|
|
.into_single()?,
|
|
|
|
))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
/// Extractor for GraphQL batch request.
|
|
|
|
pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest);
|
|
|
|
|
|
|
|
impl GraphQLBatchRequest {
|
|
|
|
/// Unwraps the value to `async_graphql::BatchRequest`.
|
|
|
|
#[must_use]
|
|
|
|
pub fn into_inner(self) -> async_graphql::BatchRequest {
|
|
|
|
self.0
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
#[async_trait::async_trait]
|
|
|
|
impl<B> FromRequest<B> for GraphQLBatchRequest
|
|
|
|
where
|
|
|
|
B: http_body::Body + Unpin + Send + Sync + 'static,
|
2021-08-21 21:43:20 +00:00
|
|
|
B::Data: Into<Bytes>,
|
|
|
|
B::Error: Into<BoxError>,
|
2021-08-01 09:44:28 +00:00
|
|
|
{
|
|
|
|
type Rejection = rejection::GraphQLRejection;
|
|
|
|
|
|
|
|
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
|
2021-08-21 21:43:20 +00:00
|
|
|
if let (&Method::GET, uri) = (req.method(), req.uri()) {
|
2021-08-01 09:44:28 +00:00
|
|
|
let res = serde_urlencoded::from_str(uri.query().unwrap_or_default()).map_err(|err| {
|
|
|
|
ParseRequestError::Io(std::io::Error::new(
|
|
|
|
ErrorKind::Other,
|
|
|
|
format!("failed to parse graphql request from uri query: {}", err),
|
|
|
|
))
|
|
|
|
});
|
|
|
|
Ok(Self(async_graphql::BatchRequest::Single(res?)))
|
|
|
|
} else {
|
|
|
|
let content_type = req
|
|
|
|
.headers()
|
|
|
|
.and_then(|headers| headers.get(http::header::CONTENT_TYPE))
|
|
|
|
.and_then(|value| value.to_str().ok())
|
|
|
|
.map(ToString::to_string);
|
|
|
|
let body_stream = BodyStream::from_request(req)
|
|
|
|
.await
|
|
|
|
.map_err(|_| {
|
|
|
|
ParseRequestError::Io(std::io::Error::new(
|
|
|
|
ErrorKind::Other,
|
|
|
|
"body has been taken by another extractor".to_string(),
|
|
|
|
))
|
|
|
|
})?
|
|
|
|
.map_err(|err| std::io::Error::new(ErrorKind::Other, err.to_string()));
|
|
|
|
let body_reader = tokio_util::io::StreamReader::new(body_stream).compat();
|
|
|
|
Ok(Self(
|
|
|
|
async_graphql::http::receive_batch_body(
|
|
|
|
content_type,
|
|
|
|
body_reader,
|
|
|
|
MultipartOptions::default(),
|
|
|
|
)
|
|
|
|
.await?,
|
|
|
|
))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|