diff --git a/Cargo.toml b/Cargo.toml
index 6f4dbfe2..08bc7c30 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -82,4 +82,5 @@ members = [
"integrations/rocket",
"integrations/tide",
"integrations/warp",
+ "integrations/axum",
]
diff --git a/README.md b/README.md
index 52364ed5..ea6d5352 100644
--- a/README.md
+++ b/README.md
@@ -100,6 +100,7 @@ cd examples && cargo run --bin [name]
* Warp [async-graphql-warp](https://crates.io/crates/async-graphql-warp)
* Tide [async-graphql-tide](https://crates.io/crates/async-graphql-tide)
* Rocket [async-graphql-rocket](https://github.com/async-graphql/async-graphql/tree/master/integrations/rocket)
+* Axum [async-graphql-axum](https://github.com/async-graphql/async-graphql/tree/master/integrations/axum)
## Blog Posts
diff --git a/examples b/examples
index 45b58bf1..849879c4 160000
--- a/examples
+++ b/examples
@@ -1 +1 @@
-Subproject commit 45b58bf1ce53ad76275d3d29c011e37de00e3bb7
+Subproject commit 849879c44cf75941e01d249c86cba6496e47a8ae
diff --git a/integrations/actix-web/src/lib.rs b/integrations/actix-web/src/lib.rs
index a67c7dc5..fc73c967 100644
--- a/integrations/actix-web/src/lib.rs
+++ b/integrations/actix-web/src/lib.rs
@@ -1,6 +1,7 @@
//! Async-graphql integration with Actix-web
#![forbid(unsafe_code)]
#![allow(clippy::upper_case_acronyms)]
+#![warn(missing_docs)]
mod subscription;
diff --git a/integrations/actix-web/src/subscription.rs b/integrations/actix-web/src/subscription.rs
index 2e3cd745..ad2713e6 100644
--- a/integrations/actix-web/src/subscription.rs
+++ b/integrations/actix-web/src/subscription.rs
@@ -11,7 +11,7 @@ use actix_http::{ws, Error};
use actix_web::web::Bytes;
use actix_web::{HttpRequest, HttpResponse};
use actix_web_actors::ws::{CloseReason, Message, ProtocolError, WebsocketContext};
-use async_graphql::http::{WebSocket, WebSocketProtocols, WsMessage};
+use async_graphql::http::{WebSocket, WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS};
use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType};
use futures_util::future::Ready;
use futures_util::stream::Stream;
@@ -96,7 +96,7 @@ where
initializer: Some(initializer),
continuation: Vec::new(),
},
- &["graphql-transport-ws", "graphql-ws"],
+ &ALL_WEBSOCKET_PROTOCOLS,
request,
stream,
)
diff --git a/integrations/axum/Cargo.toml b/integrations/axum/Cargo.toml
new file mode 100644
index 00000000..04255532
--- /dev/null
+++ b/integrations/axum/Cargo.toml
@@ -0,0 +1,19 @@
+[package]
+name = "async-graphql-axum"
+version = "2.9.9"
+edition = "2018"
+
+[dependencies]
+async-graphql = { path = "../..", version = "2.9.8" }
+
+async-trait = "0.1.51"
+#axum = { version = "0.1.2", features = ["ws", "headers"] }
+axum = { git = "https://github.com/sunli829/axum.git", rev = "f43470a", features = ["ws", "headers"] }
+bytes = "1.0.1"
+headers = "0.3.4"
+http = "0.2.4"
+http-body = "0.4.2"
+serde_json = "1.0.66"
+serde_urlencoded = "0.7.0"
+tokio-util = { version = "0.6.7", features = ["io", "compat"] }
+futures-util = "0.3.13"
diff --git a/integrations/axum/src/extract.rs b/integrations/axum/src/extract.rs
new file mode 100644
index 00000000..34613787
--- /dev/null
+++ b/integrations/axum/src/extract.rs
@@ -0,0 +1,119 @@
+use std::fmt::Display;
+use std::io::ErrorKind;
+
+use async_graphql::futures_util::TryStreamExt;
+use async_graphql::http::MultipartOptions;
+use async_graphql::ParseRequestError;
+use axum::extract::{BodyStream, FromRequest, RequestParts};
+use bytes::Buf;
+use http::Method;
+use tokio_util::compat::TokioAsyncReadCompatExt;
+
+/// Extractor for GraphQL request.
+pub struct GraphQLRequest(pub async_graphql::Request);
+
+impl GraphQLRequest {
+ /// Unwraps the value to `async_graphql::Request`.
+ #[must_use]
+ pub fn into_inner(self) -> async_graphql::Request {
+ self.0
+ }
+}
+
+/// Rejection response types.
+pub mod rejection {
+ use async_graphql::ParseRequestError;
+ use axum::body::Body;
+ use axum::response::IntoResponse;
+
+ /// Rejection used for [`GraphQLRequest`](GraphQLRequest).
+ pub struct GraphQLRejection(pub ParseRequestError);
+
+ impl IntoResponse for GraphQLRejection {
+ fn into_response(self) -> http::Response
{
+ todo!()
+ }
+ }
+
+ impl From for GraphQLRejection {
+ fn from(err: ParseRequestError) -> Self {
+ GraphQLRejection(err)
+ }
+ }
+}
+
+#[async_trait::async_trait]
+impl FromRequest for GraphQLRequest
+where
+ B: http_body::Body + Unpin + Send + Sync + 'static,
+ B::Data: Buf + Send,
+ B::Error: Display + Send,
+{
+ type Rejection = rejection::GraphQLRejection;
+
+ async fn from_request(req: &mut RequestParts) -> Result {
+ Ok(GraphQLRequest(
+ GraphQLBatchRequest::from_request(req)
+ .await?
+ .0
+ .into_single()?,
+ ))
+ }
+}
+
+/// Extractor for GraphQL batch request.
+pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest);
+
+impl GraphQLBatchRequest {
+ /// Unwraps the value to `async_graphql::BatchRequest`.
+ #[must_use]
+ pub fn into_inner(self) -> async_graphql::BatchRequest {
+ self.0
+ }
+}
+
+#[async_trait::async_trait]
+impl FromRequest for GraphQLBatchRequest
+where
+ B: http_body::Body + Unpin + Send + Sync + 'static,
+ B::Data: Buf + Send,
+ B::Error: Display + Send,
+{
+ type Rejection = rejection::GraphQLRejection;
+
+ async fn from_request(req: &mut RequestParts) -> Result {
+ if let (Some(&Method::GET), Some(uri)) = (req.method(), req.uri()) {
+ let res = serde_urlencoded::from_str(uri.query().unwrap_or_default()).map_err(|err| {
+ ParseRequestError::Io(std::io::Error::new(
+ ErrorKind::Other,
+ format!("failed to parse graphql request from uri query: {}", err),
+ ))
+ });
+ Ok(Self(async_graphql::BatchRequest::Single(res?)))
+ } else {
+ let content_type = req
+ .headers()
+ .and_then(|headers| headers.get(http::header::CONTENT_TYPE))
+ .and_then(|value| value.to_str().ok())
+ .map(ToString::to_string);
+ let body_stream = BodyStream::from_request(req)
+ .await
+ .map_err(|_| {
+ ParseRequestError::Io(std::io::Error::new(
+ ErrorKind::Other,
+ "body has been taken by another extractor".to_string(),
+ ))
+ })?
+ .map_err(|err| std::io::Error::new(ErrorKind::Other, err.to_string()));
+ let body_reader = tokio_util::io::StreamReader::new(body_stream).compat();
+ Ok(Self(
+ async_graphql::http::receive_batch_body(
+ content_type,
+ body_reader,
+ MultipartOptions::default(),
+ )
+ .await?,
+ ))
+ }
+ }
+}
diff --git a/integrations/axum/src/lib.rs b/integrations/axum/src/lib.rs
new file mode 100644
index 00000000..ec6814e2
--- /dev/null
+++ b/integrations/axum/src/lib.rs
@@ -0,0 +1,13 @@
+//! Async-graphql integration with Axum
+#![forbid(unsafe_code)]
+#![warn(missing_docs)]
+
+mod extract;
+mod response;
+mod subscription;
+
+pub use extract::{GraphQLBatchRequest, GraphQLRequest};
+pub use response::GraphQLResponse;
+pub use subscription::{
+ graphql_subscription, graphql_subscription_with_data, SecWebsocketProtocol,
+};
diff --git a/integrations/axum/src/response.rs b/integrations/axum/src/response.rs
new file mode 100644
index 00000000..c92549fd
--- /dev/null
+++ b/integrations/axum/src/response.rs
@@ -0,0 +1,51 @@
+use std::convert::TryFrom;
+
+use axum::body::Body;
+use axum::prelude::response::IntoResponse;
+use headers::HeaderName;
+use http::{HeaderValue, Response};
+
+/// Responder for a GraphQL response.
+///
+/// This contains a batch response, but since regular responses are a type of batch response it
+/// works for both.
+pub struct GraphQLResponse(pub async_graphql::BatchResponse);
+
+impl From for GraphQLResponse {
+ fn from(resp: async_graphql::Response) -> Self {
+ Self(resp.into())
+ }
+}
+
+impl From for GraphQLResponse {
+ fn from(resp: async_graphql::BatchResponse) -> Self {
+ Self(resp)
+ }
+}
+
+impl IntoResponse for GraphQLResponse {
+ fn into_response(self) -> Response {
+ let mut resp = Response::new(serde_json::to_string(&self.0).unwrap().into());
+ resp.headers_mut().insert(
+ http::header::CONTENT_TYPE,
+ HeaderValue::from_static("application/json"),
+ );
+ if self.0.is_ok() {
+ if let Some(cache_control) = self.0.cache_control().value() {
+ if let Ok(value) = HeaderValue::from_str(&cache_control) {
+ resp.headers_mut()
+ .insert(http::header::CACHE_CONTROL, value);
+ }
+ }
+ }
+ for (name, value) in self.0.http_headers() {
+ if let (Ok(name), Ok(value)) = (
+ HeaderName::try_from(name.as_bytes()),
+ HeaderValue::from_str(value),
+ ) {
+ resp.headers_mut().insert(name, value);
+ }
+ }
+ resp
+ }
+}
diff --git a/integrations/axum/src/subscription.rs b/integrations/axum/src/subscription.rs
new file mode 100644
index 00000000..f34f14ac
--- /dev/null
+++ b/integrations/axum/src/subscription.rs
@@ -0,0 +1,99 @@
+use std::future::Future;
+
+use async_graphql::http::{WebSocketProtocols, WsMessage};
+use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType};
+use axum::ws::{Message, WebSocket};
+use futures_util::{future, SinkExt, StreamExt};
+use headers::{Header, HeaderName, HeaderValue};
+
+/// The Sec-Websocket-Protocol header.
+#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)]
+pub struct SecWebsocketProtocol(pub WebSocketProtocols);
+
+impl Header for SecWebsocketProtocol {
+ fn name() -> &'static HeaderName {
+ &http::header::SEC_WEBSOCKET_PROTOCOL
+ }
+
+ fn decode<'i, I>(values: &mut I) -> Result
+ where
+ Self: Sized,
+ I: Iterator- ,
+ {
+ match values.next() {
+ Some(value) => {
+ let value = value.to_str().map_err(|_| headers::Error::invalid())?;
+ Ok(SecWebsocketProtocol(
+ value
+ .parse()
+ .ok()
+ .unwrap_or(WebSocketProtocols::SubscriptionsTransportWS),
+ ))
+ }
+ None => Err(headers::Error::invalid()),
+ }
+ }
+
+ fn encode>(&self, values: &mut E) {
+ values.extend(std::iter::once(HeaderValue::from_static(
+ self.0.sec_websocket_protocol(),
+ )))
+ }
+}
+
+/// GraphQL subscription handler
+pub async fn graphql_subscription(
+ websocket: WebSocket,
+ schema: Schema,
+ protocol: SecWebsocketProtocol,
+) where
+ Query: ObjectType + Sync + Send + 'static,
+ Mutation: ObjectType + Sync + Send + 'static,
+ Subscription: SubscriptionType + Send + Sync + 'static,
+{
+ graphql_subscription_with_data(websocket, schema, protocol, |_| async {
+ Ok(Default::default())
+ })
+ .await
+}
+
+/// GraphQL subscription handler
+///
+/// Specifies that a function converts the init payload to data.
+pub async fn graphql_subscription_with_data(
+ websocket: WebSocket,
+ schema: Schema,
+ protocol: SecWebsocketProtocol,
+ initializer: F,
+) where
+ Query: ObjectType + 'static,
+ Mutation: ObjectType + 'static,
+ Subscription: SubscriptionType + 'static,
+ F: FnOnce(serde_json::Value) -> R + Send + 'static,
+ R: Future