Replace multipart parser with multer.

This commit is contained in:
sunli 2020-05-17 16:26:36 +08:00
parent e2f6475aa6
commit 44b1eb4fb4
9 changed files with 112 additions and 580 deletions

View File

@ -43,6 +43,8 @@ fnv = "1.0.6"
regex = "1.3.5"
tracing = "0.1.13"
indexmap = "1.3.2"
async-stream = "0.2.1"
multer = "1.0.3"
bson = { version = "0.14.1", optional = true }
uuid = { version = "0.8.1", optional = true }
url = { version = "2.1.1", optional = true }

View File

@ -7,11 +7,10 @@ mod subscription;
use actix_web::dev::{Payload, PayloadStream};
use actix_web::{http, web, Error, FromRequest, HttpRequest};
use async_graphql::http::StreamBody;
use async_graphql::{IntoQueryBuilder, IntoQueryBuilderOpts, ParseRequestError, QueryBuilder};
use async_graphql::{IntoQueryBuilder, IntoQueryBuilderOpts, QueryBuilder};
use futures::channel::mpsc;
use futures::{Future, SinkExt, StreamExt, TryFutureExt};
use std::pin::Pin;
pub use subscription::WSSubscription;
/// Extractor for GraphQL request
@ -56,12 +55,7 @@ impl FromRequest for GQLRequest {
(content_type, StreamBody::new(rx))
.into_query_builder_opts(&config)
.map_ok(GQLRequest)
.map_err(|err| match err {
ParseRequestError::TooManyFiles | ParseRequestError::TooLarge => {
actix_web::error::ErrorPayloadTooLarge(err)
}
_ => actix_web::error::ErrorBadRequest(err),
})
.map_err(actix_web::error::ErrorBadRequest)
.await
})
}

View File

@ -25,12 +25,11 @@ impl GQLRequestExt for Request {
&self,
opts: IntoQueryBuilderOpts,
) -> Result<QueryBuilder, ParseRequestError> {
let body = self.body().as_ref();
let ct = self
.headers()
.get("content-type")
.and_then(|value| value.to_str().ok());
(ct, AllowStdIo::new(Cursor::new(body)))
(ct, AllowStdIo::new(Cursor::new(self.body().to_vec())))
.into_query_builder_opts(&opts)
.await
}

View File

@ -207,8 +207,8 @@ fn upload() -> Result<()> {
};
let mut content = String::new();
file.into_read().read_to_string(&mut content).ok();
assert_eq!(content, "test\r\n".to_owned());
file.into_read().read_to_string(&mut content).unwrap();
assert_eq!(content, "test".to_owned());
file_info
}

View File

