Replace multipart parser with multer.
This commit is contained in:
parent
e2f6475aa6
commit
44b1eb4fb4
|
@ -43,6 +43,8 @@ fnv = "1.0.6"
|
|||
regex = "1.3.5"
|
||||
tracing = "0.1.13"
|
||||
indexmap = "1.3.2"
|
||||
async-stream = "0.2.1"
|
||||
multer = "1.0.3"
|
||||
bson = { version = "0.14.1", optional = true }
|
||||
uuid = { version = "0.8.1", optional = true }
|
||||
url = { version = "2.1.1", optional = true }
|
||||
|
|
|
@ -7,11 +7,10 @@ mod subscription;
|
|||
use actix_web::dev::{Payload, PayloadStream};
|
||||
use actix_web::{http, web, Error, FromRequest, HttpRequest};
|
||||
use async_graphql::http::StreamBody;
|
||||
use async_graphql::{IntoQueryBuilder, IntoQueryBuilderOpts, ParseRequestError, QueryBuilder};
|
||||
use async_graphql::{IntoQueryBuilder, IntoQueryBuilderOpts, QueryBuilder};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{Future, SinkExt, StreamExt, TryFutureExt};
|
||||
use std::pin::Pin;
|
||||
|
||||
pub use subscription::WSSubscription;
|
||||
|
||||
/// Extractor for GraphQL request
|
||||
|
@ -56,12 +55,7 @@ impl FromRequest for GQLRequest {
|
|||
(content_type, StreamBody::new(rx))
|
||||
.into_query_builder_opts(&config)
|
||||
.map_ok(GQLRequest)
|
||||
.map_err(|err| match err {
|
||||
ParseRequestError::TooManyFiles | ParseRequestError::TooLarge => {
|
||||
actix_web::error::ErrorPayloadTooLarge(err)
|
||||
}
|
||||
_ => actix_web::error::ErrorBadRequest(err),
|
||||
})
|
||||
.map_err(actix_web::error::ErrorBadRequest)
|
||||
.await
|
||||
})
|
||||
}
|
||||
|
|
|
@ -25,12 +25,11 @@ impl GQLRequestExt for Request {
|
|||
&self,
|
||||
opts: IntoQueryBuilderOpts,
|
||||
) -> Result<QueryBuilder, ParseRequestError> {
|
||||
let body = self.body().as_ref();
|
||||
let ct = self
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|value| value.to_str().ok());
|
||||
(ct, AllowStdIo::new(Cursor::new(body)))
|
||||
(ct, AllowStdIo::new(Cursor::new(self.body().to_vec())))
|
||||
.into_query_builder_opts(&opts)
|
||||
.await
|
||||
}
|
||||
|
|
|
@ -207,8 +207,8 @@ fn upload() -> Result<()> {
|
|||
};
|
||||
|
||||
let mut content = String::new();
|
||||
file.into_read().read_to_string(&mut content).ok();
|
||||
assert_eq!(content, "test\r\n".to_owned());
|
||||
file.into_read().read_to_string(&mut content).unwrap();
|
||||
assert_eq!(content, "test".to_owned());
|
||||
|
||||
file_info
|
||||
}
|
||||
|
|
|
@ -303,7 +303,7 @@ pub enum ParseRequestError {
|
|||
InvalidFilesMap(serde_json::Error),
|
||||
|
||||
#[error("Invalid multipart data")]
|
||||
InvalidMultipart,
|
||||
InvalidMultipart(String),
|
||||
|
||||
#[error("Missing \"operators\" part")]
|
||||
MissingOperatorsPart,
|
||||
|
@ -316,12 +316,6 @@ pub enum ParseRequestError {
|
|||
|
||||
#[error("Missing files")]
|
||||
MissingFiles,
|
||||
|
||||
#[error("Too many files")]
|
||||
TooManyFiles,
|
||||
|
||||
#[error("The file size is too large")]
|
||||
TooLarge,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
use crate::http::multipart::{Multipart, PartData};
|
||||
use crate::http::GQLRequest;
|
||||
use crate::query::{IntoQueryBuilder, IntoQueryBuilderOpts};
|
||||
use crate::{ParseRequestError, QueryBuilder};
|
||||
use futures::{AsyncRead, AsyncReadExt};
|
||||
use bytes::Bytes;
|
||||
use futures::{AsyncRead, AsyncReadExt, Stream};
|
||||
use mime::Mime;
|
||||
use multer::{Constraints, Multipart, SizeLimit};
|
||||
use std::collections::HashMap;
|
||||
use std::io::{Seek, SeekFrom, Write};
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<CT, Body> IntoQueryBuilder for (Option<CT>, Body)
|
||||
where
|
||||
CT: AsRef<str> + Send,
|
||||
Body: AsyncRead + Send + Unpin,
|
||||
Body: AsyncRead + Send + Unpin + 'static,
|
||||
{
|
||||
async fn into_query_builder_opts(
|
||||
mut self,
|
||||
|
@ -29,55 +31,101 @@ where
|
|||
})
|
||||
{
|
||||
// multipart
|
||||
let mut multipart = Multipart::parse(
|
||||
self.1,
|
||||
boundary.as_str(),
|
||||
opts.max_file_size,
|
||||
opts.max_num_files,
|
||||
)
|
||||
.await?;
|
||||
let gql_request: GQLRequest = {
|
||||
let part = multipart
|
||||
.remove("operations")
|
||||
.ok_or_else(|| ParseRequestError::MissingOperatorsPart)?;
|
||||
let reader = part.create_reader()?;
|
||||
serde_json::from_reader(reader).map_err(ParseRequestError::InvalidRequest)?
|
||||
};
|
||||
let mut map: HashMap<String, Vec<String>> = {
|
||||
let part = multipart
|
||||
.remove("map")
|
||||
.ok_or_else(|| ParseRequestError::MissingMapPart)?;
|
||||
let reader = part.create_reader()?;
|
||||
serde_json::from_reader(reader).map_err(ParseRequestError::InvalidFilesMap)?
|
||||
};
|
||||
let stream = reader_stream(self.1);
|
||||
|
||||
let mut builder = gql_request.into_query_builder().await?;
|
||||
let mut multipart = Multipart::new_with_constraints(
|
||||
stream,
|
||||
boundary,
|
||||
Constraints::new().size_limit({
|
||||
let mut limit = SizeLimit::new();
|
||||
if let (Some(max_file_size), Some(max_num_files)) =
|
||||
(opts.max_file_size, opts.max_file_size)
|
||||
{
|
||||
limit = limit.whole_stream(max_file_size * max_num_files);
|
||||
}
|
||||
if let Some(max_file_size) = opts.max_file_size {
|
||||
limit = limit.per_field(max_file_size);
|
||||
}
|
||||
limit
|
||||
}),
|
||||
);
|
||||
|
||||
// read files
|
||||
for part in &multipart.parts {
|
||||
if let Some(name) = &part.name {
|
||||
if let Some(var_paths) = map.remove(name) {
|
||||
for var_path in var_paths {
|
||||
if let (Some(filename), PartData::File(content)) =
|
||||
(&part.filename, &part.data)
|
||||
{
|
||||
builder.set_upload(
|
||||
&var_path,
|
||||
filename.clone(),
|
||||
part.content_type.clone(),
|
||||
content.try_clone().unwrap(),
|
||||
);
|
||||
let mut builder = None;
|
||||
let mut map = None;
|
||||
|
||||
while let Some(mut field) = multipart
|
||||
.next_field()
|
||||
.await
|
||||
.map_err(|err| ParseRequestError::InvalidMultipart(err.to_string()))?
|
||||
{
|
||||
match field.name() {
|
||||
Some("operations") => {
|
||||
let request_str = field
|
||||
.text()
|
||||
.await
|
||||
.map_err(|err| ParseRequestError::InvalidMultipart(err.to_string()))?;
|
||||
let request: GQLRequest = serde_json::from_str(&request_str)
|
||||
.map_err(ParseRequestError::InvalidRequest)?;
|
||||
builder = Some(request.into_query_builder().await?);
|
||||
}
|
||||
Some("map") => {
|
||||
let map_str = field
|
||||
.text()
|
||||
.await
|
||||
.map_err(|err| ParseRequestError::InvalidMultipart(err.to_string()))?;
|
||||
map = Some(
|
||||
serde_json::from_str::<HashMap<String, Vec<String>>>(&map_str)
|
||||
.map_err(ParseRequestError::InvalidFilesMap)?,
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
let builder = match &mut builder {
|
||||
Some(builder) => builder,
|
||||
None => return Err(ParseRequestError::MissingOperatorsPart),
|
||||
};
|
||||
let map = match &mut map {
|
||||
Some(map) => map,
|
||||
None => return Err(ParseRequestError::MissingMapPart),
|
||||
};
|
||||
if let Some(name) = field.name() {
|
||||
if let Some(filename) = field.file_name().map(ToString::to_string) {
|
||||
if let Some(var_paths) = map.remove(name) {
|
||||
let content_type =
|
||||
field.content_type().map(|mime| mime.to_string());
|
||||
let mut file =
|
||||
tempfile::tempfile().map_err(ParseRequestError::Io)?;
|
||||
while let Some(chunk) = field.chunk().await.unwrap() {
|
||||
println!("{:?}", chunk);
|
||||
file.write(&chunk).map_err(ParseRequestError::Io)?;
|
||||
}
|
||||
file.seek(SeekFrom::Start(0))?;
|
||||
for var_path in var_paths {
|
||||
builder.set_upload(
|
||||
&var_path,
|
||||
filename.clone(),
|
||||
content_type.clone(),
|
||||
file.try_clone().unwrap(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !map.is_empty() {
|
||||
return Err(ParseRequestError::MissingFiles);
|
||||
if let Some(map) = &map {
|
||||
if !map.is_empty() {
|
||||
return Err(ParseRequestError::MissingFiles);
|
||||
}
|
||||
} else {
|
||||
return Err(ParseRequestError::MissingMapPart);
|
||||
}
|
||||
|
||||
Ok(builder)
|
||||
Ok(match builder {
|
||||
Some(builder) => builder,
|
||||
None => return Err(ParseRequestError::MissingOperatorsPart),
|
||||
})
|
||||
} else {
|
||||
let mut data = Vec::new();
|
||||
self.1
|
||||
|
@ -90,3 +138,17 @@ where
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn reader_stream(
|
||||
mut r: impl AsyncRead + Send + Unpin + 'static,
|
||||
) -> impl Stream<Item = std::io::Result<Bytes>> + 'static {
|
||||
async_stream::try_stream! {
|
||||
let mut buf = [0u8; 2048];
|
||||
while let size = r.read(&mut buf[..]).await? {
|
||||
if size == 0 {
|
||||
return;
|
||||
}
|
||||
yield Bytes::from(buf[..size].to_vec());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,10 +2,8 @@
|
|||
|
||||
mod graphiql_source;
|
||||
mod into_query_builder;
|
||||
mod multipart;
|
||||
mod playground_source;
|
||||
mod stream_body;
|
||||
mod token_reader;
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
|
|
|
@ -1,326 +0,0 @@
|
|||
use super::token_reader::*;
|
||||
use crate::ParseRequestError;
|
||||
use futures::io::BufReader;
|
||||
use futures::{AsyncBufRead, AsyncRead};
|
||||
use http::{header::HeaderName, HeaderMap, HeaderValue};
|
||||
use itertools::Itertools;
|
||||
use std::fs::File;
|
||||
use std::io::{Cursor, Read, Seek, SeekFrom, Write};
|
||||
use std::str::FromStr;
|
||||
|
||||
const MAX_HEADERS: usize = 16;
|
||||
|
||||
pub enum PartData {
|
||||
Bytes(Vec<u8>),
|
||||
File(File),
|
||||
}
|
||||
|
||||
pub struct Part {
|
||||
pub name: Option<String>,
|
||||
pub filename: Option<String>,
|
||||
pub content_type: Option<String>,
|
||||
pub size: usize,
|
||||
pub data: PartData,
|
||||
}
|
||||
|
||||
impl Part {
|
||||
pub fn create_reader(self) -> Result<Box<dyn Read>, std::io::Error> {
|
||||
let reader: Box<dyn Read> = match self.data {
|
||||
PartData::Bytes(bytes) => Box::new(Cursor::new(bytes)),
|
||||
PartData::File(content) => Box::new(content),
|
||||
};
|
||||
Ok(reader)
|
||||
}
|
||||
}
|
||||
|
||||
struct ContentDisposition {
|
||||
name: Option<String>,
|
||||
filename: Option<String>,
|
||||
}
|
||||
|
||||
impl ContentDisposition {
|
||||
fn parse(value: &str) -> Result<ContentDisposition, ParseRequestError> {
|
||||
let name = regex::Regex::new("name=\"(?P<name>.*?)\"")
|
||||
.unwrap()
|
||||
.captures(value)
|
||||
.and_then(|caps| caps.name("name").map(|m| m.as_str().to_string()));
|
||||
let filename = regex::Regex::new("filename=\"(?P<filename>.*?)\"")
|
||||
.unwrap()
|
||||
.captures(value)
|
||||
.and_then(|caps| caps.name("filename").map(|m| m.as_str().to_string()));
|
||||
Ok(ContentDisposition { name, filename })
|
||||
}
|
||||
}
|
||||
|
||||
pub struct Multipart {
|
||||
pub parts: Vec<Part>,
|
||||
}
|
||||
|
||||
impl Multipart {
|
||||
pub async fn parse<R: AsyncRead + Unpin>(
|
||||
reader: R,
|
||||
boundary: &str,
|
||||
max_file_size: Option<usize>,
|
||||
max_num_files: Option<usize>,
|
||||
) -> Result<Multipart, ParseRequestError> {
|
||||
let mut reader = BufReader::new(reader);
|
||||
let mut parts = Vec::new();
|
||||
let boundary = format!("--{}", boundary);
|
||||
let max_num_files = max_num_files.unwrap_or(std::usize::MAX);
|
||||
let max_file_size = max_file_size.unwrap_or(std::usize::MAX);
|
||||
let mut current_num_files = 0;
|
||||
|
||||
// first part
|
||||
reader.except_token(boundary.as_bytes()).await?;
|
||||
reader.except_token(b"\r\n").await?;
|
||||
let headers = Self::parse_headers(&mut reader).await?;
|
||||
parts.push(Self::parse_body(&mut reader, &headers, max_file_size, &boundary).await?);
|
||||
Multipart::check_max_num_files(&mut parts, max_num_files, &mut current_num_files)?;
|
||||
|
||||
// next parts
|
||||
loop {
|
||||
if reader.except_token(b"\r\n").await.is_err() {
|
||||
reader.except_token(b"--\r\n").await?;
|
||||
break;
|
||||
}
|
||||
|
||||
let headers = Self::parse_headers(&mut reader).await?;
|
||||
parts.push(Self::parse_body(&mut reader, &headers, max_file_size, &boundary).await?);
|
||||
Multipart::check_max_num_files(&mut parts, max_num_files, &mut current_num_files)?;
|
||||
}
|
||||
|
||||
Ok(Multipart { parts })
|
||||
}
|
||||
|
||||
fn check_max_num_files(
|
||||
parts: &mut Vec<Part>,
|
||||
max_num_files: usize,
|
||||
current_num_files: &mut usize,
|
||||
) -> Result<(), ParseRequestError> {
|
||||
if parts.last().unwrap().filename.is_some() {
|
||||
*current_num_files += 1;
|
||||
if *current_num_files > max_num_files {
|
||||
return Err(ParseRequestError::TooManyFiles);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn parse_headers<R: AsyncBufRead + Unpin>(
|
||||
mut reader: R,
|
||||
) -> Result<HeaderMap, ParseRequestError> {
|
||||
let mut buf = [0; 256];
|
||||
let mut header_data = Vec::new();
|
||||
let mut state = ReadUntilState::default();
|
||||
|
||||
loop {
|
||||
let (size, found) = reader
|
||||
.read_until_token(b"\r\n\r\n", &mut buf, &mut state)
|
||||
.await?;
|
||||
header_data.extend_from_slice(&buf[..size]);
|
||||
if found {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mut headers = [httparse::EMPTY_HEADER; MAX_HEADERS];
|
||||
header_data.extend_from_slice(b"\r\n\r\n");
|
||||
let headers = match httparse::parse_headers(&header_data, &mut headers)
|
||||
.map_err(|_| ParseRequestError::InvalidMultipart)?
|
||||
{
|
||||
httparse::Status::Complete((_, headers)) => headers,
|
||||
_ => return Err(ParseRequestError::InvalidMultipart),
|
||||
};
|
||||
|
||||
let mut headers_map = HeaderMap::new();
|
||||
for httparse::Header { name, value } in headers {
|
||||
headers_map.insert(
|
||||
HeaderName::from_str(name).map_err(|_| ParseRequestError::InvalidMultipart)?,
|
||||
HeaderValue::from_bytes(value).map_err(|_| ParseRequestError::InvalidMultipart)?,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(headers_map)
|
||||
}
|
||||
|
||||
async fn parse_body<R: AsyncBufRead + Unpin>(
|
||||
mut reader: R,
|
||||
headers: &HeaderMap,
|
||||
max_file_size: usize,
|
||||
boundary: &str,
|
||||
) -> Result<Part, ParseRequestError> {
|
||||
let content_disposition = headers
|
||||
.get(http::header::CONTENT_DISPOSITION)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.and_then(|value| ContentDisposition::parse(value).ok())
|
||||
.unwrap_or_else(|| ContentDisposition {
|
||||
name: None,
|
||||
filename: None,
|
||||
});
|
||||
let content_type = headers
|
||||
.get(http::header::CONTENT_TYPE)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(ToString::to_string);
|
||||
|
||||
let mut buf = [0; 4096];
|
||||
let mut state = ReadUntilState::default();
|
||||
let mut total_size = 0;
|
||||
|
||||
let part_data = if content_disposition.filename.is_some() {
|
||||
// Create a temporary file.
|
||||
let mut file = tempfile::tempfile()?;
|
||||
|
||||
loop {
|
||||
let (size, found) = reader
|
||||
.read_until_token(boundary.as_bytes(), &mut buf, &mut state)
|
||||
.await?;
|
||||
total_size += size;
|
||||
if total_size > max_file_size {
|
||||
return Err(ParseRequestError::TooLarge);
|
||||
}
|
||||
file.write_all(&buf[..size])?;
|
||||
if found {
|
||||
break;
|
||||
}
|
||||
}
|
||||
file.seek(SeekFrom::Start(0))?;
|
||||
PartData::File(file)
|
||||
} else {
|
||||
let mut body = Vec::new();
|
||||
|
||||
loop {
|
||||
let (size, found) = reader
|
||||
.read_until_token(boundary.as_bytes(), &mut buf, &mut state)
|
||||
.await?;
|
||||
total_size += size;
|
||||
body.extend_from_slice(&buf[..size]);
|
||||
if found {
|
||||
break;
|
||||
}
|
||||
}
|
||||
PartData::Bytes(body)
|
||||
};
|
||||
|
||||
Ok(Part {
|
||||
name: content_disposition.name,
|
||||
filename: content_disposition.filename,
|
||||
content_type,
|
||||
size: total_size,
|
||||
data: part_data,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn remove(&mut self, name: &str) -> Option<Part> {
|
||||
if let Some((pos, _)) = self.parts.iter().find_position(|part| {
|
||||
if let Some(part_name) = &part.name {
|
||||
part_name == name
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}) {
|
||||
Some(self.parts.remove(pos))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[async_std::test]
|
||||
async fn test_parse() {
|
||||
let data: &[u8] = b"--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
|
||||
Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\
|
||||
Content-Type: text/plain; charset=utf-8\r\n\r\n\
|
||||
test\
|
||||
--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
|
||||
Content-Type: text/plain; charset=utf-8\r\n\r\n\
|
||||
data\
|
||||
--abbc761f78ff4d7cb7573b5a23f96ef0--\r\n";
|
||||
let multipart = Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(multipart.parts.len(), 2);
|
||||
|
||||
let part_1 = &multipart.parts[0];
|
||||
assert_eq!(part_1.name.as_deref(), Some("file"));
|
||||
assert_eq!(part_1.filename.as_deref(), Some("fn.txt"));
|
||||
assert_eq!(
|
||||
part_1.content_type.as_deref(),
|
||||
Some("text/plain; charset=utf-8")
|
||||
);
|
||||
|
||||
let part_2 = &multipart.parts[1];
|
||||
assert!(part_2.name.is_none());
|
||||
assert!(part_2.filename.is_none());
|
||||
assert_eq!(
|
||||
part_2.content_type.as_deref(),
|
||||
Some("text/plain; charset=utf-8")
|
||||
);
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
async fn test_parse_limit_file_size() {
|
||||
let data: &[u8] = b"--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
|
||||
Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\
|
||||
Content-Type: text/plain; charset=utf-8\r\n\r\n\
|
||||
12345\
|
||||
--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
|
||||
Content-Type: text/plain; charset=utf-8\r\n\r\n\
|
||||
data\
|
||||
--abbc761f78ff4d7cb7573b5a23f96ef0--\r\n";
|
||||
|
||||
assert!(
|
||||
Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", Some(5), None,)
|
||||
.await
|
||||
.is_ok()
|
||||
);
|
||||
|
||||
assert!(
|
||||
Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", Some(6), None,)
|
||||
.await
|
||||
.is_ok()
|
||||
);
|
||||
|
||||
assert!(
|
||||
Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", Some(4), None,)
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
async fn test_parse_limit_num_files() {
|
||||
let data: &[u8] = b"--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
|
||||
Content-Disposition: form-data; name=\"file\"; filename=\"fn.txt\"\r\n\
|
||||
Content-Type: text/plain; charset=utf-8\r\n\r\n\
|
||||
12345\
|
||||
--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
|
||||
Content-Disposition: form-data; name=\"file1\"; filename=\"fn1.txt\"\r\n\r\n\
|
||||
data\
|
||||
--abbc761f78ff4d7cb7573b5a23f96ef0\r\n\
|
||||
Content-Disposition: form-data; name=\"file2\"\r\n\r\n\
|
||||
data\
|
||||
--abbc761f78ff4d7cb7573b5a23f96ef0--\r\n";
|
||||
|
||||
assert!(
|
||||
Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", None, Some(1))
|
||||
.await
|
||||
.is_err()
|
||||
);
|
||||
|
||||
assert!(
|
||||
Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", None, Some(2))
|
||||
.await
|
||||
.is_ok()
|
||||
);
|
||||
|
||||
assert!(
|
||||
Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", None, Some(3))
|
||||
.await
|
||||
.is_ok()
|
||||
);
|
||||
}
|
||||
}
|
|
@ -1,191 +0,0 @@
|
|||
use futures::io::ErrorKind;
|
||||
use futures::task::{Context, Poll};
|
||||
use futures::{AsyncBufRead, Future};
|
||||
use std::io::{Error, Result};
|
||||
use std::pin::Pin;
|
||||
|
||||
pub trait AsyncTokenReader: AsyncBufRead {
|
||||
fn read_until_token<'a>(
|
||||
&'a mut self,
|
||||
token: &'a [u8],
|
||||
buf: &'a mut [u8],
|
||||
state: &'a mut ReadUntilState,
|
||||
) -> ReadUntilToken<'a, Self> {
|
||||
ReadUntilToken {
|
||||
reader: self,
|
||||
token,
|
||||
buf,
|
||||
state,
|
||||
}
|
||||
}
|
||||
|
||||
fn except_token<'a>(&'a mut self, token: &'a [u8]) -> ExceptToken<'a, Self> {
|
||||
ExceptToken {
|
||||
reader: self,
|
||||
token,
|
||||
match_size: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: AsyncBufRead> AsyncTokenReader for R {}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ReadUntilState {
|
||||
match_size: usize,
|
||||
consume_token: Option<(usize, usize)>,
|
||||
}
|
||||
|
||||
pub struct ReadUntilToken<'a, R: ?Sized> {
|
||||
reader: &'a mut R,
|
||||
token: &'a [u8],
|
||||
buf: &'a mut [u8],
|
||||
state: &'a mut ReadUntilState,
|
||||
}
|
||||
|
||||
impl<'a, R: AsyncBufRead + ?Sized + Unpin> Future for ReadUntilToken<'a, R> {
|
||||
type Output = Result<(usize, bool)>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = &mut *self;
|
||||
let mut rsz = 0;
|
||||
|
||||
loop {
|
||||
let nsz = this.buf.len() - rsz;
|
||||
|
||||
if let Some((pos, size)) = &mut this.state.consume_token {
|
||||
let sz = (*size - *pos).min(nsz);
|
||||
this.buf[rsz..rsz + sz].copy_from_slice(&this.token[*pos..*pos + sz]);
|
||||
*pos += sz;
|
||||
rsz += sz;
|
||||
if *pos == *size {
|
||||
this.state.consume_token = None;
|
||||
}
|
||||
if rsz == this.buf.len() {
|
||||
return Poll::Ready(Ok((rsz, false)));
|
||||
}
|
||||
} else {
|
||||
match Pin::new(&mut this.reader).poll_fill_buf(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Ready(Ok(data)) if data.is_empty() => {
|
||||
return Poll::Ready(Err(Error::from(ErrorKind::UnexpectedEof)))
|
||||
}
|
||||
Poll::Ready(Ok(data)) => {
|
||||
let mut consume_size = data.len();
|
||||
for (idx, b) in data.iter().enumerate() {
|
||||
if *b == this.token[this.state.match_size] {
|
||||
this.state.match_size += 1;
|
||||
if this.state.match_size == this.token.len() {
|
||||
Pin::new(&mut this.reader).consume(idx + 1);
|
||||
this.state.match_size = 0;
|
||||
return Poll::Ready(Ok((rsz, true)));
|
||||
}
|
||||
} else if this.state.match_size > 0 {
|
||||
this.state.consume_token = Some((0, this.state.match_size));
|
||||
this.state.match_size = 0;
|
||||
consume_size = idx;
|
||||
break;
|
||||
} else {
|
||||
this.buf[rsz] = *b;
|
||||
rsz += 1;
|
||||
if rsz == this.buf.len() {
|
||||
Pin::new(&mut this.reader).consume(idx + 1);
|
||||
return Poll::Ready(Ok((rsz, false)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Pin::new(&mut this.reader).consume(consume_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ExceptToken<'a, R: ?Sized> {
|
||||
reader: &'a mut R,
|
||||
token: &'a [u8],
|
||||
match_size: usize,
|
||||
}
|
||||
|
||||
impl<'a, R: AsyncBufRead + ?Sized + Unpin> Future for ExceptToken<'a, R> {
|
||||
type Output = Result<()>;
|
||||
|
||||
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let this = &mut *self;
|
||||
|
||||
loop {
|
||||
match Pin::new(&mut this.reader).poll_fill_buf(cx) {
|
||||
Poll::Pending => return Poll::Pending,
|
||||
Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
|
||||
Poll::Ready(Ok(data)) if data.is_empty() => {
|
||||
return Poll::Ready(Err(Error::from(ErrorKind::UnexpectedEof)))
|
||||
}
|
||||
Poll::Ready(Ok(data)) => {
|
||||
for b in data {
|
||||
if *b == this.token[this.match_size] {
|
||||
this.match_size += 1;
|
||||
if this.match_size == this.token.len() {
|
||||
Pin::new(&mut this.reader).consume(this.match_size);
|
||||
return Poll::Ready(Ok(()));
|
||||
}
|
||||
} else {
|
||||
return Poll::Ready(Err(Error::from(ErrorKind::InvalidData)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::http::token_reader::{AsyncTokenReader, ReadUntilState};
|
||||
use futures::io::BufReader;
|
||||
|
||||
#[async_std::test]
|
||||
async fn test_read_until_token() {
|
||||
let data: &[u8] = b"12AB567890ABC12345ABC6";
|
||||
let mut reader = BufReader::new(data);
|
||||
let mut buf = [0; 3];
|
||||
let mut state = ReadUntilState::default();
|
||||
|
||||
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
|
||||
assert!(matches!(res, Ok((3, false))));
|
||||
assert_eq!(&buf, b"12A");
|
||||
|
||||
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
|
||||
assert!(matches!(res, Ok((3, false))));
|
||||
assert_eq!(&buf, b"B56");
|
||||
|
||||
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
|
||||
assert!(matches!(res, Ok((3, false))));
|
||||
assert_eq!(&buf, b"789");
|
||||
|
||||
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
|
||||
assert!(matches!(res, Ok((1, true))));
|
||||
assert_eq!(&buf[..1], b"0");
|
||||
|
||||
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
|
||||
assert!(matches!(res, Ok((3, false))));
|
||||
assert_eq!(&buf, b"123");
|
||||
|
||||
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
|
||||
assert!(matches!(res, Ok((2, true))));
|
||||
assert_eq!(&buf[..2], b"45");
|
||||
|
||||
let res = reader.read_until_token(b"ABC", &mut buf, &mut state).await;
|
||||
assert!(matches!(res, Err(_)));
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
async fn test_read_expect_token() {
|
||||
let data: &[u8] = b"ABCABC";
|
||||
let mut reader = BufReader::new(data);
|
||||
assert!(reader.except_token(b"ABC").await.is_ok());
|
||||
assert!(reader.except_token(b"ABC").await.is_ok());
|
||||
assert!(reader.except_token(b"ABC").await.is_err());
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user