Remove code duplication for batch requests

This commit is contained in:
Koxiaet 2020-09-17 19:43:03 +01:00
parent 5cf527aeef
commit 3c04411780
5 changed files with 70 additions and 162 deletions

View File

@ -315,6 +315,7 @@ impl QueryError {
/// An error parsing the request. /// An error parsing the request.
#[derive(Debug, Error)] #[derive(Debug, Error)]
#[non_exhaustive]
pub enum ParseRequestError { pub enum ParseRequestError {
/// An IO error occurred. /// An IO error occurred.
#[error("{0}")] #[error("{0}")]
@ -353,6 +354,10 @@ pub enum ParseRequestError {
/// The request's payload is too large, and this server rejected it. /// The request's payload is too large, and this server rejected it.
#[error("Payload too large")] #[error("Payload too large")]
PayloadTooLarge, PayloadTooLarge,
/// The request is a batch request, but the server does not support batch requests.
#[error("Batch requests are not supported")]
UnsupportedBatch,
} }
#[cfg(feature = "multipart")] #[cfg(feature = "multipart")]

View File

@ -1,141 +0,0 @@
#[cfg(feature = "multipart")]
use crate::http::{multipart::ReaderStream, MultipartOptions};
use crate::{BatchRequest, ParseRequestError};
use futures::{AsyncRead, AsyncReadExt};
#[cfg(feature = "multipart")]
use multer::{Constraints, Multipart, SizeLimit};
#[cfg(feature = "multipart")]
use std::{
collections::HashMap,
io::{Seek, SeekFrom, Write},
};
#[cfg(feature = "multipart")]
#[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))]
/// Receive a GraphQL batch request from a content type and body.
pub async fn receive_batch_body(
content_type: Option<impl AsRef<str>>,
body: impl AsyncRead + Send + 'static,
opts: MultipartOptions,
) -> Result<BatchRequest, ParseRequestError> {
if let Some(Ok(boundary)) = content_type.map(multer::parse_boundary) {
receive_batch_multipart(body, boundary, opts).await
} else {
receive_batch_json(body).await
}
}
/// Receive a GraphQL batch request from a body as JSON.
pub async fn receive_batch_json(
body: impl AsyncRead + Send + 'static,
) -> Result<BatchRequest, ParseRequestError> {
let mut data = Vec::new();
futures::pin_mut!(body);
body.read_to_end(&mut data)
.await
.map_err(ParseRequestError::Io)?;
Ok(serde_json::from_slice::<BatchRequest>(&data).map_err(ParseRequestError::InvalidRequest)?)
}
#[cfg(feature = "multipart")]
async fn receive_batch_multipart(
body: impl AsyncRead + Send + 'static,
boundary: impl Into<String>,
opts: MultipartOptions,
) -> Result<BatchRequest, ParseRequestError> {
let mut multipart = Multipart::new_with_constraints(
ReaderStream::new(body),
boundary,
Constraints::new().size_limit({
let mut limit = SizeLimit::new();
if let (Some(max_file_size), Some(max_num_files)) =
(opts.max_file_size, opts.max_file_size)
{
limit = limit.whole_stream((max_file_size * max_num_files) as u64);
}
if let Some(max_file_size) = opts.max_file_size {
limit = limit.per_field(max_file_size as u64);
}
limit
}),
);
let mut batch_request = None;
let mut map = None;
let mut files = Vec::new();
while let Some(mut field) = multipart.next_field().await? {
match field.name() {
Some("operations") => {
let request_str = field.text().await?;
batch_request = Some(
serde_json::from_str::<BatchRequest>(&request_str)
.map_err(ParseRequestError::InvalidRequest)?,
);
}
Some("map") => {
let map_str = field.text().await?;
map = Some(
serde_json::from_str::<HashMap<String, Vec<String>>>(&map_str)
.map_err(ParseRequestError::InvalidFilesMap)?,
);
}
_ => {
if let Some(name) = field.name().map(ToString::to_string) {
if let Some(filename) = field.file_name().map(ToString::to_string) {
let content_type = field.content_type().map(|mime| mime.to_string());
let mut file = tempfile::tempfile().map_err(ParseRequestError::Io)?;
while let Some(chunk) = field.chunk().await.unwrap() {
file.write(&chunk).map_err(ParseRequestError::Io)?;
}
file.seek(SeekFrom::Start(0))?;
files.push((name, filename, content_type, file));
}
}
}
}
}
let mut batch_request: BatchRequest =
batch_request.ok_or(ParseRequestError::MissingOperatorsPart)?;
let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?;
for (name, filename, content_type, file) in files {
if let Some(var_paths) = map.remove(&name) {
for var_path in var_paths {
match &mut batch_request {
BatchRequest::Single(request) => {
request.set_upload(
&var_path,
filename.clone(),
content_type.clone(),
file.try_clone().unwrap(),
);
}
BatchRequest::Batch(requests) => {
let mut s = var_path.splitn(2, '.');
let idx = s.next().and_then(|idx| idx.parse::<usize>().ok());
let path = s.next();
if let (Some(idx), Some(path)) = (idx, path) {
if let Some(request) = requests.get_mut(idx) {
request.set_upload(
path,
filename.clone(),
content_type.clone(),
file.try_clone().unwrap(),
);
}
}
}
}
}
}
}
if !map.is_empty() {
return Err(ParseRequestError::MissingFiles);
}
Ok(batch_request)
}

View File

@ -1,22 +1,18 @@
//! A helper module that supports HTTP //! A helper module that supports HTTP
mod batch;
mod graphiql_source; mod graphiql_source;
#[cfg(feature = "multipart")] #[cfg(feature = "multipart")]
mod multipart; mod multipart;
mod playground_source; mod playground_source;
mod websocket; mod websocket;
#[cfg(feature = "multipart")]
pub use batch::receive_batch_body;
pub use batch::receive_batch_json;
pub use graphiql_source::graphiql_source; pub use graphiql_source::graphiql_source;
#[cfg(feature = "multipart")] #[cfg(feature = "multipart")]
pub use multipart::MultipartOptions; pub use multipart::MultipartOptions;
pub use playground_source::{playground_source, GraphQLPlaygroundConfig}; pub use playground_source::{playground_source, GraphQLPlaygroundConfig};
pub use websocket::WebSocket; pub use websocket::WebSocket;
use crate::{ParseRequestError, Request}; use crate::{BatchRequest, ParseRequestError, Request};
use futures::io::AsyncRead; use futures::io::AsyncRead;
use futures::AsyncReadExt; use futures::AsyncReadExt;
@ -28,10 +24,23 @@ pub async fn receive_body(
body: impl AsyncRead + Send + 'static, body: impl AsyncRead + Send + 'static,
opts: MultipartOptions, opts: MultipartOptions,
) -> Result<Request, ParseRequestError> { ) -> Result<Request, ParseRequestError> {
receive_batch_body(content_type, body, opts)
.await?
.into_single()
}
#[cfg(feature = "multipart")]
#[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))]
/// Receive a GraphQL request from a content type and body.
pub async fn receive_batch_body(
content_type: Option<impl AsRef<str>>,
body: impl AsyncRead + Send + 'static,
opts: MultipartOptions,
) -> Result<BatchRequest, ParseRequestError> {
if let Some(Ok(boundary)) = content_type.map(multer::parse_boundary) { if let Some(Ok(boundary)) = content_type.map(multer::parse_boundary) {
multipart::receive_multipart(body, boundary, opts).await multipart::receive_batch_multipart(body, boundary, opts).await
} else { } else {
receive_json(body).await receive_batch_json(body).await
} }
} }
@ -39,10 +48,17 @@ pub async fn receive_body(
pub async fn receive_json( pub async fn receive_json(
body: impl AsyncRead + Send + 'static, body: impl AsyncRead + Send + 'static,
) -> Result<Request, ParseRequestError> { ) -> Result<Request, ParseRequestError> {
receive_batch_json(body).await?.into_single()
}
/// Receive a GraphQL batch request from a body as JSON.
pub async fn receive_batch_json(
body: impl AsyncRead + Send + 'static,
) -> Result<BatchRequest, ParseRequestError> {
let mut data = Vec::new(); let mut data = Vec::new();
futures::pin_mut!(body); futures::pin_mut!(body);
body.read_to_end(&mut data) body.read_to_end(&mut data)
.await .await
.map_err(ParseRequestError::Io)?; .map_err(ParseRequestError::Io)?;
Ok(serde_json::from_slice::<Request>(&data).map_err(ParseRequestError::InvalidRequest)?) Ok(serde_json::from_slice::<BatchRequest>(&data).map_err(ParseRequestError::InvalidRequest)?)
} }

View File

@ -1,4 +1,4 @@
use crate::{ParseRequestError, Request}; use crate::{BatchRequest, ParseRequestError};
use bytes::Bytes; use bytes::Bytes;
use futures::io::AsyncRead; use futures::io::AsyncRead;
use futures::stream::Stream; use futures::stream::Stream;
@ -38,12 +38,11 @@ impl MultipartOptions {
} }
} }
/// Receive a multipart request. pub(super) async fn receive_batch_multipart(
pub(crate) async fn receive_multipart(
body: impl AsyncRead + Send + 'static, body: impl AsyncRead + Send + 'static,
boundary: impl Into<String>, boundary: impl Into<String>,
opts: MultipartOptions, opts: MultipartOptions,
) -> Result<Request, ParseRequestError> { ) -> Result<BatchRequest, ParseRequestError> {
let mut multipart = Multipart::new_with_constraints( let mut multipart = Multipart::new_with_constraints(
ReaderStream::new(body), ReaderStream::new(body),
boundary, boundary,
@ -70,7 +69,7 @@ pub(crate) async fn receive_multipart(
Some("operations") => { Some("operations") => {
let request_str = field.text().await?; let request_str = field.text().await?;
request = Some( request = Some(
serde_json::from_str::<Request>(&request_str) serde_json::from_str::<BatchRequest>(&request_str)
.map_err(ParseRequestError::InvalidRequest)?, .map_err(ParseRequestError::InvalidRequest)?,
); );
} }
@ -97,18 +96,38 @@ pub(crate) async fn receive_multipart(
} }
} }
let mut request: Request = request.ok_or(ParseRequestError::MissingOperatorsPart)?; let mut request: BatchRequest = request.ok_or(ParseRequestError::MissingOperatorsPart)?;
let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?; let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?;
for (name, filename, content_type, file) in files { for (name, filename, content_type, file) in files {
if let Some(var_paths) = map.remove(&name) { if let Some(var_paths) = map.remove(&name) {
for var_path in var_paths { for var_path in var_paths {
request.set_upload( match &mut request {
&var_path, BatchRequest::Single(request) => {
filename.clone(), request.set_upload(
content_type.clone(), &var_path,
file.try_clone().unwrap(), filename.clone(),
); content_type.clone(),
file.try_clone().unwrap(),
);
}
BatchRequest::Batch(requests) => {
let mut s = var_path.splitn(2, '.');
let idx = s.next().and_then(|idx| idx.parse::<usize>().ok());
let path = s.next();
if let (Some(idx), Some(path)) = (idx, path) {
if let Some(request) = requests.get_mut(idx) {
request.set_upload(
path,
filename.clone(),
content_type.clone(),
file.try_clone().unwrap(),
);
}
}
}
}
} }
} }
} }

View File

@ -1,5 +1,5 @@
use crate::parser::types::UploadValue; use crate::parser::types::UploadValue;
use crate::{Data, Value, Variables}; use crate::{Data, ParseRequestError, Value, Variables};
use serde::{Deserialize, Deserializer}; use serde::{Deserialize, Deserializer};
use std::any::Any; use std::any::Any;
use std::fs::File; use std::fs::File;
@ -101,6 +101,15 @@ pub enum BatchRequest {
Batch(Vec<Request>), Batch(Vec<Request>),
} }
impl BatchRequest {
pub(crate) fn into_single(self) -> Result<Request, ParseRequestError> {
match self {
Self::Single(req) => Ok(req),
Self::Batch(_) => Err(ParseRequestError::UnsupportedBatch),
}
}
}
fn deserialize_non_empty_vec<'de, D, T>(deserializer: D) -> std::result::Result<Vec<T>, D::Error> fn deserialize_non_empty_vec<'de, D, T>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
where where
D: Deserializer<'de>, D: Deserializer<'de>,