async-graphql/src/http/multipart.rs

155 lines
4.7 KiB
Rust
Raw Normal View History

2020-09-10 08:54:38 +00:00
use crate::{ParseRequestError, Request};
use bytes::Bytes;
2020-09-10 07:04:24 +00:00
use futures::io::AsyncRead;
use futures::stream::Stream;
2020-09-10 08:54:38 +00:00
use multer::{Constraints, Multipart, SizeLimit};
2020-09-14 01:46:22 +00:00
use pin_project_lite::pin_project;
2020-09-10 07:04:24 +00:00
use std::collections::HashMap;
2020-09-10 08:54:38 +00:00
use std::io::{self, Seek, SeekFrom, Write};
2020-09-10 07:04:24 +00:00
use std::pin::Pin;
use std::task::{Context, Poll};
2020-09-10 07:04:24 +00:00
/// Options for `receive_multipart`.
#[derive(Default, Clone)]
2020-09-14 19:16:41 +00:00
#[non_exhaustive]
2020-09-15 18:32:13 +00:00
#[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))]
2020-09-10 07:04:24 +00:00
pub struct MultipartOptions {
2020-09-14 19:16:41 +00:00
/// The maximum file size.
pub max_file_size: Option<usize>,
/// The maximum number of files.
pub max_num_files: Option<usize>,
2020-09-11 23:43:50 +00:00
}
impl MultipartOptions {
/// Set maximum file size.
pub fn max_file_size(self, size: usize) -> Self {
MultipartOptions {
max_file_size: Some(size),
..self
}
}
2020-09-10 07:04:24 +00:00
2020-09-11 23:43:50 +00:00
/// Set maximum number of files.
pub fn max_num_files(self, n: usize) -> Self {
MultipartOptions {
max_num_files: Some(n),
..self
}
}
2020-09-10 07:04:24 +00:00
}
/// Receive a multipart request.
2020-09-15 18:32:13 +00:00
#[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))]
2020-09-10 07:04:24 +00:00
pub async fn receive_multipart(
2020-09-11 07:54:56 +00:00
body: impl AsyncRead + Send + 'static,
2020-09-10 07:04:24 +00:00
boundary: impl Into<String>,
opts: MultipartOptions,
2020-09-10 08:54:38 +00:00
) -> Result<Request, ParseRequestError> {
2020-09-10 07:04:24 +00:00
let mut multipart = Multipart::new_with_constraints(
ReaderStream::new(body),
2020-09-10 07:04:24 +00:00
boundary,
Constraints::new().size_limit({
let mut limit = SizeLimit::new();
if let (Some(max_file_size), Some(max_num_files)) =
(opts.max_file_size, opts.max_file_size)
{
limit = limit.whole_stream((max_file_size * max_num_files) as u64);
}
if let Some(max_file_size) = opts.max_file_size {
limit = limit.per_field(max_file_size as u64);
}
limit
}),
);
2020-09-10 08:54:38 +00:00
let mut request = None;
2020-09-10 07:04:24 +00:00
let mut map = None;
let mut files = Vec::new();
while let Some(mut field) = multipart.next_field().await? {
match field.name() {
Some("operations") => {
let request_str = field.text().await?;
2020-09-10 08:54:38 +00:00
request = Some(
2020-09-12 16:07:46 +00:00
serde_json::from_str::<Request>(&request_str)
.map_err(ParseRequestError::InvalidRequest)?,
2020-09-10 08:54:38 +00:00
);
2020-09-10 07:04:24 +00:00
}
Some("map") => {
let map_str = field.text().await?;
map = Some(
serde_json::from_str::<HashMap<String, Vec<String>>>(&map_str)
.map_err(ParseRequestError::InvalidFilesMap)?,
);
}
_ => {
if let Some(name) = field.name().map(ToString::to_string) {
if let Some(filename) = field.file_name().map(ToString::to_string) {
2020-09-10 08:54:38 +00:00
let content_type = field.content_type().map(|mime| mime.to_string());
let mut file = tempfile::tempfile().map_err(ParseRequestError::Io)?;
2020-09-10 07:04:24 +00:00
while let Some(chunk) = field.chunk().await.unwrap() {
file.write(&chunk).map_err(ParseRequestError::Io)?;
}
file.seek(SeekFrom::Start(0))?;
files.push((name, filename, content_type, file));
}
}
}
}
}
2020-09-10 08:54:38 +00:00
let mut request: Request = request.ok_or(ParseRequestError::MissingOperatorsPart)?;
2020-09-10 07:04:24 +00:00
let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?;
for (name, filename, content_type, file) in files {
if let Some(var_paths) = map.remove(&name) {
for var_path in var_paths {
2020-09-10 08:54:38 +00:00
request.set_upload(
2020-09-10 07:04:24 +00:00
&var_path,
filename.clone(),
content_type.clone(),
file.try_clone().unwrap(),
);
}
}
}
if !map.is_empty() {
return Err(ParseRequestError::MissingFiles);
}
2020-09-10 08:54:38 +00:00
Ok(request)
2020-09-10 07:04:24 +00:00
}
pin_project! {
struct ReaderStream<T> {
buf: [u8; 2048],
#[pin]
reader: T,
}
}
impl<T> ReaderStream<T> {
fn new(reader: T) -> Self {
Self {
buf: [0; 2048],
reader,
}
}
}
impl<T: AsyncRead> Stream for ReaderStream<T> {
type Item = io::Result<Bytes>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
let this = self.project();
2020-09-10 07:04:24 +00:00
Poll::Ready(
match futures::ready!(this.reader.poll_read(cx, this.buf)?) {
2020-09-10 07:04:24 +00:00
0 => None,
size => Some(Ok(Bytes::copy_from_slice(&this.buf[..size]))),
2020-09-14 01:46:22 +00:00
},
2020-09-10 07:04:24 +00:00
)
}
2020-09-10 07:04:24 +00:00
}