From 0e371afb3ccdb7438b772f6d754abb81b454d457 Mon Sep 17 00:00:00 2001 From: Sunli Date: Mon, 11 May 2020 17:13:50 +0800 Subject: [PATCH 1/2] It not finished yet. --- async-graphql-parser/src/lib.rs | 2 +- async-graphql-parser/src/value.rs | 18 ++++++++++++++++++ src/types/upload.rs | 22 +++++++++++++++++----- 3 files changed, 36 insertions(+), 6 deletions(-) diff --git a/async-graphql-parser/src/lib.rs b/async-graphql-parser/src/lib.rs index da16399a..b9ecec42 100644 --- a/async-graphql-parser/src/lib.rs +++ b/async-graphql-parser/src/lib.rs @@ -10,4 +10,4 @@ mod value; pub use pos::{Pos, Positioned}; pub use query_parser::{parse_query, parse_value, Error, Result}; -pub use value::Value; +pub use value::{UploadValue, Value}; diff --git a/async-graphql-parser/src/value.rs b/async-graphql-parser/src/value.rs index 0e2186c1..fbbfe527 100644 --- a/async-graphql-parser/src/value.rs +++ b/async-graphql-parser/src/value.rs @@ -1,5 +1,19 @@ use std::collections::BTreeMap; use std::fmt; +use std::fmt::Formatter; +use std::io::Read; + +pub struct UploadValue { + pub filename: String, + pub content_type: Option, + pub path: Option>, +} + +impl fmt::Debug for UploadValue { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!("Upload({})", self.filename) + } +} /// Represents a GraphQL value #[derive(Clone, Debug)] @@ -14,6 +28,7 @@ pub enum Value { Enum(String), List(Vec), Object(BTreeMap), + Upload(UploadValue), } impl PartialEq for Value { @@ -54,6 +69,7 @@ impl PartialEq for Value { } true } + (Upload(a), Upload(b)) => a.filename == b.filename, _ => false, } } @@ -112,6 +128,7 @@ impl fmt::Display for Value { } write!(f, "}}") } + Value::Upload(upload) => write!(f, "null"), } } } @@ -136,6 +153,7 @@ impl From for serde_json::Value { .map(|(name, value)| (name, value.into())) .collect(), ), + Value::Upload(_) => serde_json::Value::Null, } } } diff --git a/src/types/upload.rs b/src/types/upload.rs index c5a4b1e6..dcccd863 100644 --- a/src/types/upload.rs +++ b/src/types/upload.rs @@ -1,5 +1,8 @@ use crate::{registry, InputValueError, InputValueResult, InputValueType, Type, Value}; +use async_graphql_parser::UploadValue; +use futures::AsyncRead; use std::borrow::Cow; +use std::io::Read; use std::path::PathBuf; /// Uploaded file @@ -41,15 +44,24 @@ use std::path::PathBuf; /// --form 'map={ "0": ["variables.file"] }' \ /// --form '0=@myFile.txt' /// ``` -pub struct Upload { +pub struct Upload(UploadValue); + +impl Upload { /// Filename - pub filename: String, + pub fn filename(&self) -> &str { + self.0.filename.as_str() + } /// Content type, such as `application/json`, `image/jpg` ... - pub content_type: Option, + pub fn content_type(&self) -> Option<&str> { + self.0.content_type.as_deref() + } - /// Temporary file path - pub path: PathBuf, + /// Convert to an asynchronous stream + pub fn into_async_read(self) -> impl AsyncRead {} + + /// Convert to a synchronized stream + pub fn into_read(self) -> impl Read {} } impl<'a> Type for Upload { From 8e9aff105ee1db49842bc8b0ad2e510b3e16898a Mon Sep 17 00:00:00 2001 From: sunli Date: Mon, 11 May 2020 21:47:24 +0800 Subject: [PATCH 2/2] Support Upload Stream #15 I think the previous implementation is not elegant enough, the `QueryBuilder::set_files_holder` function looks disgusting, so I refactored it. By the way, the performance of parsing InputValue has been optimized, and unnecessary clones have been removed. --- Cargo.toml | 2 +- async-graphql-derive/src/enum.rs | 2 +- async-graphql-derive/src/input_object.rs | 12 +- async-graphql-derive/src/interface.rs | 2 +- async-graphql-derive/src/lib.rs | 2 +- async-graphql-derive/src/object.rs | 2 +- async-graphql-derive/src/union.rs | 2 +- async-graphql-derive/src/utils.rs | 1 + async-graphql-parser/src/value.rs | 18 ++- async-graphql-tide/tests/graphql.rs | 20 ++- docs/en/src/custom_scalars.md | 2 +- docs/zh-CN/src/custom_scalars.md | 2 +- src/base.rs | 19 ++- src/context.rs | 53 ++++---- src/error.rs | 6 +- src/http/into_query_builder.rs | 13 +- src/http/multipart.rs | 151 +++++++---------------- src/query.rs | 19 +-- src/scalars/any.rs | 4 +- src/scalars/bool.rs | 13 +- src/scalars/bson.rs | 6 +- src/scalars/chrono_tz.rs | 4 +- src/scalars/datetime.rs | 4 +- src/scalars/floats.rs | 8 +- src/scalars/id.rs | 6 +- src/scalars/integers.rs | 12 +- src/scalars/json.rs | 4 +- src/scalars/string.rs | 6 +- src/scalars/url.rs | 6 +- src/scalars/uuid.rs | 4 +- src/types/connection/cursor.rs | 6 +- src/types/enum.rs | 8 +- src/types/list.rs | 2 +- src/types/optional.rs | 2 +- src/types/upload.rs | 43 ++----- 35 files changed, 189 insertions(+), 277 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index f80d9600..3c6c81a6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -35,7 +35,7 @@ chrono = "0.4.10" slab = "0.4.2" once_cell = "1.3.1" itertools = "0.9.0" -tempdir = "0.3.7" +tempfile = "3.1.0" httparse = "1.3.4" mime = "0.3.16" http = "0.2.1" diff --git a/async-graphql-derive/src/enum.rs b/async-graphql-derive/src/enum.rs index 40ac5e1a..1613d480 100644 --- a/async-graphql-derive/src/enum.rs +++ b/async-graphql-derive/src/enum.rs @@ -156,7 +156,7 @@ pub fn generate(enum_args: &args::Enum, input: &DeriveInput) -> Result #crate_name::InputValueResult { + fn parse(value: #crate_name::Value) -> #crate_name::InputValueResult { #crate_name::EnumType::parse_enum(value) } } diff --git a/async-graphql-derive/src/input_object.rs b/async-graphql-derive/src/input_object.rs index ed6a87ba..d9c9a328 100644 --- a/async-graphql-derive/src/input_object.rs +++ b/async-graphql-derive/src/input_object.rs @@ -81,17 +81,17 @@ pub fn generate(object_args: &args::InputObject, input: &DeriveInput) -> Result< get_fields.push(quote! { let #ident:#ty = { match obj.get(#name) { - Some(value) => #crate_name::InputValueType::parse(value)?, + Some(value) => #crate_name::InputValueType::parse(value.clone())?, None => { let default = #default_repr; - #crate_name::InputValueType::parse(&default)? + #crate_name::InputValueType::parse(default)? } } }; }); } else { get_fields.push(quote! { - let #ident:#ty = #crate_name::InputValueType::parse(obj.get(#name).unwrap_or(&#crate_name::Value::Null))?; + let #ident:#ty = #crate_name::InputValueType::parse(obj.get(#name).cloned().unwrap_or(#crate_name::Value::Null))?; }); } @@ -129,14 +129,14 @@ pub fn generate(object_args: &args::InputObject, input: &DeriveInput) -> Result< } impl #crate_name::InputValueType for #ident { - fn parse(value: &#crate_name::Value) -> #crate_name::InputValueResult { + fn parse(value: #crate_name::Value) -> #crate_name::InputValueResult { use #crate_name::Type; - if let #crate_name::Value::Object(obj) = value { + if let #crate_name::Value::Object(obj) = &value { #(#get_fields)* Ok(Self { #(#fields),* }) } else { - Err(#crate_name::InputValueError::ExpectedType) + Err(#crate_name::InputValueError::ExpectedType(value)) } } } diff --git a/async-graphql-derive/src/interface.rs b/async-graphql-derive/src/interface.rs index 1c4cd195..83b15137 100644 --- a/async-graphql-derive/src/interface.rs +++ b/async-graphql-derive/src/interface.rs @@ -69,7 +69,7 @@ pub fn generate(interface_args: &args::Interface, input: &DeriveInput) -> Result }; if let Type::Path(p) = &field.ty { // This validates that the field type wasn't already used - if enum_items.insert(p) == false { + if !enum_items.insert(p) { return Err(Error::new_spanned( field, "This type already used in another variant", diff --git a/async-graphql-derive/src/lib.rs b/async-graphql-derive/src/lib.rs index 607d0726..d29ffaf0 100644 --- a/async-graphql-derive/src/lib.rs +++ b/async-graphql-derive/src/lib.rs @@ -165,7 +165,7 @@ pub fn Scalar(args: TokenStream, input: TokenStream) -> TokenStream { } impl #generic #crate_name::InputValueType for #self_ty #where_clause { - fn parse(value: &#crate_name::Value) -> #crate_name::InputValueResult { + fn parse(value: #crate_name::Value) -> #crate_name::InputValueResult { <#self_ty as #crate_name::ScalarType>::parse(value) } } diff --git a/async-graphql-derive/src/object.rs b/async-graphql-derive/src/object.rs index 46659433..8d5d1963 100644 --- a/async-graphql-derive/src/object.rs +++ b/async-graphql-derive/src/object.rs @@ -126,7 +126,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result< }); key_getter.push(quote! { params.get(#name).and_then(|value| { - let value: Option<#ty> = #crate_name::InputValueType::parse(value).ok(); + let value: Option<#ty> = #crate_name::InputValueType::parse(value.clone()).ok(); value }) }); diff --git a/async-graphql-derive/src/union.rs b/async-graphql-derive/src/union.rs index fcb653c1..303f6948 100644 --- a/async-graphql-derive/src/union.rs +++ b/async-graphql-derive/src/union.rs @@ -61,7 +61,7 @@ pub fn generate(union_args: &args::Interface, input: &DeriveInput) -> Result TokenStream } } } + Value::Upload(_) => quote! { #crate_name::Value::Null }, } } diff --git a/async-graphql-parser/src/value.rs b/async-graphql-parser/src/value.rs index fbbfe527..5991671e 100644 --- a/async-graphql-parser/src/value.rs +++ b/async-graphql-parser/src/value.rs @@ -1,17 +1,27 @@ use std::collections::BTreeMap; use std::fmt; use std::fmt::Formatter; -use std::io::Read; +use std::fs::File; pub struct UploadValue { pub filename: String, pub content_type: Option, - pub path: Option>, + pub content: File, } impl fmt::Debug for UploadValue { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { - write!("Upload({})", self.filename) + write!(f, "Upload({})", self.filename) + } +} + +impl Clone for UploadValue { + fn clone(&self) -> Self { + Self { + filename: self.filename.clone(), + content_type: self.content_type.clone(), + content: self.content.try_clone().unwrap(), + } } } @@ -128,7 +138,7 @@ impl fmt::Display for Value { } write!(f, "}}") } - Value::Upload(upload) => write!(f, "null"), + Value::Upload(_) => write!(f, "null"), } } } diff --git a/async-graphql-tide/tests/graphql.rs b/async-graphql-tide/tests/graphql.rs index d4ee6e6e..52638dac 100644 --- a/async-graphql-tide/tests/graphql.rs +++ b/async-graphql-tide/tests/graphql.rs @@ -1,6 +1,7 @@ mod test_utils; use serde_json::json; use smol::{Task, Timer}; +use std::io::Read; use std::time::Duration; use tide::Request; @@ -167,21 +168,18 @@ fn upload() -> Result<()> { #[Object] impl MutationRoot { async fn single_upload(&self, file: Upload) -> FileInfo { - println!("single_upload: filename={}", file.filename); - println!("single_upload: content_type={:?}", file.content_type); - println!("single_upload: path={:?}", file.path); - - let file_path = file.path.clone(); - let content = Task::blocking(async move { std::fs::read_to_string(file_path) }) - .await - .ok(); - assert_eq!(content, Some("test\r\n".to_owned())); + println!("single_upload: filename={}", file.filename()); + println!("single_upload: content_type={:?}", file.content_type()); let file_info = FileInfo { - filename: file.filename, - mime_type: file.content_type, + filename: file.filename().into(), + mime_type: file.content_type().map(ToString::to_string), }; + let mut content = String::new(); + file.into_read().read_to_string(&mut content).ok(); + assert_eq!(content, "test\r\n".to_owned()); + file_info } } diff --git a/docs/en/src/custom_scalars.md b/docs/en/src/custom_scalars.md index 3fe4c82d..e6f2498d 100644 --- a/docs/en/src/custom_scalars.md +++ b/docs/en/src/custom_scalars.md @@ -19,7 +19,7 @@ impl ScalarType for StringNumber { "StringNumber" } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { if let Value::String(value) = value { // Parse the integer value value.parse().map(StringNumber)? diff --git a/docs/zh-CN/src/custom_scalars.md b/docs/zh-CN/src/custom_scalars.md index 6bfa7549..1605ee5d 100644 --- a/docs/zh-CN/src/custom_scalars.md +++ b/docs/zh-CN/src/custom_scalars.md @@ -19,7 +19,7 @@ impl ScalarType for StringNumber { "StringNumber" } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { if let Value::String(value) = value { // 解析整数 value.parse().map(StringNumber)? diff --git a/src/base.rs b/src/base.rs index fbfccfbd..29da53f2 100644 --- a/src/base.rs +++ b/src/base.rs @@ -52,7 +52,7 @@ pub trait Type { /// Represents a GraphQL input value pub trait InputValueType: Type + Sized { /// Parse from `Value` - fn parse(value: &Value) -> InputValueResult; + fn parse(value: Value) -> InputValueResult; } /// Represents a GraphQL output value @@ -128,11 +128,11 @@ pub trait InputObjectType: InputValueType {} /// "MyInt" /// } /// -/// fn parse(value: &Value) -> InputValueResult { +/// fn parse(value: Value) -> InputValueResult { /// if let Value::Int(n) = value { -/// Ok(MyInt(*n as i32)) +/// Ok(MyInt(n as i32)) /// } else { -/// Err(InputValueError::ExpectedType) +/// Err(InputValueError::ExpectedType(value)) /// } /// } /// @@ -151,16 +151,13 @@ pub trait ScalarType: Sized + Send { } /// Parse a scalar value, return `Some(Self)` if successful, otherwise return `None`. - fn parse(value: &Value) -> InputValueResult; + fn parse(value: Value) -> InputValueResult; /// Checks for a valid scalar value. /// - /// The default implementation is to try to parse it, and in some cases you can implement this on your own to improve performance. - fn is_valid(value: &Value) -> bool { - match Self::parse(value) { - Ok(_) => true, - _ => false, - } + /// Implementing this function can find incorrect input values during the verification phase, which can improve performance. + fn is_valid(_value: &Value) -> bool { + true } /// Convert the scalar value to json value. diff --git a/src/context.rs b/src/context.rs index 5f43a26b..03fed6ed 100644 --- a/src/context.rs +++ b/src/context.rs @@ -3,11 +3,12 @@ use crate::parser::ast::{Directive, Field, FragmentDefinition, SelectionSet, Var use crate::registry::Registry; use crate::{InputValueType, QueryError, Result, Schema, Type}; use crate::{Pos, Positioned, Value}; +use async_graphql_parser::UploadValue; use fnv::FnvHashMap; use std::any::{Any, TypeId}; use std::collections::{BTreeMap, HashMap}; +use std::fs::File; use std::ops::{Deref, DerefMut}; -use std::path::Path; use std::sync::atomic::AtomicUsize; use std::sync::Arc; @@ -56,9 +57,9 @@ impl Variables { pub(crate) fn set_upload( &mut self, var_path: &str, - filename: &str, - content_type: Option<&str>, - path: &Path, + filename: String, + content_type: Option, + content: File, ) { let mut it = var_path.split('.').peekable(); @@ -76,7 +77,11 @@ impl Variables { if let Value::List(ls) = current { if let Some(value) = ls.get_mut(idx as usize) { if !has_next { - *value = Value::String(file_string(filename, content_type, path)); + *value = Value::Upload(UploadValue { + filename, + content_type, + content, + }); return; } else { current = value; @@ -88,7 +93,11 @@ impl Variables { } else if let Value::Object(obj) = current { if let Some(value) = obj.get_mut(s) { if !has_next { - *value = Value::String(file_string(filename, content_type, path)); + *value = Value::Upload(UploadValue { + filename, + content_type, + content, + }); return; } else { current = value; @@ -101,14 +110,6 @@ impl Variables { } } -fn file_string(filename: &str, content_type: Option<&str>, path: &Path) -> String { - if let Some(content_type) = content_type { - format!("file:{}:{}|", filename, content_type) + &path.display().to_string() - } else { - format!("file:{}|", filename) + &path.display().to_string() - } -} - #[derive(Default)] /// Schema/Context data pub struct Data(FnvHashMap>); @@ -411,16 +412,12 @@ impl<'a, T> ContextBase<'a, T> { if directive.name.as_str() == "skip" { if let Some(value) = directive.get_argument("if") { match InputValueType::parse( - &self.resolve_input_value(value.clone_inner(), value.position())?, + self.resolve_input_value(value.clone_inner(), value.position())?, ) { Ok(true) => return Ok(true), Ok(false) => {} Err(err) => { - return Err(err.into_error( - value.pos, - bool::qualified_type_name(), - value.clone_inner(), - )) + return Err(err.into_error(value.pos, bool::qualified_type_name())) } } } else { @@ -434,16 +431,12 @@ impl<'a, T> ContextBase<'a, T> { } else if directive.name.as_str() == "include" { if let Some(value) = directive.get_argument("if") { match InputValueType::parse( - &self.resolve_input_value(value.clone_inner(), value.position())?, + self.resolve_input_value(value.clone_inner(), value.position())?, ) { Ok(false) => return Ok(true), Ok(true) => {} Err(err) => { - return Err(err.into_error( - value.pos, - bool::qualified_type_name(), - value.clone_inner(), - )) + return Err(err.into_error(value.pos, bool::qualified_type_name())) } } } else { @@ -499,18 +492,18 @@ impl<'a> ContextBase<'a, &'a Positioned> { Some(value) => { let pos = value.position(); let value = self.resolve_input_value(value.into_inner(), pos)?; - match InputValueType::parse(&value) { + match InputValueType::parse(value) { Ok(res) => Ok(res), - Err(err) => Err(err.into_error(pos, T::qualified_type_name(), value)), + Err(err) => Err(err.into_error(pos, T::qualified_type_name())), } } None => { let value = default(); - match InputValueType::parse(&value) { + match InputValueType::parse(value) { Ok(res) => Ok(res), Err(err) => { // The default value has no valid location. - Err(err.into_error(Pos::default(), T::qualified_type_name(), value)) + Err(err.into_error(Pos::default(), T::qualified_type_name())) } } } diff --git a/src/error.rs b/src/error.rs index 084f172d..fe5d764e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -8,7 +8,7 @@ pub enum InputValueError { Custom(String), /// The type of input value does not match the expectation. - ExpectedType, + ExpectedType(Value), } impl From for InputValueError { @@ -19,14 +19,14 @@ impl From for InputValueError { impl InputValueError { #[allow(missing_docs)] - pub fn into_error(self, pos: Pos, expected_type: String, value: Value) -> Error { + pub fn into_error(self, pos: Pos, expected_type: String) -> Error { match self { InputValueError::Custom(reason) => Error::Query { pos, path: None, err: QueryError::ParseInputValue { reason }, }, - InputValueError::ExpectedType => Error::Query { + InputValueError::ExpectedType(value) => Error::Query { pos, path: None, err: QueryError::ExpectedInputType { diff --git a/src/http/into_query_builder.rs b/src/http/into_query_builder.rs index 5a293796..86ae6704 100644 --- a/src/http/into_query_builder.rs +++ b/src/http/into_query_builder.rs @@ -32,7 +32,6 @@ where let mut multipart = Multipart::parse( self.1, boundary.as_str(), - opts.temp_dir.as_deref(), opts.max_file_size, opts.max_num_files, ) @@ -59,14 +58,14 @@ where if let Some(name) = &part.name { if let Some(var_paths) = map.remove(name) { for var_path in var_paths { - if let (Some(filename), PartData::File(path)) = + if let (Some(filename), PartData::File(content)) = (&part.filename, &part.data) { builder.set_upload( &var_path, - &filename, - part.content_type.as_deref(), - path, + filename.clone(), + part.content_type.clone(), + content.try_clone().unwrap(), ); } } @@ -78,10 +77,6 @@ where return Err(ParseRequestError::MissingFiles); } - if let Some(temp_dir) = multipart.temp_dir { - builder.set_files_holder(temp_dir); - } - Ok(builder) } else { let mut data = Vec::new(); diff --git a/src/http/multipart.rs b/src/http/multipart.rs index 5d65f9fb..0acdee24 100644 --- a/src/http/multipart.rs +++ b/src/http/multipart.rs @@ -5,16 +5,14 @@ use futures::{AsyncBufRead, AsyncRead}; use http::{header::HeaderName, HeaderMap, HeaderValue}; use itertools::Itertools; use std::fs::File; -use std::io::{Cursor, Read, Write}; -use std::path::{Path, PathBuf}; +use std::io::{Cursor, Read, Seek, SeekFrom, Write}; use std::str::FromStr; -use tempdir::TempDir; const MAX_HEADERS: usize = 16; pub enum PartData { Bytes(Vec), - File(PathBuf), + File(File), } pub struct Part { @@ -26,10 +24,10 @@ pub struct Part { } impl Part { - pub fn create_reader<'a>(&'a self) -> Result, std::io::Error> { - let reader: Box = match &self.data { + pub fn create_reader(self) -> Result, std::io::Error> { + let reader: Box = match self.data { PartData::Bytes(bytes) => Box::new(Cursor::new(bytes)), - PartData::File(path) => Box::new(File::open(path)?), + PartData::File(content) => Box::new(content), }; Ok(reader) } @@ -55,7 +53,6 @@ impl ContentDisposition { } pub struct Multipart { - pub temp_dir: Option, pub parts: Vec, } @@ -63,12 +60,10 @@ impl Multipart { pub async fn parse( reader: R, boundary: &str, - temp_dir_in: Option<&Path>, max_file_size: Option, max_num_files: Option, ) -> Result { let mut reader = BufReader::new(reader); - let mut temp_dir = None; let mut parts = Vec::new(); let boundary = format!("--{}", boundary); let max_num_files = max_num_files.unwrap_or(std::usize::MAX); @@ -79,17 +74,7 @@ impl Multipart { reader.except_token(boundary.as_bytes()).await?; reader.except_token(b"\r\n").await?; let headers = Self::parse_headers(&mut reader).await?; - parts.push( - Self::parse_body( - &mut reader, - &headers, - &mut temp_dir, - temp_dir_in, - max_file_size, - &boundary, - ) - .await?, - ); + parts.push(Self::parse_body(&mut reader, &headers, max_file_size, &boundary).await?); Multipart::check_max_num_files(&mut parts, max_num_files, &mut current_num_files)?; // next parts @@ -100,21 +85,11 @@ impl Multipart { } let headers = Self::parse_headers(&mut reader).await?; - parts.push( - Self::parse_body( - &mut reader, - &headers, - &mut temp_dir, - temp_dir_in, - max_file_size, - &boundary, - ) - .await?, - ); + parts.push(Self::parse_body(&mut reader, &headers, max_file_size, &boundary).await?); Multipart::check_max_num_files(&mut parts, max_num_files, &mut current_num_files)?; } - Ok(Multipart { temp_dir, parts }) + Ok(Multipart { parts }) } fn check_max_num_files( @@ -171,8 +146,6 @@ impl Multipart { async fn parse_body( mut reader: R, headers: &HeaderMap, - temp_dir: &mut Option, - temp_dir_in: Option<&Path>, max_file_size: usize, boundary: &str, ) -> Result { @@ -193,17 +166,9 @@ impl Multipart { let mut state = ReadUntilState::default(); let mut total_size = 0; - let part_data = if let Some(filename) = &content_disposition.filename { - if temp_dir.is_none() { - if let Some(temp_dir_in) = temp_dir_in { - *temp_dir = Some(TempDir::new_in(temp_dir_in, "async-graphql")?); - } else { - *temp_dir = Some(TempDir::new("async-graphql")?); - } - } - let temp_dir = temp_dir.as_mut().unwrap(); - let path = temp_dir.path().join(filename); - let mut file = File::create(&path)?; + let part_data = if content_disposition.filename.is_some() { + // Create a temporary file. + let mut file = tempfile::tempfile()?; loop { let (size, found) = reader @@ -218,7 +183,8 @@ impl Multipart { break; } } - PartData::File(path) + file.seek(SeekFrom::Start(0))?; + PartData::File(file) } else { let mut body = Vec::new(); @@ -273,10 +239,9 @@ mod tests { Content-Type: text/plain; charset=utf-8\r\n\r\n\ data\ --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n"; - let multipart = - Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", None, None, None) - .await - .unwrap(); + let multipart = Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", None, None) + .await + .unwrap(); assert_eq!(multipart.parts.len(), 2); let part_1 = &multipart.parts[0]; @@ -307,35 +272,23 @@ mod tests { data\ --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n"; - assert!(Multipart::parse( - data, - "abbc761f78ff4d7cb7573b5a23f96ef0", - None, - Some(5), - None, - ) - .await - .is_ok()); + assert!( + Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", Some(5), None,) + .await + .is_ok() + ); - assert!(Multipart::parse( - data, - "abbc761f78ff4d7cb7573b5a23f96ef0", - None, - Some(6), - None, - ) - .await - .is_ok()); + assert!( + Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", Some(6), None,) + .await + .is_ok() + ); - assert!(Multipart::parse( - data, - "abbc761f78ff4d7cb7573b5a23f96ef0", - None, - Some(4), - None, - ) - .await - .is_err()); + assert!( + Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", Some(4), None,) + .await + .is_err() + ); } #[async_std::test] @@ -352,34 +305,22 @@ mod tests { data\ --abbc761f78ff4d7cb7573b5a23f96ef0--\r\n"; - assert!(Multipart::parse( - data, - "abbc761f78ff4d7cb7573b5a23f96ef0", - None, - None, - Some(1) - ) - .await - .is_err()); + assert!( + Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", None, Some(1)) + .await + .is_err() + ); - assert!(Multipart::parse( - data, - "abbc761f78ff4d7cb7573b5a23f96ef0", - None, - None, - Some(2) - ) - .await - .is_ok()); + assert!( + Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", None, Some(2)) + .await + .is_ok() + ); - assert!(Multipart::parse( - data, - "abbc761f78ff4d7cb7573b5a23f96ef0", - None, - None, - Some(3) - ) - .await - .is_ok()); + assert!( + Multipart::parse(data, "abbc761f78ff4d7cb7573b5a23f96ef0", None, Some(3)) + .await + .is_ok() + ); } } diff --git a/src/query.rs b/src/query.rs index 36a2a47e..8d4c8e91 100644 --- a/src/query.rs +++ b/src/query.rs @@ -14,9 +14,9 @@ use crate::{ use itertools::Itertools; use std::any::Any; use std::collections::HashMap; -use std::path::{Path, PathBuf}; +use std::fs::File; +use std::path::PathBuf; use std::sync::atomic::AtomicUsize; -use tempdir::TempDir; /// IntoQueryBuilder options #[derive(Default, Clone)] @@ -65,7 +65,6 @@ pub struct QueryBuilder { pub(crate) operation_name: Option, pub(crate) variables: Variables, pub(crate) ctx_data: Option, - pub(crate) files_holder: Option, } impl QueryBuilder { @@ -76,7 +75,6 @@ impl QueryBuilder { operation_name: None, variables: Default::default(), ctx_data: None, - files_holder: None, } } @@ -107,21 +105,16 @@ impl QueryBuilder { self } - /// Set file holder - pub fn set_files_holder(&mut self, files_holder: TempDir) { - self.files_holder = Some(files_holder); - } - /// Set uploaded file path pub fn set_upload( &mut self, var_path: &str, - filename: &str, - content_type: Option<&str>, - path: &Path, + filename: String, + content_type: Option, + content: File, ) { self.variables - .set_upload(var_path, filename, content_type, path); + .set_upload(var_path, filename, content_type, content); } /// Execute the query. diff --git a/src/scalars/any.rs b/src/scalars/any.rs index 9e308d5f..15d03dc2 100644 --- a/src/scalars/any.rs +++ b/src/scalars/any.rs @@ -18,8 +18,8 @@ impl ScalarType for Any { Some("The `_Any` scalar is used to pass representations of entities from external services into the root `_entities` field for execution.") } - fn parse(value: &Value) -> InputValueResult { - Ok(Self(value.clone())) + fn parse(value: Value) -> InputValueResult { + Ok(Self(value)) } fn is_valid(_value: &Value) -> bool { diff --git a/src/scalars/bool.rs b/src/scalars/bool.rs index d480bb05..0bb94e89 100644 --- a/src/scalars/bool.rs +++ b/src/scalars/bool.rs @@ -11,10 +11,17 @@ impl ScalarType for bool { Some("The `Boolean` scalar type represents `true` or `false`.") } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { match value { - Value::Boolean(n) => Ok(*n), - _ => Err(InputValueError::ExpectedType), + Value::Boolean(n) => Ok(n), + _ => Err(InputValueError::ExpectedType(value)), + } + } + + fn is_valid(value: &Value) -> bool { + match value { + Value::Boolean(_) => true, + _ => false, } } diff --git a/src/scalars/bson.rs b/src/scalars/bson.rs index 86389067..27014abc 100644 --- a/src/scalars/bson.rs +++ b/src/scalars/bson.rs @@ -9,10 +9,10 @@ impl ScalarType for ObjectId { "ObjectId" } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { match value { Value::String(s) => Ok(ObjectId::with_string(&s)?), - _ => Err(InputValueError::ExpectedType), + _ => Err(InputValueError::ExpectedType(value)), } } @@ -27,7 +27,7 @@ impl ScalarType for UtcDateTime { "DateTime" } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { DateTime::::parse(value).map(UtcDateTime::from) } diff --git a/src/scalars/chrono_tz.rs b/src/scalars/chrono_tz.rs index 118fd37c..7790040f 100644 --- a/src/scalars/chrono_tz.rs +++ b/src/scalars/chrono_tz.rs @@ -9,10 +9,10 @@ impl ScalarType for Tz { "TimeZone" } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { match value { Value::String(s) => Ok(Tz::from_str(&s)?), - _ => Err(InputValueError::ExpectedType), + _ => Err(InputValueError::ExpectedType(value)), } } diff --git a/src/scalars/datetime.rs b/src/scalars/datetime.rs index a290638d..d82463cd 100644 --- a/src/scalars/datetime.rs +++ b/src/scalars/datetime.rs @@ -11,10 +11,10 @@ impl ScalarType for DateTime { "DateTime" } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { match value { Value::String(s) => Ok(Utc.datetime_from_str(&s, "%+")?), - _ => Err(InputValueError::ExpectedType), + _ => Err(InputValueError::ExpectedType(value)), } } diff --git a/src/scalars/floats.rs b/src/scalars/floats.rs index 410dc15e..15531a90 100644 --- a/src/scalars/floats.rs +++ b/src/scalars/floats.rs @@ -14,11 +14,11 @@ macro_rules! impl_float_scalars { Some("The `Float` scalar type represents signed double-precision fractional values as specified by [IEEE 754](https://en.wikipedia.org/wiki/IEEE_floating_point).") } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { match value { - Value::Int(n) => Ok(*n as Self), - Value::Float(n) => Ok(*n as Self), - _ => Err(InputValueError::ExpectedType) + Value::Int(n) => Ok(n as Self), + Value::Float(n) => Ok(n as Self), + _ => Err(InputValueError::ExpectedType(value)) } } diff --git a/src/scalars/id.rs b/src/scalars/id.rs index f1fa8bd8..911b2e5b 100644 --- a/src/scalars/id.rs +++ b/src/scalars/id.rs @@ -77,11 +77,11 @@ impl ScalarType for ID { "ID" } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { match value { Value::Int(n) => Ok(ID(n.to_string())), - Value::String(s) => Ok(ID(s.clone())), - _ => Err(InputValueError::ExpectedType), + Value::String(s) => Ok(ID(s)), + _ => Err(InputValueError::ExpectedType(value)), } } diff --git a/src/scalars/integers.rs b/src/scalars/integers.rs index 61751bd6..40b5b90d 100644 --- a/src/scalars/integers.rs +++ b/src/scalars/integers.rs @@ -14,10 +14,10 @@ macro_rules! impl_integer_scalars { Some("The `Int` scalar type represents non-fractional signed whole numeric values. Int can represent values between -(2^31) and 2^31 - 1.") } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { match value { - Value::Int(n) => Ok(*n as Self), - _ => Err(InputValueError::ExpectedType) + Value::Int(n) => Ok(n as Self), + _ => Err(InputValueError::ExpectedType(value)) } } @@ -51,11 +51,11 @@ macro_rules! impl_int64_scalars { Some("The `Int64` scalar type represents non-fractional signed whole numeric values. Int can represent values between -(2^64) and 2^64 - 1.") } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { match value { - Value::Int(n) => Ok(*n as Self), + Value::Int(n) => Ok(n as Self), Value::String(s) => Ok(s.parse()?), - _ => Err(InputValueError::ExpectedType) + _ => Err(InputValueError::ExpectedType(value)) } } diff --git a/src/scalars/json.rs b/src/scalars/json.rs index 29f03ccf..b346f8d5 100644 --- a/src/scalars/json.rs +++ b/src/scalars/json.rs @@ -28,8 +28,8 @@ impl ScalarType for Json { "JSON" } - fn parse(value: &Value) -> InputValueResult { - Ok(serde_json::from_value(value.clone().into()).map(Json)?) + fn parse(value: Value) -> InputValueResult { + Ok(serde_json::from_value(value.into()).map(Json)?) } fn to_json(&self) -> Result { diff --git a/src/scalars/string.rs b/src/scalars/string.rs index 29400705..57895870 100644 --- a/src/scalars/string.rs +++ b/src/scalars/string.rs @@ -18,10 +18,10 @@ impl ScalarType for String { Some(STRING_DESC) } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { match value { - Value::String(s) => Ok(s.clone()), - _ => Err(InputValueError::ExpectedType), + Value::String(s) => Ok(s), + _ => Err(InputValueError::ExpectedType(value)), } } diff --git a/src/scalars/url.rs b/src/scalars/url.rs index 021a8ec9..5c5d3c53 100644 --- a/src/scalars/url.rs +++ b/src/scalars/url.rs @@ -8,10 +8,10 @@ impl ScalarType for Url { "Url" } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { match value { - Value::String(s) => Ok(Url::parse(s)?), - _ => Err(InputValueError::ExpectedType), + Value::String(s) => Ok(Url::parse(&s)?), + _ => Err(InputValueError::ExpectedType(value)), } } diff --git a/src/scalars/uuid.rs b/src/scalars/uuid.rs index b33395f6..738738e5 100644 --- a/src/scalars/uuid.rs +++ b/src/scalars/uuid.rs @@ -8,10 +8,10 @@ impl ScalarType for Uuid { "UUID" } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { match value { Value::String(s) => Ok(Uuid::parse_str(&s)?), - _ => Err(InputValueError::ExpectedType), + _ => Err(InputValueError::ExpectedType(value)), } } diff --git a/src/types/connection/cursor.rs b/src/types/connection/cursor.rs index 27d40588..abd4351f 100644 --- a/src/types/connection/cursor.rs +++ b/src/types/connection/cursor.rs @@ -44,10 +44,10 @@ impl ScalarType for Cursor { "Cursor" } - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { match value { - Value::String(s) => Ok(Cursor(s.into())), - _ => Err(InputValueError::ExpectedType), + Value::String(s) => Ok(Cursor(s)), + _ => Err(InputValueError::ExpectedType(value)), } } diff --git a/src/types/enum.rs b/src/types/enum.rs index f9058240..3a4ec915 100644 --- a/src/types/enum.rs +++ b/src/types/enum.rs @@ -11,11 +11,11 @@ pub struct EnumItem { pub trait EnumType: Type + Sized + Eq + Send + Copy + Sized + 'static { fn items() -> &'static [EnumItem]; - fn parse_enum(value: &Value) -> InputValueResult { + fn parse_enum(value: Value) -> InputValueResult { let value = match value { - Value::Enum(s) => s.as_str(), - Value::String(s) => s.as_str(), - _ => return Err(InputValueError::ExpectedType), + Value::Enum(s) => s, + Value::String(s) => s, + _ => return Err(InputValueError::ExpectedType(value)), }; let items = Self::items(); diff --git a/src/types/list.rs b/src/types/list.rs index 2d328a5c..dfbd1fb7 100644 --- a/src/types/list.rs +++ b/src/types/list.rs @@ -20,7 +20,7 @@ impl Type for Vec { } impl InputValueType for Vec { - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { match value { Value::List(values) => { let mut result = Vec::new(); diff --git a/src/types/optional.rs b/src/types/optional.rs index ba600d7b..3ddaa177 100644 --- a/src/types/optional.rs +++ b/src/types/optional.rs @@ -20,7 +20,7 @@ impl Type for Option { } impl InputValueType for Option { - fn parse(value: &Value) -> InputValueResult { + fn parse(value: Value) -> InputValueResult { match value { Value::Null => Ok(None), _ => Ok(Some(T::parse(value)?)), diff --git a/src/types/upload.rs b/src/types/upload.rs index dcccd863..c7d870a3 100644 --- a/src/types/upload.rs +++ b/src/types/upload.rs @@ -1,9 +1,7 @@ use crate::{registry, InputValueError, InputValueResult, InputValueType, Type, Value}; use async_graphql_parser::UploadValue; -use futures::AsyncRead; use std::borrow::Cow; use std::io::Read; -use std::path::PathBuf; /// Uploaded file /// @@ -26,7 +24,7 @@ use std::path::PathBuf; /// #[async_graphql::Object] /// impl MutationRoot { /// async fn upload(&self, file: Upload) -> bool { -/// println!("upload: filename={}", file.filename); +/// println!("upload: filename={}", file.filename()); /// true /// } /// } @@ -57,11 +55,10 @@ impl Upload { self.0.content_type.as_deref() } - /// Convert to an asynchronous stream - pub fn into_async_read(self) -> impl AsyncRead {} - - /// Convert to a synchronized stream - pub fn into_read(self) -> impl Read {} + /// Convert to a read + pub fn into_read(self) -> impl Read + Sync + Send + 'static { + self.0.content + } } impl<'a> Type for Upload { @@ -82,31 +79,11 @@ impl<'a> Type for Upload { } impl<'a> InputValueType for Upload { - fn parse(value: &Value) -> InputValueResult { - if let Value::String(s) = value { - if s.starts_with("file:") { - let s = &s[5..]; - if let Some(idx) = s.find('|') { - let name_and_type = &s[..idx]; - let path = &s[idx + 1..]; - if let Some(type_idx) = name_and_type.find(':') { - let name = &name_and_type[..type_idx]; - let mime_type = &name_and_type[type_idx + 1..]; - return Ok(Self { - filename: name.to_string(), - content_type: Some(mime_type.to_string()), - path: PathBuf::from(path), - }); - } else { - return Ok(Self { - filename: name_and_type.to_string(), - content_type: None, - path: PathBuf::from(path), - }); - } - } - } + fn parse(value: Value) -> InputValueResult { + if let Value::Upload(upload) = value { + Ok(Upload(upload)) + } else { + Err(InputValueError::ExpectedType(value)) } - Err(InputValueError::ExpectedType) } }