Add support for batch queries.

This commit is contained in:
Sunli 2020-09-17 16:39:55 +08:00
parent 52a534df0c
commit acc1a6f515
13 changed files with 359 additions and 30 deletions

View File

@ -51,6 +51,7 @@ This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in
* Limit query complexity/depth
* Error Extensions
* Apollo Federation
* Batch Queries
## Examples

View File

@ -8,7 +8,7 @@ mod subscription;
use actix_web::dev::{HttpResponseBuilder, Payload, PayloadStream};
use actix_web::http::StatusCode;
use actix_web::{http, web, Error, FromRequest, HttpRequest, HttpResponse, Responder};
use async_graphql::http::{receive_body, MultipartOptions};
use async_graphql::http::MultipartOptions;
use async_graphql::{ParseRequestError, Request, Response};
use futures::channel::mpsc;
use futures::future::Ready;
@ -67,7 +67,7 @@ impl FromRequest for GQLRequest {
Box::pin(async move {
Ok(GQLRequest(
receive_body(
async_graphql::http::receive_body(
content_type,
rx.map_err(|err| io::Error::new(ErrorKind::Other, err))
.into_async_read(),

141
src/http/batch.rs Normal file
View File

@ -0,0 +1,141 @@
#[cfg(feature = "multipart")]
use crate::http::{multipart::ReaderStream, MultipartOptions};
use crate::{BatchRequest, ParseRequestError};
use futures::{AsyncRead, AsyncReadExt};
#[cfg(feature = "multipart")]
use multer::{Constraints, Multipart, SizeLimit};
#[cfg(feature = "multipart")]
use std::{
collections::HashMap,
io::{Seek, SeekFrom, Write},
};
#[cfg(feature = "multipart")]
#[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))]
/// Receive a GraphQL batch request from a content type and body.
pub async fn receive_batch_body(
content_type: Option<impl AsRef<str>>,
body: impl AsyncRead + Send + 'static,
opts: MultipartOptions,
) -> Result<BatchRequest, ParseRequestError> {
if let Some(Ok(boundary)) = content_type.map(multer::parse_boundary) {
receive_batch_multipart(body, boundary, opts).await
} else {
receive_batch_json(body).await
}
}
/// Receive a GraphQL batch request from a body as JSON.
pub async fn receive_batch_json(
body: impl AsyncRead + Send + 'static,
) -> Result<BatchRequest, ParseRequestError> {
let mut data = Vec::new();
futures::pin_mut!(body);
body.read_to_end(&mut data)
.await
.map_err(ParseRequestError::Io)?;
Ok(serde_json::from_slice::<BatchRequest>(&data).map_err(ParseRequestError::InvalidRequest)?)
}
#[cfg(feature = "multipart")]
async fn receive_batch_multipart(
body: impl AsyncRead + Send + 'static,
boundary: impl Into<String>,
opts: MultipartOptions,
) -> Result<BatchRequest, ParseRequestError> {
let mut multipart = Multipart::new_with_constraints(
ReaderStream::new(body),
boundary,
Constraints::new().size_limit({
let mut limit = SizeLimit::new();
if let (Some(max_file_size), Some(max_num_files)) =
(opts.max_file_size, opts.max_file_size)
{
limit = limit.whole_stream((max_file_size * max_num_files) as u64);
}
if let Some(max_file_size) = opts.max_file_size {
limit = limit.per_field(max_file_size as u64);
}
limit
}),
);
let mut batch_request = None;
let mut map = None;
let mut files = Vec::new();
while let Some(mut field) = multipart.next_field().await? {
match field.name() {
Some("operations") => {
let request_str = field.text().await?;
batch_request = Some(
serde_json::from_str::<BatchRequest>(&request_str)
.map_err(ParseRequestError::InvalidRequest)?,
);
}
Some("map") => {
let map_str = field.text().await?;
map = Some(
serde_json::from_str::<HashMap<String, Vec<String>>>(&map_str)
.map_err(ParseRequestError::InvalidFilesMap)?,
);
}
_ => {
if let Some(name) = field.name().map(ToString::to_string) {
if let Some(filename) = field.file_name().map(ToString::to_string) {
let content_type = field.content_type().map(|mime| mime.to_string());
let mut file = tempfile::tempfile().map_err(ParseRequestError::Io)?;
while let Some(chunk) = field.chunk().await.unwrap() {
file.write(&chunk).map_err(ParseRequestError::Io)?;
}
file.seek(SeekFrom::Start(0))?;
files.push((name, filename, content_type, file));
}
}
}
}
}
let mut batch_request: BatchRequest =
batch_request.ok_or(ParseRequestError::MissingOperatorsPart)?;
let map = map.as_mut().ok_or(ParseRequestError::MissingMapPart)?;
for (name, filename, content_type, file) in files {
if let Some(var_paths) = map.remove(&name) {
for var_path in var_paths {
match &mut batch_request {
BatchRequest::Single(request) => {
request.set_upload(
&var_path,
filename.clone(),
content_type.clone(),
file.try_clone().unwrap(),
);
}
BatchRequest::Batch(requests) => {
let mut s = var_path.splitn(2, '.');
let idx = s.next().and_then(|idx| idx.parse::<usize>().ok());
let path = s.next();
if let (Some(idx), Some(path)) = (idx, path) {
if let Some(request) = requests.get_mut(idx) {
request.set_upload(
path,
filename.clone(),
content_type.clone(),
file.try_clone().unwrap(),
);
}
}
}
}
}
}
}
if !map.is_empty() {
return Err(ParseRequestError::MissingFiles);
}
Ok(batch_request)
}

View File

@ -1,14 +1,18 @@
//! A helper module that supports HTTP
mod batch;
mod graphiql_source;
#[cfg(feature = "multipart")]
mod multipart;
mod playground_source;
mod websocket;
#[cfg(feature = "multipart")]
pub use batch::receive_batch_body;
pub use batch::receive_batch_json;
pub use graphiql_source::graphiql_source;
#[cfg(feature = "multipart")]
pub use multipart::{receive_multipart, MultipartOptions};
pub use multipart::MultipartOptions;
pub use playground_source::{playground_source, GraphQLPlaygroundConfig};
pub use websocket::WebSocketStream;
@ -16,19 +20,16 @@ use crate::{ParseRequestError, Request};
use futures::io::AsyncRead;
use futures::AsyncReadExt;
/// Receive a GraphQL request from a content type and body.
///
/// If the content type is multipart it will use `receive_multipart`, otherwise it will use
/// `receive_json`.
#[cfg(feature = "multipart")]
#[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))]
/// Receive a GraphQL request from a content type and body.
pub async fn receive_body(
content_type: Option<impl AsRef<str>>,
body: impl AsyncRead + Send + 'static,
opts: MultipartOptions,
) -> Result<Request, ParseRequestError> {
if let Some(Ok(boundary)) = content_type.map(multer::parse_boundary) {
receive_multipart(body, boundary, opts).await
multipart::receive_multipart(body, boundary, opts).await
} else {
receive_json(body).await
}

View File

@ -39,8 +39,7 @@ impl MultipartOptions {
}
/// Receive a multipart request.
#[cfg_attr(feature = "nightly", doc(cfg(feature = "multipart")))]
pub async fn receive_multipart(
pub(crate) async fn receive_multipart(
body: impl AsyncRead + Send + 'static,
boundary: impl Into<String>,
opts: MultipartOptions,
@ -122,7 +121,7 @@ pub async fn receive_multipart(
}
pin_project! {
struct ReaderStream<T> {
pub(crate) struct ReaderStream<T> {
buf: [u8; 2048],
#[pin]
reader: T,
@ -130,7 +129,7 @@ pin_project! {
}
impl<T> ReaderStream<T> {
fn new(reader: T) -> Self {
pub(crate) fn new(reader: T) -> Self {
Self {
buf: [0; 2048],
reader,

View File

@ -47,6 +47,7 @@
//! * Limit query complexity/depth
//! * Error Extensions
//! * Apollo Federation
//! * Batch Queries
//!
//! ## Integrations
//!
@ -157,8 +158,8 @@ pub use error::{
pub use look_ahead::Lookahead;
pub use parser::{types::ConstValue as Value, Pos, Positioned};
pub use registry::CacheControl;
pub use request::Request;
pub use response::Response;
pub use request::{BatchRequest, Request};
pub use response::{BatchResponse, Response};
pub use schema::{Schema, SchemaBuilder, SchemaEnv};
pub use serde_json::Number;
pub use types::*;

View File

@ -63,14 +63,17 @@ impl CacheControl {
}
impl CacheControl {
pub(crate) fn merge(&mut self, other: &CacheControl) {
self.public = self.public && other.public;
self.max_age = if self.max_age == 0 {
other.max_age
} else if other.max_age == 0 {
self.max_age
} else {
self.max_age.min(other.max_age)
};
#[must_use]
pub(crate) fn merge(self, other: &CacheControl) -> CacheControl {
CacheControl {
public: self.public && other.public,
max_age: if self.max_age == 0 {
other.max_age
} else if other.max_age == 0 {
self.max_age
} else {
self.max_age.min(other.max_age)
},
}
}
}

View File

@ -1,6 +1,6 @@
use crate::parser::types::UploadValue;
use crate::{Data, Value, Variables};
use serde::Deserialize;
use serde::{Deserialize, Deserializer};
use std::any::Any;
use std::fs::File;
@ -87,6 +87,47 @@ impl<T: Into<String>> From<T> for Request {
}
}
/// Batch support for GraphQL requests, which is either a single query, or an array of queries
///
/// **Reference:** <https://www.apollographql.com/blog/batching-client-graphql-queries-a685f5bcd41b/>
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum BatchRequest {
/// Single query
Single(Request),
/// Non-empty array of queries
#[serde(deserialize_with = "deserialize_non_empty_vec")]
Batch(Vec<Request>),
}
fn deserialize_non_empty_vec<'de, D, T>(deserializer: D) -> std::result::Result<Vec<T>, D::Error>
where
D: Deserializer<'de>,
T: Deserialize<'de>,
{
use serde::de::Error as _;
let v = Vec::<T>::deserialize(deserializer)?;
if v.is_empty() {
Err(D::Error::invalid_length(0, &"a positive integer"))
} else {
Ok(v)
}
}
impl From<Request> for BatchRequest {
fn from(r: Request) -> Self {
BatchRequest::Single(r)
}
}
impl From<Vec<Request>> for BatchRequest {
fn from(r: Vec<Request>) -> Self {
BatchRequest::Batch(r)
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -137,4 +178,45 @@ mod tests {
assert!(request.operation_name.is_none());
assert_eq!(request.query, "{ a b c }");
}
#[test]
fn test_batch_request_single() {
let request: BatchRequest = serde_json::from_value(json! ({
"query": "{ a b c }"
}))
.unwrap();
if let BatchRequest::Single(request) = request {
assert!(request.variables.0.is_empty());
assert!(request.operation_name.is_none());
assert_eq!(request.query, "{ a b c }");
} else {
unreachable!()
}
}
#[test]
fn test_batch_request_batch() {
let request: BatchRequest = serde_json::from_value(json!([
{
"query": "{ a b c }"
},
{
"query": "{ d e }"
}
]))
.unwrap();
if let BatchRequest::Batch(requests) = request {
assert!(requests[0].variables.0.is_empty());
assert!(requests[0].operation_name.is_none());
assert_eq!(requests[0].query, "{ a b c }");
assert!(requests[1].variables.0.is_empty());
assert!(requests[1].operation_name.is_none());
assert_eq!(requests[1].query, "{ d e }");
} else {
unreachable!()
}
}
}

View File

@ -1,4 +1,5 @@
use crate::{CacheControl, Error, Result};
use serde::Serialize;
/// Query response
#[derive(Debug, Default)]
@ -88,3 +89,57 @@ impl From<Error> for Response {
Self::from_error(err)
}
}
/// Response for batchable queries
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum BatchResponse {
/// Response for single queries
Single(Response),
/// Response for batch queries
Batch(Vec<Response>),
}
impl BatchResponse {
/// Get cache control value
pub fn cache_control(&self) -> CacheControl {
match self {
BatchResponse::Single(resp) => resp.cache_control,
BatchResponse::Batch(resp) => resp.iter().fold(CacheControl::default(), |acc, item| {
acc.merge(&item.cache_control)
}),
}
}
/// Returns `true` if all responses are ok.
pub fn is_ok(&self) -> bool {
match self {
BatchResponse::Single(resp) => resp.is_ok(),
BatchResponse::Batch(resp) => resp.iter().all(Response::is_ok),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_response_single() {
let resp = BatchResponse::Single(Response::new(serde_json::Value::Bool(true)));
assert_eq!(serde_json::to_string(&resp).unwrap(), r#"{"data":true}"#);
}
#[test]
fn test_batch_response_batch() {
let resp = BatchResponse::Batch(vec![
Response::new(serde_json::Value::Bool(true)),
Response::new(serde_json::Value::String("1".to_string())),
]);
assert_eq!(
serde_json::to_string(&resp).unwrap(),
r#"[{"data":true},{"data":"1"}]"#
);
}
}

View File

@ -9,8 +9,8 @@ use crate::subscription::collect_subscription_streams;
use crate::types::QueryRoot;
use crate::validation::{check_rules, CheckResult, ValidationMode};
use crate::{
CacheControl, ContextBase, Error, Pos, QueryEnv, QueryError, Request, Response, Result,
SubscriptionType, Type, Variables, ID,
BatchRequest, BatchResponse, CacheControl, ContextBase, Error, Pos, QueryEnv, QueryError,
Request, Response, Result, SubscriptionType, Type, Variables, ID,
};
use async_graphql_parser::types::ExecutableDocumentData;
use futures::stream::{self, Stream, StreamExt};
@ -434,6 +434,19 @@ where
}
}
/// Execute an GraphQL batch query.
pub async fn execute_batch(&self, batch_request: BatchRequest) -> BatchResponse {
match batch_request {
BatchRequest::Single(request) => BatchResponse::Single(self.execute(request).await),
BatchRequest::Batch(requests) => BatchResponse::Batch(
futures::stream::iter(requests.into_iter())
.then(|request| self.execute(request))
.collect()
.await,
),
}
}
pub(crate) fn execute_stream_with_ctx_data(
&self,
request: impl Into<Request>,

View File

@ -41,7 +41,7 @@ impl<A: Type, B: Type> Type for MergedObject<A, B> {
}) = registry.types.remove(&*A::type_name())
{
fields.extend(a_fields);
cc.merge(&a_cc);
cc = cc.merge(&a_cc);
}
B::create_type_info(registry);
@ -52,7 +52,7 @@ impl<A: Type, B: Type> Type for MergedObject<A, B> {
}) = registry.types.remove(&*B::type_name())
{
fields.extend(b_fields);
cc.merge(&b_cc);
cc = cc.merge(&b_cc);
}
MetaType::Object {

View File

@ -15,7 +15,7 @@ impl<'ctx, 'a> Visitor<'ctx> for CacheControlCalculate<'a> {
) {
if let Some(current_type) = ctx.current_type() {
if let MetaType::Object { cache_control, .. } = current_type {
self.cache_control.merge(cache_control);
*self.cache_control = self.cache_control.merge(cache_control);
}
}
}
@ -25,7 +25,7 @@ impl<'ctx, 'a> Visitor<'ctx> for CacheControlCalculate<'a> {
.parent_type()
.and_then(|parent| parent.field_by_name(&field.node.name.node))
{
self.cache_control.merge(&registry_field.cache_control);
*self.cache_control = self.cache_control.merge(&registry_field.cache_control);
}
}
}

33
tests/batch_request.rs Normal file
View File

@ -0,0 +1,33 @@
use async_graphql::*;
#[async_std::test]
pub async fn test_batch_request() {
struct Query;
#[GQLObject]
impl Query {
async fn value(&self, a: i32, b: i32) -> i32 {
a + b
}
}
let schema = Schema::new(Query, EmptyMutation, EmptySubscription);
let batch: BatchRequest = vec![
Request::new("{ value(a: 10, b: 20) }"),
Request::new("{ value(a: 30, b: 40) }"),
Request::new("{ value1 }"),
]
.into();
let resp = schema.execute_batch(batch).await;
assert_eq!(
serde_json::to_value(&resp).unwrap(),
serde_json::json!([
{"data": { "value": 30 }},
{"data": { "value": 70 }},
{"errors": [{
"message": r#"Unknown field "value1" on type "Query". Did you mean "value"?"#,
"locations": [{"line": 1, "column": 3}]
}]},
])
);
}