From acc1a6f515b3001c65b738fcf21216998b15d094 Mon Sep 17 00:00:00 2001 From: Sunli Date: Thu, 17 Sep 2020 16:39:55 +0800 Subject: [PATCH] Add support for batch queries. --- README.md | 1 + integrations/actix-web/src/lib.rs | 4 +- src/http/batch.rs | 141 +++++++++++++++++++++++ src/http/mod.rs | 13 ++- src/http/multipart.rs | 7 +- src/lib.rs | 5 +- src/registry/cache_control.rs | 21 ++-- src/request.rs | 84 +++++++++++++- src/response.rs | 55 +++++++++ src/schema.rs | 17 ++- src/types/merged_object.rs | 4 +- src/validation/visitors/cache_control.rs | 4 +- tests/batch_request.rs | 33 ++++++ 13 files changed, 359 insertions(+), 30 deletions(-) create mode 100644 src/http/batch.rs create mode 100644 tests/batch_request.rs diff --git a/README.md b/README.md index a7c3aaa5..3b5faba8 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,7 @@ This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in * Limit query complexity/depth * Error Extensions * Apollo Federation +* Batch Queries ## Examples diff --git a/integrations/actix-web/src/lib.rs b/integrations/actix-web/src/lib.rs index a71144db..f9186f88 100644 --- a/integrations/actix-web/src/lib.rs +++ b/integrations/actix-web/src/lib.rs @@ -8,7 +8,7 @@ mod subscription; use actix_web::dev::{HttpResponseBuilder, Payload, PayloadStream}; use actix_web::http::StatusCode; use actix_web::{http, web, Error, FromRequest, HttpRequest, HttpResponse, Responder}; -use async_graphql::http::{receive_body, MultipartOptions}; +use async_graphql::http::MultipartOptions; use async_graphql::{ParseRequestError, Request, Response}; use futures::channel::mpsc; use futures::future::Ready; @@ -67,7 +67,7 @@ impl FromRequest for GQLRequest { Box::pin(async move { Ok(GQLRequest( - receive_body( + async_graphql::http::receive_body( content_type, rx.map_err(|err| io::Error::new(ErrorKind::Other, err)) .into_async_read(), diff --git a/src/http/batch.rs b/src/http/batch.rs new file mode 100644 index 00000000..156b09de --- /dev/null +++ b/src/http/batch.rs @@ -0,0 +1,141 @@ +#[cfg(feature = "multipart")] +use crate::http::{multipart::ReaderStream, MultipartOptions}; +use crate::{BatchRequest, ParseRequestError}; +use futures::{AsyncRead, AsyncReadExt}; +#[cfg(feature = "multipart")] +use multer::{Constraints, Multipart, SizeLimit}; +#[cfg(feature = "multipart")] +use std::{ + collections::HashMap, + io::{Seek, SeekFrom, Write}, +}; + +#[cfg(feature = "multipart")] +#[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))] +/// Receive a GraphQL batch request from a content type and body. +pub async fn receive_batch_body( + content_type: Option>, + body: impl AsyncRead + Send + 'static, + opts: MultipartOptions, +) -> Result { + if let Some(Ok(boundary)) = content_type.map(multer::parse_boundary) { + receive_batch_multipart(body, boundary, opts).await + } else { + receive_batch_json(body).await + } +} + +/// Receive a GraphQL batch request from a body as JSON. +pub async fn receive_batch_json( + body: impl AsyncRead + Send + 'static, +) -> Result { + let mut data = Vec::new(); + futures::pin_mut!(body); + body.read_to_end(&mut data) + .await + .map_err(ParseRequestError::Io)?; + Ok(serde_json::from_slice::(&data).map_err(ParseRequestError::InvalidRequest)?) +} + +#[cfg(feature = "multipart")] +async fn receive_batch_multipart( + body: impl AsyncRead + Send + 'static, + boundary: impl Into, + opts: MultipartOptions, +) -> Result { + let mut multipart = Multipart::new_with_constraints( + ReaderStream::new(body), + boundary, + Constraints::new().size_limit({ + let mut limit = SizeLimit::new(); + if let (Some(max_file_size), Some(max_num_files)) = + (opts.max_file_size, opts.max_file_size) + { + limit = limit.whole_stream((max_file_size * max_num_files) as u64); + } + if let Some(max_file_size) = opts.max_file_size { + limit = limit.per_field(max_file_size as u64); + } + limit + }), + ); + + let mut batch_request = None; + let mut map = None; + let mut files = Vec::new(); + + while let Some(mut field) = multipart.next_field().await? { + match field.name() { + Some("operations") => { + let request_str = field.text().await?; + batch_request = Some( + serde_json::from_str::(&request_str) + .map_err(ParseRequestError::InvalidRequest)?, + ); + } + Some("map") => { + let map_str = field.text().await?; + map = Some( + serde_json::from_str::>>(&map_str) + .map_err(ParseRequestError::InvalidFilesMap)?, + ); + } + _ => { + if let Some(name) = field.name().map(ToString::to_string) { + if let Some(filename) = field.file_name().map(ToString::to_string) { + let content_type = field.content_type().map(|mime| mime.to_string()); + let mut file = tempfile::tempfile().map_err(ParseRequestError::Io)?; + while let Some(chunk) = field.chunk().await.unwrap() { + file.write(&chunk).map_err(ParseRequestError::Io)?; + } + file.seek(SeekFrom::Start(0))?; + files.push((name, filename, content_type, file)); + } + } + } + } + } + + let mut batch_request: BatchRequest = + batch_request.ok_or(ParseRequestError::MissingOperatorsPart)?; + let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?; + + for (name, filename, content_type, file) in files { + if let Some(var_paths) = map.remove(&name) { + for var_path in var_paths { + match &mut batch_request { + BatchRequest::Single(request) => { + request.set_upload( + &var_path, + filename.clone(), + content_type.clone(), + file.try_clone().unwrap(), + ); + } + BatchRequest::Batch(requests) => { + let mut s = var_path.splitn(2, '.'); + let idx = s.next().and_then(|idx| idx.parse::().ok()); + let path = s.next(); + + if let (Some(idx), Some(path)) = (idx, path) { + if let Some(request) = requests.get_mut(idx) { + request.set_upload( + path, + filename.clone(), + content_type.clone(), + file.try_clone().unwrap(), + ); + } + } + } + } + } + } + } + + if !map.is_empty() { + return Err(ParseRequestError::MissingFiles); + } + + Ok(batch_request) +} diff --git a/src/http/mod.rs b/src/http/mod.rs index 79a5c921..24b80322 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -1,14 +1,18 @@ //! A helper module that supports HTTP +mod batch; mod graphiql_source; #[cfg(feature = "multipart")] mod multipart; mod playground_source; mod websocket; +#[cfg(feature = "multipart")] +pub use batch::receive_batch_body; +pub use batch::receive_batch_json; pub use graphiql_source::graphiql_source; #[cfg(feature = "multipart")] -pub use multipart::{receive_multipart, MultipartOptions}; +pub use multipart::MultipartOptions; pub use playground_source::{playground_source, GraphQLPlaygroundConfig}; pub use websocket::WebSocketStream; @@ -16,19 +20,16 @@ use crate::{ParseRequestError, Request}; use futures::io::AsyncRead; use futures::AsyncReadExt; -/// Receive a GraphQL request from a content type and body. -/// -/// If the content type is multipart it will use `receive_multipart`, otherwise it will use -/// `receive_json`. #[cfg(feature = "multipart")] #[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))] +/// Receive a GraphQL request from a content type and body. pub async fn receive_body( content_type: Option>, body: impl AsyncRead + Send + 'static, opts: MultipartOptions, ) -> Result { if let Some(Ok(boundary)) = content_type.map(multer::parse_boundary) { - receive_multipart(body, boundary, opts).await + multipart::receive_multipart(body, boundary, opts).await } else { receive_json(body).await } diff --git a/src/http/multipart.rs b/src/http/multipart.rs index 79fea22c..1664aae3 100644 --- a/src/http/multipart.rs +++ b/src/http/multipart.rs @@ -39,8 +39,7 @@ impl MultipartOptions { } /// Receive a multipart request. -#[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))] -pub async fn receive_multipart( +pub(crate) async fn receive_multipart( body: impl AsyncRead + Send + 'static, boundary: impl Into, opts: MultipartOptions, @@ -122,7 +121,7 @@ pub async fn receive_multipart( } pin_project! { - struct ReaderStream { + pub(crate) struct ReaderStream { buf: [u8; 2048], #[pin] reader: T, @@ -130,7 +129,7 @@ pin_project! { } impl ReaderStream { - fn new(reader: T) -> Self { + pub(crate) fn new(reader: T) -> Self { Self { buf: [0; 2048], reader, diff --git a/src/lib.rs b/src/lib.rs index 1d773804..3e33e48d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -47,6 +47,7 @@ //! * Limit query complexity/depth //! * Error Extensions //! * Apollo Federation +//! * Batch Queries //! //! ## Integrations //! @@ -157,8 +158,8 @@ pub use error::{ pub use look_ahead::Lookahead; pub use parser::{types::ConstValue as Value, Pos, Positioned}; pub use registry::CacheControl; -pub use request::Request; -pub use response::Response; +pub use request::{BatchRequest, Request}; +pub use response::{BatchResponse, Response}; pub use schema::{Schema, SchemaBuilder, SchemaEnv}; pub use serde_json::Number; pub use types::*; diff --git a/src/registry/cache_control.rs b/src/registry/cache_control.rs index fe1d1111..ce7ed950 100644 --- a/src/registry/cache_control.rs +++ b/src/registry/cache_control.rs @@ -63,14 +63,17 @@ impl CacheControl { } impl CacheControl { - pub(crate) fn merge(&mut self, other: &CacheControl) { - self.public = self.public && other.public; - self.max_age = if self.max_age == 0 { - other.max_age - } else if other.max_age == 0 { - self.max_age - } else { - self.max_age.min(other.max_age) - }; + #[must_use] + pub(crate) fn merge(self, other: &CacheControl) -> CacheControl { + CacheControl { + public: self.public && other.public, + max_age: if self.max_age == 0 { + other.max_age + } else if other.max_age == 0 { + self.max_age + } else { + self.max_age.min(other.max_age) + }, + } } } diff --git a/src/request.rs b/src/request.rs index a5bfe3c4..33d12f42 100644 --- a/src/request.rs +++ b/src/request.rs @@ -1,6 +1,6 @@ use crate::parser::types::UploadValue; use crate::{Data, Value, Variables}; -use serde::Deserialize; +use serde::{Deserialize, Deserializer}; use std::any::Any; use std::fs::File; @@ -87,6 +87,47 @@ impl> From for Request { } } +/// Batch support for GraphQL requests, which is either a single query, or an array of queries +/// +/// **Reference:** +#[derive(Debug, Deserialize)] +#[serde(untagged)] +pub enum BatchRequest { + /// Single query + Single(Request), + + /// Non-empty array of queries + #[serde(deserialize_with = "deserialize_non_empty_vec")] + Batch(Vec), +} + +fn deserialize_non_empty_vec<'de, D, T>(deserializer: D) -> std::result::Result, D::Error> +where + D: Deserializer<'de>, + T: Deserialize<'de>, +{ + use serde::de::Error as _; + + let v = Vec::::deserialize(deserializer)?; + if v.is_empty() { + Err(D::Error::invalid_length(0, &"a positive integer")) + } else { + Ok(v) + } +} + +impl From for BatchRequest { + fn from(r: Request) -> Self { + BatchRequest::Single(r) + } +} + +impl From> for BatchRequest { + fn from(r: Vec) -> Self { + BatchRequest::Batch(r) + } +} + #[cfg(test)] mod tests { use super::*; @@ -137,4 +178,45 @@ mod tests { assert!(request.operation_name.is_none()); assert_eq!(request.query, "{ a b c }"); } + + #[test] + fn test_batch_request_single() { + let request: BatchRequest = serde_json::from_value(json! ({ + "query": "{ a b c }" + })) + .unwrap(); + + if let BatchRequest::Single(request) = request { + assert!(request.variables.0.is_empty()); + assert!(request.operation_name.is_none()); + assert_eq!(request.query, "{ a b c }"); + } else { + unreachable!() + } + } + + #[test] + fn test_batch_request_batch() { + let request: BatchRequest = serde_json::from_value(json!([ + { + "query": "{ a b c }" + }, + { + "query": "{ d e }" + } + ])) + .unwrap(); + + if let BatchRequest::Batch(requests) = request { + assert!(requests[0].variables.0.is_empty()); + assert!(requests[0].operation_name.is_none()); + assert_eq!(requests[0].query, "{ a b c }"); + + assert!(requests[1].variables.0.is_empty()); + assert!(requests[1].operation_name.is_none()); + assert_eq!(requests[1].query, "{ d e }"); + } else { + unreachable!() + } + } } diff --git a/src/response.rs b/src/response.rs index c082e1b1..a252d3a4 100644 --- a/src/response.rs +++ b/src/response.rs @@ -1,4 +1,5 @@ use crate::{CacheControl, Error, Result}; +use serde::Serialize; /// Query response #[derive(Debug, Default)] @@ -88,3 +89,57 @@ impl From for Response { Self::from_error(err) } } + +/// Response for batchable queries +#[derive(Debug, Serialize)] +#[serde(untagged)] +pub enum BatchResponse { + /// Response for single queries + Single(Response), + + /// Response for batch queries + Batch(Vec), +} + +impl BatchResponse { + /// Get cache control value + pub fn cache_control(&self) -> CacheControl { + match self { + BatchResponse::Single(resp) => resp.cache_control, + BatchResponse::Batch(resp) => resp.iter().fold(CacheControl::default(), |acc, item| { + acc.merge(&item.cache_control) + }), + } + } + + /// Returns `true` if all responses are ok. + pub fn is_ok(&self) -> bool { + match self { + BatchResponse::Single(resp) => resp.is_ok(), + BatchResponse::Batch(resp) => resp.iter().all(Response::is_ok), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_batch_response_single() { + let resp = BatchResponse::Single(Response::new(serde_json::Value::Bool(true))); + assert_eq!(serde_json::to_string(&resp).unwrap(), r#"{"data":true}"#); + } + + #[test] + fn test_batch_response_batch() { + let resp = BatchResponse::Batch(vec![ + Response::new(serde_json::Value::Bool(true)), + Response::new(serde_json::Value::String("1".to_string())), + ]); + assert_eq!( + serde_json::to_string(&resp).unwrap(), + r#"[{"data":true},{"data":"1"}]"# + ); + } +} diff --git a/src/schema.rs b/src/schema.rs index 2019f0e0..98d5a645 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -9,8 +9,8 @@ use crate::subscription::collect_subscription_streams; use crate::types::QueryRoot; use crate::validation::{check_rules, CheckResult, ValidationMode}; use crate::{ - CacheControl, ContextBase, Error, Pos, QueryEnv, QueryError, Request, Response, Result, - SubscriptionType, Type, Variables, ID, + BatchRequest, BatchResponse, CacheControl, ContextBase, Error, Pos, QueryEnv, QueryError, + Request, Response, Result, SubscriptionType, Type, Variables, ID, }; use async_graphql_parser::types::ExecutableDocumentData; use futures::stream::{self, Stream, StreamExt}; @@ -434,6 +434,19 @@ where } } + /// Execute an GraphQL batch query. + pub async fn execute_batch(&self, batch_request: BatchRequest) -> BatchResponse { + match batch_request { + BatchRequest::Single(request) => BatchResponse::Single(self.execute(request).await), + BatchRequest::Batch(requests) => BatchResponse::Batch( + futures::stream::iter(requests.into_iter()) + .then(|request| self.execute(request)) + .collect() + .await, + ), + } + } + pub(crate) fn execute_stream_with_ctx_data( &self, request: impl Into, diff --git a/src/types/merged_object.rs b/src/types/merged_object.rs index da9bfcf6..6d17cc65 100644 --- a/src/types/merged_object.rs +++ b/src/types/merged_object.rs @@ -41,7 +41,7 @@ impl Type for MergedObject { }) = registry.types.remove(&*A::type_name()) { fields.extend(a_fields); - cc.merge(&a_cc); + cc = cc.merge(&a_cc); } B::create_type_info(registry); @@ -52,7 +52,7 @@ impl Type for MergedObject { }) = registry.types.remove(&*B::type_name()) { fields.extend(b_fields); - cc.merge(&b_cc); + cc = cc.merge(&b_cc); } MetaType::Object { diff --git a/src/validation/visitors/cache_control.rs b/src/validation/visitors/cache_control.rs index 1f1a04e9..edd6d086 100644 --- a/src/validation/visitors/cache_control.rs +++ b/src/validation/visitors/cache_control.rs @@ -15,7 +15,7 @@ impl<'ctx, 'a> Visitor<'ctx> for CacheControlCalculate<'a> { ) { if let Some(current_type) = ctx.current_type() { if let MetaType::Object { cache_control, .. } = current_type { - self.cache_control.merge(cache_control); + *self.cache_control = self.cache_control.merge(cache_control); } } } @@ -25,7 +25,7 @@ impl<'ctx, 'a> Visitor<'ctx> for CacheControlCalculate<'a> { .parent_type() .and_then(|parent| parent.field_by_name(&field.node.name.node)) { - self.cache_control.merge(®istry_field.cache_control); + *self.cache_control = self.cache_control.merge(®istry_field.cache_control); } } } diff --git a/tests/batch_request.rs b/tests/batch_request.rs new file mode 100644 index 00000000..933aadc4 --- /dev/null +++ b/tests/batch_request.rs @@ -0,0 +1,33 @@ +use async_graphql::*; + +#[async_std::test] +pub async fn test_batch_request() { + struct Query; + + #[GQLObject] + impl Query { + async fn value(&self, a: i32, b: i32) -> i32 { + a + b + } + } + + let schema = Schema::new(Query, EmptyMutation, EmptySubscription); + let batch: BatchRequest = vec![ + Request::new("{ value(a: 10, b: 20) }"), + Request::new("{ value(a: 30, b: 40) }"), + Request::new("{ value1 }"), + ] + .into(); + let resp = schema.execute_batch(batch).await; + assert_eq!( + serde_json::to_value(&resp).unwrap(), + serde_json::json!([ + {"data": { "value": 30 }}, + {"data": { "value": 70 }}, + {"errors": [{ + "message": r#"Unknown field "value1" on type "Query". Did you mean "value"?"#, + "locations": [{"line": 1, "column": 3}] + }]}, + ]) + ); +}