diff --git a/Cargo.toml b/Cargo.toml index 6158882a..f8d4b8b0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,6 @@ http = "0.2.3" multer = "2.0.0" tempfile = "3.2.0" bytes = { version = "1.0.1", features = ["serde"] } - # Feature optional dependencies bson = { version = "2.0.0-beta.1", optional = true, features = ["chrono-0_4"] } chrono = { version = "0.4.19", optional = true } @@ -57,7 +56,7 @@ opentelemetry = { version = "0.13.0", optional = true } url = { version = "2.2.1", optional = true } uuid = { version = "0.8.2", optional = true, features = ["v4", "serde"] } rust_decimal = { version = "1.14.3", optional = true } - +mime = "0.3.15" # Non-feature optional dependencies blocking = { version = "1.0.2", optional = true } lru = { version = "0.6.5", optional = true } diff --git a/src/error.rs b/src/error.rs index 5d4bd242..fbc3e2f3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -6,7 +6,6 @@ use serde::{Deserialize, Serialize}; use thiserror::Error; use crate::{parser, InputType, Pos, Value}; -use serde::de::Error as SerdeError; /// Extensions to the error. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] @@ -259,6 +258,12 @@ impl From for ParseRequestError { } } +impl From for ParseRequestError { + fn from(e: mime::FromStrError) -> Self { + Self::InvalidRequest(Box::new(e)) + } +} + /// An error which can be extended into a `Error`. pub trait ErrorExtensions: Sized { /// Convert the error to a `Error`. diff --git a/src/http/mod.rs b/src/http/mod.rs index 454dda99..de2e8fc6 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -4,7 +4,7 @@ mod graphiql_source; mod multipart; mod playground_source; mod websocket; - +use mime::{self}; use futures_util::io::{AsyncRead, AsyncReadExt}; use crate::{BatchRequest, ParseRequestError, Request}; @@ -31,12 +31,27 @@ pub async fn receive_batch_body( body: impl AsyncRead + Send, opts: MultipartOptions, ) -> Result { - let content_type = content_type.as_ref().map(AsRef::as_ref); + // if no content-type header is set, we default to json + let content_type = content_type.as_ref().map(AsRef::as_ref).unwrap_or("application/json"); + + let content_type: mime::Mime = content_type.parse()?; + match (content_type.type_(), content_type.subtype()) { + // application/json -> try json + (mime::APPLICATION, mime::JSON) => receive_batch_json(body).await, + // cbor is in application/octet-stream. + // TODO: wait for mime to add application/cbor and match against that too + (mime::OCTET_STREAM, _) | (mime::APPLICATION, mime::OCTET_STREAM) => receive_batch_cbor(body).await, + // try to use multipart + (mime::MULTIPART, _) => { + if let Some(boundary) = content_type.get_param("boundary") { + multipart::receive_batch_multipart(body, boundary.to_string(), opts).await + } else { + Err(ParseRequestError::InvalidMultipart(multer::Error::NoBoundary)) + } + } - if let Some(Ok(boundary)) = content_type.map(multer::parse_boundary) { - multipart::receive_batch_multipart(body, boundary, opts).await - } else { - receive_batch_cbor(body).await + // default to json and try that + _ => receive_batch_json(body).await } } @@ -52,7 +67,6 @@ pub async fn receive_batch_json(body: impl AsyncRead) -> Result(&data) .map_err(|e| ParseRequestError::InvalidRequest(Box::new(e)))?) }