From 3c0441178061429406ae94adab43834eeea46a32 Mon Sep 17 00:00:00 2001 From: Koxiaet <38139193+Koxiaet@users.noreply.github.com> Date: Thu, 17 Sep 2020 19:43:03 +0100 Subject: [PATCH] Remove code duplication for batch requests --- src/error.rs | 5 ++ src/http/batch.rs | 141 ------------------------------------------ src/http/mod.rs | 32 +++++++--- src/http/multipart.rs | 43 +++++++++---- src/request.rs | 11 +++- 5 files changed, 70 insertions(+), 162 deletions(-) delete mode 100644 src/http/batch.rs diff --git a/src/error.rs b/src/error.rs index 73811d00..51bdf211 100644 --- a/src/error.rs +++ b/src/error.rs @@ -315,6 +315,7 @@ impl QueryError { /// An error parsing the request. #[derive(Debug, Error)] +#[non_exhaustive] pub enum ParseRequestError { /// An IO error occurred. #[error("{0}")] @@ -353,6 +354,10 @@ pub enum ParseRequestError { /// The request's payload is too large, and this server rejected it. #[error("Payload too large")] PayloadTooLarge, + + /// The request is a batch request, but the server does not support batch requests. + #[error("Batch requests are not supported")] + UnsupportedBatch, } #[cfg(feature = "multipart")] diff --git a/src/http/batch.rs b/src/http/batch.rs deleted file mode 100644 index 156b09de..00000000 --- a/src/http/batch.rs +++ /dev/null @@ -1,141 +0,0 @@ -#[cfg(feature = "multipart")] -use crate::http::{multipart::ReaderStream, MultipartOptions}; -use crate::{BatchRequest, ParseRequestError}; -use futures::{AsyncRead, AsyncReadExt}; -#[cfg(feature = "multipart")] -use multer::{Constraints, Multipart, SizeLimit}; -#[cfg(feature = "multipart")] -use std::{ - collections::HashMap, - io::{Seek, SeekFrom, Write}, -}; - -#[cfg(feature = "multipart")] -#[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))] -/// Receive a GraphQL batch request from a content type and body. -pub async fn receive_batch_body( - content_type: Option>, - body: impl AsyncRead + Send + 'static, - opts: MultipartOptions, -) -> Result { - if let Some(Ok(boundary)) = content_type.map(multer::parse_boundary) { - receive_batch_multipart(body, boundary, opts).await - } else { - receive_batch_json(body).await - } -} - -/// Receive a GraphQL batch request from a body as JSON. -pub async fn receive_batch_json( - body: impl AsyncRead + Send + 'static, -) -> Result { - let mut data = Vec::new(); - futures::pin_mut!(body); - body.read_to_end(&mut data) - .await - .map_err(ParseRequestError::Io)?; - Ok(serde_json::from_slice::(&data).map_err(ParseRequestError::InvalidRequest)?) -} - -#[cfg(feature = "multipart")] -async fn receive_batch_multipart( - body: impl AsyncRead + Send + 'static, - boundary: impl Into, - opts: MultipartOptions, -) -> Result { - let mut multipart = Multipart::new_with_constraints( - ReaderStream::new(body), - boundary, - Constraints::new().size_limit({ - let mut limit = SizeLimit::new(); - if let (Some(max_file_size), Some(max_num_files)) = - (opts.max_file_size, opts.max_file_size) - { - limit = limit.whole_stream((max_file_size * max_num_files) as u64); - } - if let Some(max_file_size) = opts.max_file_size { - limit = limit.per_field(max_file_size as u64); - } - limit - }), - ); - - let mut batch_request = None; - let mut map = None; - let mut files = Vec::new(); - - while let Some(mut field) = multipart.next_field().await? { - match field.name() { - Some("operations") => { - let request_str = field.text().await?; - batch_request = Some( - serde_json::from_str::(&request_str) - .map_err(ParseRequestError::InvalidRequest)?, - ); - } - Some("map") => { - let map_str = field.text().await?; - map = Some( - serde_json::from_str::>>(&map_str) - .map_err(ParseRequestError::InvalidFilesMap)?, - ); - } - _ => { - if let Some(name) = field.name().map(ToString::to_string) { - if let Some(filename) = field.file_name().map(ToString::to_string) { - let content_type = field.content_type().map(|mime| mime.to_string()); - let mut file = tempfile::tempfile().map_err(ParseRequestError::Io)?; - while let Some(chunk) = field.chunk().await.unwrap() { - file.write(&chunk).map_err(ParseRequestError::Io)?; - } - file.seek(SeekFrom::Start(0))?; - files.push((name, filename, content_type, file)); - } - } - } - } - } - - let mut batch_request: BatchRequest = - batch_request.ok_or(ParseRequestError::MissingOperatorsPart)?; - let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?; - - for (name, filename, content_type, file) in files { - if let Some(var_paths) = map.remove(&name) { - for var_path in var_paths { - match &mut batch_request { - BatchRequest::Single(request) => { - request.set_upload( - &var_path, - filename.clone(), - content_type.clone(), - file.try_clone().unwrap(), - ); - } - BatchRequest::Batch(requests) => { - let mut s = var_path.splitn(2, '.'); - let idx = s.next().and_then(|idx| idx.parse::().ok()); - let path = s.next(); - - if let (Some(idx), Some(path)) = (idx, path) { - if let Some(request) = requests.get_mut(idx) { - request.set_upload( - path, - filename.clone(), - content_type.clone(), - file.try_clone().unwrap(), - ); - } - } - } - } - } - } - } - - if !map.is_empty() { - return Err(ParseRequestError::MissingFiles); - } - - Ok(batch_request) -} diff --git a/src/http/mod.rs b/src/http/mod.rs index 0022a343..1ca3cd44 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -1,22 +1,18 @@ //! A helper module that supports HTTP -mod batch; mod graphiql_source; #[cfg(feature = "multipart")] mod multipart; mod playground_source; mod websocket; -#[cfg(feature = "multipart")] -pub use batch::receive_batch_body; -pub use batch::receive_batch_json; pub use graphiql_source::graphiql_source; #[cfg(feature = "multipart")] pub use multipart::MultipartOptions; pub use playground_source::{playground_source, GraphQLPlaygroundConfig}; pub use websocket::WebSocket; -use crate::{ParseRequestError, Request}; +use crate::{BatchRequest, ParseRequestError, Request}; use futures::io::AsyncRead; use futures::AsyncReadExt; @@ -28,10 +24,23 @@ pub async fn receive_body( body: impl AsyncRead + Send + 'static, opts: MultipartOptions, ) -> Result { + receive_batch_body(content_type, body, opts) + .await? + .into_single() +} + +#[cfg(feature = "multipart")] +#[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))] +/// Receive a GraphQL request from a content type and body. +pub async fn receive_batch_body( + content_type: Option>, + body: impl AsyncRead + Send + 'static, + opts: MultipartOptions, +) -> Result { if let Some(Ok(boundary)) = content_type.map(multer::parse_boundary) { - multipart::receive_multipart(body, boundary, opts).await + multipart::receive_batch_multipart(body, boundary, opts).await } else { - receive_json(body).await + receive_batch_json(body).await } } @@ -39,10 +48,17 @@ pub async fn receive_body( pub async fn receive_json( body: impl AsyncRead + Send + 'static, ) -> Result { + receive_batch_json(body).await?.into_single() +} + +/// Receive a GraphQL batch request from a body as JSON. +pub async fn receive_batch_json( + body: impl AsyncRead + Send + 'static, +) -> Result { let mut data = Vec::new(); futures::pin_mut!(body); body.read_to_end(&mut data) .await .map_err(ParseRequestError::Io)?; - Ok(serde_json::from_slice::(&data).map_err(ParseRequestError::InvalidRequest)?) + Ok(serde_json::from_slice::(&data).map_err(ParseRequestError::InvalidRequest)?) } diff --git a/src/http/multipart.rs b/src/http/multipart.rs index 1664aae3..540ffffa 100644 --- a/src/http/multipart.rs +++ b/src/http/multipart.rs @@ -1,4 +1,4 @@ -use crate::{ParseRequestError, Request}; +use crate::{BatchRequest, ParseRequestError}; use bytes::Bytes; use futures::io::AsyncRead; use futures::stream::Stream; @@ -38,12 +38,11 @@ impl MultipartOptions { } } -/// Receive a multipart request. -pub(crate) async fn receive_multipart( +pub(super) async fn receive_batch_multipart( body: impl AsyncRead + Send + 'static, boundary: impl Into, opts: MultipartOptions, -) -> Result { +) -> Result { let mut multipart = Multipart::new_with_constraints( ReaderStream::new(body), boundary, @@ -70,7 +69,7 @@ pub(crate) async fn receive_multipart( Some("operations") => { let request_str = field.text().await?; request = Some( - serde_json::from_str::(&request_str) + serde_json::from_str::(&request_str) .map_err(ParseRequestError::InvalidRequest)?, ); } @@ -97,18 +96,38 @@ pub(crate) async fn receive_multipart( } } - let mut request: Request = request.ok_or(ParseRequestError::MissingOperatorsPart)?; + let mut request: BatchRequest = request.ok_or(ParseRequestError::MissingOperatorsPart)?; let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?; for (name, filename, content_type, file) in files { if let Some(var_paths) = map.remove(&name) { for var_path in var_paths { - request.set_upload( - &var_path, - filename.clone(), - content_type.clone(), - file.try_clone().unwrap(), - ); + match &mut request { + BatchRequest::Single(request) => { + request.set_upload( + &var_path, + filename.clone(), + content_type.clone(), + file.try_clone().unwrap(), + ); + } + BatchRequest::Batch(requests) => { + let mut s = var_path.splitn(2, '.'); + let idx = s.next().and_then(|idx| idx.parse::().ok()); + let path = s.next(); + + if let (Some(idx), Some(path)) = (idx, path) { + if let Some(request) = requests.get_mut(idx) { + request.set_upload( + path, + filename.clone(), + content_type.clone(), + file.try_clone().unwrap(), + ); + } + } + } + } } } } diff --git a/src/request.rs b/src/request.rs index 33d12f42..b8601f2c 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,5 +1,5 @@ use crate::parser::types::UploadValue; -use crate::{Data, Value, Variables}; +use crate::{Data, ParseRequestError, Value, Variables}; use serde::{Deserialize, Deserializer}; use std::any::Any; use std::fs::File; @@ -101,6 +101,15 @@ pub enum BatchRequest { Batch(Vec), } +impl BatchRequest { + pub(crate) fn into_single(self) -> Result { + match self { + Self::Single(req) => Ok(req), + Self::Batch(_) => Err(ParseRequestError::UnsupportedBatch), + } + } +} + fn deserialize_non_empty_vec<'de, D, T>(deserializer: D) -> std::result::Result, D::Error> where D: Deserializer<'de>,