@ -303,7 +303,7 @@ pub enum ParseRequestError {
InvalidFilesMap(serde_json::Error),
#[error("Invalid multipart data")]
InvalidMultipart,
InvalidMultipart(String),
#[error("Missing \"operators\" part")]
MissingOperatorsPart,
@ -316,12 +316,6 @@ pub enum ParseRequestError {
#[error("Missing files")]
MissingFiles,
#[error("Too many files")]
TooManyFiles,
#[error("The file size is too large")]
TooLarge,
}
#[derive(Debug, PartialEq)]

View File

@ -1,16 +1,18 @@
use crate::http::multipart::{Multipart, PartData};
use crate::http::GQLRequest;
use crate::query::{IntoQueryBuilder, IntoQueryBuilderOpts};
use crate::{ParseRequestError, QueryBuilder};
use futures::{AsyncRead, AsyncReadExt};
use bytes::Bytes;
use futures::{AsyncRead, AsyncReadExt, Stream};
use mime::Mime;
use multer::{Constraints, Multipart, SizeLimit};
use std::collections::HashMap;
use std::io::{Seek, SeekFrom, Write};
#[async_trait::async_trait]
impl<CT, Body> IntoQueryBuilder for (Option<CT>, Body)
where
CT: AsRef<str> + Send,
Body: AsyncRead + Send + Unpin,
Body: AsyncRead + Send + Unpin + 'static,
{
async fn into_query_builder_opts(
mut self,
@ -29,55 +31,101 @@ where
})
{
// multipart
let mut multipart = Multipart::parse(
self.1,
boundary.as_str(),
opts.max_file_size,
opts.max_num_files,
)
.await?;
let gql_request: GQLRequest = {
let part = multipart
.remove("operations")
.ok_or_else(|| ParseRequestError::MissingOperatorsPart)?;
let reader = part.create_reader()?;
serde_json::from_reader(reader).map_err(ParseRequestError::InvalidRequest)?
};
let mut map: HashMap<String, Vec<String>> = {
let part = multipart
.remove("map")
.ok_or_else(|| ParseRequestError::MissingMapPart)?;
let reader = part.create_reader()?;
serde_json::from_reader(reader).map_err(ParseRequestError::InvalidFilesMap)?
};
let stream = reader_stream(self.1);
let mut builder = gql_request.into_query_builder().await?;
let mut multipart = Multipart::new_with_constraints(
stream,
boundary,
Constraints::new().size_limit({
let mut limit = SizeLimit::new();
if let (Some(max_file_size), Some(max_num_files)) =
(opts.max_file_size, opts.max_file_size)
{
limit = limit.whole_stream(max_file_size * max_num_files);
}
if let Some(max_file_size) = opts.max_file_size {
limit = limit.per_field(max_file_size);
}
limit
}),
);
// read files
for part in &multipart.parts {
if let Some(name) = &part.name {
if let Some(var_paths) = map.remove(name) {
for var_path in var_paths {
if let (Some(filename), PartData::File(content)) =
(&part.filename, &part.data)
{
builder.set_upload(
&var_path,
filename.clone(),
part.content_type.clone(),
content.try_clone().unwrap(),
);
let mut builder = None;
let mut map = None;
while let Some(mut field) = multipart
.next_field()
.await
.map_err(|err| ParseRequestError::InvalidMultipart(err.to_string()))?
{
match field.name() {
Some("operations") => {
let request_str = field
.text()
.await
.map_err(|err| ParseRequestError::InvalidMultipart(err.to_string()))?;
let request: GQLRequest = serde_json::from_str(&request_str)
.map_err(ParseRequestError::InvalidRequest)?;
builder = Some(request.into_query_builder().await?);
}
Some("map") => {
let map_str = field
.text()
.await
.map_err(|err| ParseRequestError::InvalidMultipart(err.to_string()))?;
map = Some(
serde_json::from_str::<HashMap<String, Vec<String>>>(&map_str)
.map_err(ParseRequestError::InvalidFilesMap)?,
);
}
_ => {
let builder = match &mut builder {
Some(builder) => builder,
None => return Err(ParseRequestError::MissingOperatorsPart),
};
let map = match &mut map {
Some(map) => map,
None => return Err(ParseRequestError::MissingMapPart),
};
if let Some(name) = field.name() {
if let Some(filename) = field.file_name().map(ToString::to_string) {
if let Some(var_paths) = map.remove(name) {
let content_type =
field.content_type().map(|mime| mime.to_string());
let mut file =
tempfile::tempfile().map_err(ParseRequestError::Io)?;
while let Some(chunk) = field.chunk().await.unwrap() {
println!("{:?}", chunk);
file.write(&chunk).map_err(ParseRequestError::Io)?;
}
file.seek(SeekFrom::Start(0))?;
for var_path in var_paths {
builder.set_upload(
&var_path,
filename.clone(),
content_type.clone(),
file.try_clone().unwrap(),
);
}
}
}
}
}
}
}
if !map.is_empty() {
return Err(ParseRequestError::MissingFiles);
if let Some(map) = &map {
if !map.is_empty() {
return Err(ParseRequestError::MissingFiles);
}
} else {
return Err(ParseRequestError::MissingMapPart);
}
Ok(builder)
Ok(match builder {
Some(builder) => builder,
None => return Err(ParseRequestError::MissingOperatorsPart),
})
} else {
let mut data = Vec::new();
self.1
@ -90,3 +138,17 @@ where
}
}
}
fn reader_stream(
mut r: impl AsyncRead + Send + Unpin + 'static,
) -> impl Stream<Item = std::io::Result<Bytes>> + 'static {
async_stream::try_stream! {
let mut buf = [0u8; 2048];
while let size = r.read(&mut buf[..]).await? {
if size == 0 {
return;
}
yield Bytes::from(buf[..size].to_vec());
}
}
}

View File

@ -2,10 +2,8 @@
mod graphiql_source;
mod into_query_builder;
mod multipart;
mod playground_source;
mod stream_body;
mod token_reader;
use itertools::Itertools;

View File

@ -1,326 +0,0 @@
use super::token_reader::*;
use crate::ParseRequestError;
use futures::io::BufReader;
use futures::{AsyncBufRead, AsyncRead};
use http::{header::HeaderName, HeaderMap, HeaderValue};
use itertools::Itertools;
use std::fs::File;
use std::io::{Cursor, Read, Seek, SeekFrom, Write};
use std::str::FromStr;
const MAX_HEADERS: usize = 16;
pub enum PartData {
Bytes(Vec<u8>),
File(File),
}
pub struct Part {
pub name: Option<String>,
pub filename: Option<String>,
pub content_type: Option<String>,
pub size: usize,
pub data: PartData,
}
impl Part {
pub fn create_reader(self) -> Result<Box<dyn Read>, std::io::Error> {
let reader: Box<dyn Read> = match self.data {
PartData::Bytes(bytes) => Box::new(Cursor::new(bytes)),
PartData::File(content) => Box::new(content),
};
Ok(reader)
}
}
struct ContentDisposition {
name: Option<String>,
filename: Option<String>,
}
impl ContentDisposition {
fn parse(value: &str) -> Result<ContentDisposition, ParseRequestError> {
let name = regex::Regex::new("name=\"(?P<name>.*?)\"")
.unwrap()
.captures(value)
.and_then(|caps| caps.name("name").map(|m| m.as_str().to_string()));
let filename = regex::Regex::new("filename=\"(?P<filename>.*?)\"")
.unwrap()
.captures(value)
.and_then(|caps| caps.name("filename").map(|m| m.as_str().to_string()));
Ok(ContentDisposition { name, filename })
}
}
pub struct Multipart {
pub parts: Vec<Part>,
}
impl Multipart {
pub async fn parse<R: AsyncRead + Unpin>(
reader: R,
boundary: &str,
max_file_size: Option<usize>,
max_num_files: Option<usize>,
) -> Result<Multipart, ParseRequestError> {
let mut reader = BufReader::new(reader);
let mut parts = Vec::new();
let boundary = format!("--{}", boundary);
let max_num_files = max_num_files.unwrap_or(std::usize::MAX);
let max_file_size = max_file_size.unwrap_or(std::usize::MAX);
let mut current_num_files = 0;
// first part
reader.except_token(boundary.as_bytes()).await?;
reader.except_token(b"\r\n").await?;
let headers = Self::parse_headers(&mut reader).await?;
parts.push(Self::parse_body(&mut reader, &headers, max_file_size, &boundary).await?);
Multipart::check_max_num_files(&mut parts, max_num_files, &mut current_num_files)?;
// next parts
loop {
if reader.except_token(b"\r\n").await.is_err() {
reader.except_token(b"--\r\n").await?;
break;
}
let headers = Self::parse_headers(&mut reader).await?;
parts.push(Self::parse_body(&mut reader, &headers, max_file_size, &boundary).await?);
Multipart::check_max_num_files(&mut parts, max_num_files, &mut current_num_files)?;
}
Ok(Multipart { parts })
}
fn check_max_num_files(
parts: &mut Vec<Part>,
max_num_files: usize,
current_num_files: &mut usize,
) -> Result<(), ParseRequestError> {
if parts.last().unwrap().filename.is_some() {
*current_num_files += 1;
if *current_num_files > max_num_files {
return Err(ParseRequestError::TooManyFiles);
}
}
Ok(())
}
async fn parse_headers<R: AsyncBufRead + Unpin>(
mut reader: R,
) -> Result<HeaderMap, ParseRequestError> {
let mut buf = [0; 256];
let mut header_data = Vec::new();
let mut state = ReadUntilState::default();
loop {
let (size, found) = reader
.read_until_token(b"\r\n\r\n", &mut buf, &mut state)
.await?;
header_data.extend_from_slice(&buf[..size]);
if found {
break;
}
}
let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
header_data.extend_from_slice(b"\r\n\r\n");
let headers = match httparse::parse_headers(&header_data, &mut headers)
.map_err(|_| ParseRequestError::InvalidMultipart)?
{
httparse::Status::Complete((_, headers)) => headers,
_ => return Err(ParseRequestError::InvalidMultipart),
};
let mut headers_map = HeaderMap::new();
for httparse::Header { name, value } in headers {
headers_map.insert(
HeaderName::from_str(name).map_err(|_| ParseRequestError::InvalidMultipart)?,
HeaderValue::from_bytes(value).map_err(|_| ParseRequestError::InvalidMultipart)?,
);
}
Ok(headers_map)
}
async fn parse_body<R: AsyncBufRead + Unpin>(
mut reader: R,
headers: &HeaderMap,
max_file_size: usize,
boundary: &str,
) -> Result<Part, ParseRequestError> {
let content_disposition = headers
.get(http::header::CONTENT_DISPOSITION)
.and_then(|value| value.to_str().ok())
.and_then(|value| ContentDisposition::parse(value).ok())
.unwrap_or_else(|| ContentDisposition {
name: None,
filename: None,
});
let content_type = headers
.get(http::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(ToString::to_string);
let mut buf = [0; 4096];
let mut state = ReadUntilState::default();
let mut total_size = 0;
let part_data = if content_disposition.filename.is_some() {
// Create a temporary file.
let mut file = tempfile::tempfile()?;
loop {
let (size, found) = reader
.read_until_token(boundary.as_bytes(), &mut buf, &mut state)
.await?;
total_size += size;
if total_size > max_file_size {
return Err(ParseRequestError::TooLarge);
}
file.write_all(&buf[..size])?;
if found {
break;
}
}
file.seek(SeekFrom::Start(0))?;
PartData::File(file)
} else {
let mut body = Vec::new();
loop {
let (size, found) = reader
.read_until_token(boundary.as_bytes(), &mut buf, &mut state)
.await?;
total_size += size;
body.extend_from_slice(&buf[..size]);
if found {
break;
}
}
PartData::Bytes(body)
};
Ok(Part {
name: content_disposition.name,
filename: content_disposition.filename,
content_type,
size: total_size,
data: part_data,
})
}
pub fn remove(&mut self, name: &str) -> Option<Part> {
if let Some((pos, _)) = self.parts.iter().find_position(|part| {
if let Some(part_name) = &part.name {
part_name == name
} else {
false
}
}) {
Some(self.parts.remove(pos))
} else {
None
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[async_std::test]
async fn test_parse() {
let data: &[u8] = b"--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\
Content-Type: text/plain; charset=utf-8\r\n\r\n\
test\
--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
Content-Type: text/plain; charset=utf-8\r\n\r\n\
data\
--abbc761f78ff4d7cb7573b5a23f96ef0--\r\n";
let multipart = Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", None, None)
.await
.unwrap();
assert_eq!(multipart.parts.len(), 2);
let part_1 = &multipart.parts[0];
assert_eq!(part_1.name.as_deref(), Some("file"));
assert_eq!(part_1.filename.as_deref(), Some("fn.txt"));
assert_eq!(
part_1.content_type.as_deref(),
Some("text/plain; charset=utf-8")
);
let part_2 = &multipart.parts[1];
assert!(part_2.name.is_none());
assert!(part_2.filename.is_none());
assert_eq!(
part_2.content_type.as_deref(),
Some("text/plain; charset=utf-8")
);
}
#[async_std::test]
async fn test_parse_limit_file_size() {
let data: &[u8] = b"--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\
Content-Type: text/plain; charset=utf-8\r\n\r\n\
12345\
--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
Content-Type: text/plain; charset=utf-8\r\n\r\n\
data\
--abbc761f78ff4d7cb7573b5a23f96ef0--\r\n";
assert!(
Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", Some(5), None,)
.await
.is_ok()
);
assert!(
Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", Some(6), None,)
.await
.is_ok()
);
assert!(
Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", Some(4), None,)
.await
.is_err()
);
}
#[async_std::test]
async fn test_parse_limit_num_files() {
let data: &[u8] = b"--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\
Content-Type: text/plain; charset=utf-8\r\n\r\n\
12345\
--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
Content-Disposition: form-data; name=\"file1\"; filename=\"fn1.txt\"\r\n\r\n\
data\
--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
Content-Disposition: form-data; name=\"file2\"\r\n\r\n\
data\
--abbc761f78ff4d7cb7573b5a23f96ef0--\r\n";
assert!(
Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", None, Some(1))
.await
.is_err()
);
assert!(
Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", None, Some(2))
.await
.is_ok()
);
assert!(
Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", None, Some(3))
.await
.is_ok()
);
}
}

View File

@ -1,191 +0,0 @@
use futures::io::ErrorKind;
use futures::task::{Context, Poll};
use futures::{AsyncBufRead, Future};
use std::io::{Error, Result};
use std::pin::Pin;
pub trait AsyncTokenReader: AsyncBufRead {
fn read_until_token<'a>(
&'a mut self,
token: &'a [u8],
buf: &'a mut [u8],
state: &'a mut ReadUntilState,
) -> ReadUntilToken<'a, Self> {
ReadUntilToken {
reader: self,
token,
buf,
state,
}
}
fn except_token<'a>(&'a mut self, token: &'a [u8]) -> ExceptToken<'a, Self> {
ExceptToken {
reader: self,
token,
match_size: 0,
}
}
}
impl<R: AsyncBufRead> AsyncTokenReader for R {}
#[derive(Default)]
pub struct ReadUntilState {
match_size: usize,
consume_token: Option<(usize, usize)>,
}
pub struct ReadUntilToken<'a, R: ?Sized> {
reader: &'a mut R,
token: &'a [u8],
buf: &'a mut [u8],
state: &'a mut ReadUntilState,
}
impl<'a, R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntilToken<'a, R> {
type Output = Result<(usize, bool)>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
let mut rsz = 0;
loop {
let nsz = this.buf.len() - rsz;
if let Some((pos, size)) = &mut this.state.consume_token {
let sz = (*size - *pos).min(nsz);
this.buf[rsz..rsz + sz].copy_from_slice(&this.token[*pos..*pos + sz]);
*pos += sz;
rsz += sz;
if *pos == *size {
this.state.consume_token = None;
}
if rsz == this.buf.len() {
return Poll::Ready(Ok((rsz, false)));
}
} else {
match Pin::new(&mut this.reader).poll_fill_buf(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Ready(Ok(data)) if data.is_empty() => {
return Poll::Ready(Err(Error::from(ErrorKind::UnexpectedEof)))
}
Poll::Ready(Ok(data)) => {
let mut consume_size = data.len();
for (idx, b) in data.iter().enumerate() {
if *b == this.token[this.state.match_size] {
this.state.match_size += 1;
if this.state.match_size == this.token.len() {
Pin::new(&mut this.reader).consume(idx + 1);
this.state.match_size = 0;
return Poll::Ready(Ok((rsz, true)));
}
} else if this.state.match_size > 0 {
this.state.consume_token = Some((0, this.state.match_size));
this.state.match_size = 0;
consume_size = idx;
break;
} else {
this.buf[rsz] = *b;
rsz += 1;
if rsz == this.buf.len() {
Pin::new(&mut this.reader).consume(idx + 1);
return Poll::Ready(Ok((rsz, false)));
}
}
}
Pin::new(&mut this.reader).consume(consume_size);
}
}
}
}
}
}
pub struct ExceptToken<'a, R: ?Sized> {
reader: &'a mut R,
token: &'a [u8],
match_size: usize,
}
impl<'a, R: AsyncBufRead + ?Sized + Unpin> Future for ExceptToken<'a, R> {
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = &mut *self;
loop {
match Pin::new(&mut this.reader).poll_fill_buf(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
Poll::Ready(Ok(data)) if data.is_empty() => {
return Poll::Ready(Err(Error::from(ErrorKind::UnexpectedEof)))
}
Poll::Ready(Ok(data)) => {
for b in data {
if *b == this.token[this.match_size] {
this.match_size += 1;
if this.match_size == this.token.len() {
Pin::new(&mut this.reader).consume(this.match_size);
return Poll::Ready(Ok(()));
}
} else {
return Poll::Ready(Err(Error::from(ErrorKind::InvalidData)));
}
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use crate::http::token_reader::{AsyncTokenReader, ReadUntilState};
use futures::io::BufReader;
#[async_std::test]
async fn test_read_until_token() {
let data: &[u8] = b"12AB567890ABC12345ABC6";
let mut reader = BufReader::new(data);
let mut buf = [0; 3];
let mut state = ReadUntilState::default();
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
assert!(matches!(res, Ok((3, false))));
assert_eq!(&buf, b"12A");
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
assert!(matches!(res, Ok((3, false))));
assert_eq!(&buf, b"B56");
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
assert!(matches!(res, Ok((3, false))));
assert_eq!(&buf, b"789");
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
assert!(matches!(res, Ok((1, true))));
assert_eq!(&buf[..1], b"0");
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
assert!(matches!(res, Ok((3, false))));
assert_eq!(&buf, b"123");
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
assert!(matches!(res, Ok((2, true))));
assert_eq!(&buf[..2], b"45");
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
assert!(matches!(res, Err(_)));
}
#[async_std::test]
async fn test_read_expect_token() {
let data: &[u8] = b"ABCABC";
let mut reader = BufReader::new(data);
assert!(reader.except_token(b"ABC").await.is_ok());
assert!(reader.except_token(b"ABC").await.is_ok());
assert!(reader.except_token(b"ABC").await.is_err());
}
}