diff --git a/src/http/mod.rs b/src/http/mod.rs index 533b5704..7c122bf4 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -3,14 +3,17 @@ mod graphiql_source; mod playground_source; mod stream_body; +mod multipart; use itertools::Itertools; pub use graphiql_source::graphiql_source; pub use playground_source::{playground_source, GraphQLPlaygroundConfig}; pub use stream_body::StreamBody; +pub use multipart::{receive_multipart, MultipartOptions}; -use crate::{Error, GQLQueryResponse, ParseRequestError, Pos, QueryError, Result, Variables}; +use crate::{GQLQuery, ParseRequestError, Pos, QueryError, Variables}; +use futures::io::AsyncRead; use serde::ser::{SerializeMap, SerializeSeq}; use serde::{Deserialize, Serialize, Serializer}; @@ -28,6 +31,24 @@ pub struct GQLRequest { pub variables: Option, } +/// Receive a GraphQL request from a content type and body. +pub async fn receive_body( + content_type: Option>, + body: impl AsyncRead, + opts: MultipartOptions, +) -> Result { + if let Some(Ok(boundary)) = content_type.map(multer::parse_boundary) { + receive_multipart(body, boundary, opts) + } else { + futures::pin_mut!(body); + let mut data = Vec::new(); + body.read_to_end(data).await.map_err(ParseRequestError::Io)?; + Ok(GQLQuery::new_with_http_request( + serde_json::from_slice(&data).map_err(ParseRequestError::InvalidRequest)?, + )) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/http/multipart.rs b/src/http/multipart.rs new file mode 100644 index 00000000..796037cb --- /dev/null +++ b/src/http/multipart.rs @@ -0,0 +1,119 @@ +use futures::stream::{self, Stream}; +use futures::io::AsyncRead; +use crate::{GQLQuery, ParseRequestError}; +use std::collections::HashMap; +use multer::{Multipart, Constraints, SizeLimit}; +use std::io::{self, SeekFrom}; +use std::task::Poll; +use bytes::Bytes; +use std::pin::Pin; + +/// Options for `receive_multipart`. +#[derive(Default, Clone)] +#[non_exhaustive] +pub struct MultipartOptions { + /// Maximum file size. + pub max_file_size: Option, + + /// Maximum number of files. + pub max_num_files: Option, +} + +/// Receive a multipart request. +pub async fn receive_multipart( + body: impl AsyncRead, + boundary: impl Into, + opts: MultipartOptions, +) -> Result { + let mut multipart = Multipart::new_with_constraints( + reader_stream(body), + boundary, + Constraints::new().size_limit({ + let mut limit = SizeLimit::new(); + if let (Some(max_file_size), Some(max_num_files)) = + (opts.max_file_size, opts.max_file_size) + { + limit = limit.whole_stream((max_file_size * max_num_files) as u64); + } + if let Some(max_file_size) = opts.max_file_size { + limit = limit.per_field(max_file_size as u64); + } + limit + }), + ); + + let mut query = None; + let mut map = None; + let mut files = Vec::new(); + + while let Some(mut field) = multipart.next_field().await? { + match field.name() { + Some("operations") => { + let request_str = field.text().await?; + query = Some(GQLQuery::new_with_http_request( + serde_json::from_str(&request_str) + .map_err(ParseRequestError::InvalidRequest)?, + )); + } + Some("map") => { + let map_str = field.text().await?; + map = Some( + serde_json::from_str::>>(&map_str) + .map_err(ParseRequestError::InvalidFilesMap)?, + ); + } + _ => { + if let Some(name) = field.name().map(ToString::to_string) { + if let Some(filename) = field.file_name().map(ToString::to_string) { + let content_type = + field.content_type().map(|mime| mime.to_string()); + let mut file = + tempfile::tempfile().map_err(ParseRequestError::Io)?; + while let Some(chunk) = field.chunk().await.unwrap() { + file.write(&chunk).map_err(ParseRequestError::Io)?; + } + file.seek(SeekFrom::Start(0))?; + files.push((name, filename, content_type, file)); + } + } + } + } + } + + let mut query = query.ok_or(ParseRequestError::MissingOperatorsPart)?; + let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?; + + for (name, filename, content_type, file) in files { + if let Some(var_paths) = map.remove(&name) { + for var_path in var_paths { + query.set_upload( + &var_path, + filename.clone(), + content_type.clone(), + file.try_clone().unwrap(), + ); + } + } + } + + if !map.is_empty() { + return Err(ParseRequestError::MissingFiles); + } + + Ok(query) +} + +fn reader_stream( + mut reader: impl AsyncRead + Unpin + Send + 'static, +) -> impl Stream> + Unpin + Send + 'static { + let mut buf = [0u8; 2048]; + + stream::poll_fn(move |cx| { + Poll::Ready( + match futures::ready!(Pin::new(&mut reader).poll_read(cx, &mut buf)?) { + 0 => None, + size => Some(Ok(Bytes::copy_from_slice(&buf[..size]))), + }, + ) + }) +} diff --git a/src/query.rs b/src/query.rs index 240b6530..7319f4c3 100644 --- a/src/query.rs +++ b/src/query.rs @@ -20,16 +20,6 @@ use std::io::{Seek, SeekFrom, Write}; use std::sync::atomic::AtomicUsize; use std::sync::Arc; -/// Options for `GQLQuery::receive_multipart` -#[derive(Default, Clone)] -pub struct ReceiveMultipartOptions { - /// Maximum file size. - pub max_file_size: Option, - - /// Maximum number of files. - pub max_num_files: Option, -} - pub struct GQLQuery { pub(crate) query: String, pub(crate) operation_name: Option, @@ -63,100 +53,6 @@ impl GQLQuery { } } - pub async fn receive_request( - content_type: Option>, - mut body: impl AsyncRead, - opts: ReceiveMultipartOptions, - ) -> std::result::Result { - if let Some(boundary) = content_type.and_then(|ct| multer::parse_boundary(ct).ok()) { - // multipart - let mut multipart = Multipart::new_with_constraints( - reader_stream(body), - boundary, - Constraints::new().size_limit({ - let mut limit = SizeLimit::new(); - if let (Some(max_file_size), Some(max_num_files)) = - (opts.max_file_size, opts.max_file_size) - { - limit = limit.whole_stream((max_file_size * max_num_files) as u64); - } - if let Some(max_file_size) = opts.max_file_size { - limit = limit.per_field(max_file_size as u64); - } - limit - }), - ); - - let mut query = None; - let mut map = None; - let mut files = Vec::new(); - - while let Some(mut field) = multipart.next_field().await? { - match field.name() { - Some("operations") => { - let request_str = field.text().await?; - query = Some(Self::new_with_http_request( - serde_json::from_str(&request_str) - .map_err(ParseRequestError::InvalidRequest)?, - )); - } - Some("map") => { - let map_str = field.text().await?; - map = Some( - serde_json::from_str::>>(&map_str) - .map_err(ParseRequestError::InvalidFilesMap)?, - ); - } - _ => { - if let Some(name) = field.name().map(ToString::to_string) { - if let Some(filename) = field.file_name().map(ToString::to_string) { - let content_type = - field.content_type().map(|mime| mime.to_string()); - let mut file = - tempfile::tempfile().map_err(ParseRequestError::Io)?; - while let Some(chunk) = field.chunk().await.unwrap() { - file.write(&chunk).map_err(ParseRequestError::Io)?; - } - file.seek(SeekFrom::Start(0))?; - files.push((name, filename, content_type, file)); - } - } - } - } - } - - let mut query = query.ok_or(ParseRequestError::MissingOperatorsPart)?; - let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?; - - for (name, filename, content_type, file) in files { - if let Some(var_paths) = map.remove(&name) { - for var_path in var_paths { - query.set_upload( - &var_path, - filename.clone(), - content_type.clone(), - file.try_clone().unwrap(), - ); - } - } - } - - if !map.is_empty() { - return Err(ParseRequestError::MissingFiles); - } - - Ok(query) - } else { - let mut data = Vec::new(); - body.read_to_end(&mut data) - .await - .map_err(ParseRequestError::Io)?; - Ok(Self::new_with_http_request( - serde_json::from_slice(&data).map_err(ParseRequestError::InvalidRequest)?, - )) - } - } - /// Specify the operation name. pub fn operation_name>(self, name: T) -> Self { Self { @@ -258,18 +154,3 @@ impl From for GQLQueryResponse { } } } - -fn reader_stream( - mut reader: impl AsyncRead + Unpin + Send + 'static, -) -> impl Stream> + Unpin + Send + 'static { - let mut buf = [0u8; 2048]; - - stream::poll_fn(move |cx| { - Poll::Ready( - match futures::ready!(Pin::new(&mut reader).poll_read(cx, &mut buf)?) { - 0 => None, - size => Some(Ok(Bytes::copy_from_slice(&buf[..size]))), - }, - ) - }) -}