diff --git a/Cargo.toml b/Cargo.toml index 6f4dbfe2..08bc7c30 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,4 +82,5 @@ members = [ "integrations/rocket", "integrations/tide", "integrations/warp", + "integrations/axum", ] diff --git a/README.md b/README.md index 52364ed5..ea6d5352 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,7 @@ cd examples && cargo run --bin [name] * Warp [async-graphql-warp](https://crates.io/crates/async-graphql-warp) * Tide [async-graphql-tide](https://crates.io/crates/async-graphql-tide) * Rocket [async-graphql-rocket](https://github.com/async-graphql/async-graphql/tree/master/integrations/rocket) +* Axum [async-graphql-axum](https://github.com/async-graphql/async-graphql/tree/master/integrations/axum) ## Blog Posts diff --git a/examples b/examples index 45b58bf1..849879c4 160000 --- a/examples +++ b/examples @@ -1 +1 @@ -Subproject commit 45b58bf1ce53ad76275d3d29c011e37de00e3bb7 +Subproject commit 849879c44cf75941e01d249c86cba6496e47a8ae diff --git a/integrations/actix-web/src/lib.rs b/integrations/actix-web/src/lib.rs index a67c7dc5..fc73c967 100644 --- a/integrations/actix-web/src/lib.rs +++ b/integrations/actix-web/src/lib.rs @@ -1,6 +1,7 @@ //! Async-graphql integration with Actix-web #![forbid(unsafe_code)] #![allow(clippy::upper_case_acronyms)] +#![warn(missing_docs)] mod subscription; diff --git a/integrations/actix-web/src/subscription.rs b/integrations/actix-web/src/subscription.rs index 2e3cd745..ad2713e6 100644 --- a/integrations/actix-web/src/subscription.rs +++ b/integrations/actix-web/src/subscription.rs @@ -11,7 +11,7 @@ use actix_http::{ws, Error}; use actix_web::web::Bytes; use actix_web::{HttpRequest, HttpResponse}; use actix_web_actors::ws::{CloseReason, Message, ProtocolError, WebsocketContext}; -use async_graphql::http::{WebSocket, WebSocketProtocols, WsMessage}; +use async_graphql::http::{WebSocket, WebSocketProtocols, WsMessage, ALL_WEBSOCKET_PROTOCOLS}; use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; use futures_util::future::Ready; use futures_util::stream::Stream; @@ -96,7 +96,7 @@ where initializer: Some(initializer), continuation: Vec::new(), }, - &["graphql-transport-ws", "graphql-ws"], + &ALL_WEBSOCKET_PROTOCOLS, request, stream, ) diff --git a/integrations/axum/Cargo.toml b/integrations/axum/Cargo.toml new file mode 100644 index 00000000..04255532 --- /dev/null +++ b/integrations/axum/Cargo.toml @@ -0,0 +1,19 @@ +[package] +name = "async-graphql-axum" +version = "2.9.9" +edition = "2018" + +[dependencies] +async-graphql = { path = "../..", version = "2.9.8" } + +async-trait = "0.1.51" +#axum = { version = "0.1.2", features = ["ws", "headers"] } +axum = { git = "https://github.com/sunli829/axum.git", rev = "f43470a", features = ["ws", "headers"] } +bytes = "1.0.1" +headers = "0.3.4" +http = "0.2.4" +http-body = "0.4.2" +serde_json = "1.0.66" +serde_urlencoded = "0.7.0" +tokio-util = { version = "0.6.7", features = ["io", "compat"] } +futures-util = "0.3.13" diff --git a/integrations/axum/src/extract.rs b/integrations/axum/src/extract.rs new file mode 100644 index 00000000..34613787 --- /dev/null +++ b/integrations/axum/src/extract.rs @@ -0,0 +1,119 @@ +use std::fmt::Display; +use std::io::ErrorKind; + +use async_graphql::futures_util::TryStreamExt; +use async_graphql::http::MultipartOptions; +use async_graphql::ParseRequestError; +use axum::extract::{BodyStream, FromRequest, RequestParts}; +use bytes::Buf; +use http::Method; +use tokio_util::compat::TokioAsyncReadCompatExt; + +/// Extractor for GraphQL request. +pub struct GraphQLRequest(pub async_graphql::Request); + +impl GraphQLRequest { + /// Unwraps the value to `async_graphql::Request`. + #[must_use] + pub fn into_inner(self) -> async_graphql::Request { + self.0 + } +} + +/// Rejection response types. +pub mod rejection { + use async_graphql::ParseRequestError; + use axum::body::Body; + use axum::response::IntoResponse; + + /// Rejection used for [`GraphQLRequest`](GraphQLRequest). + pub struct GraphQLRejection(pub ParseRequestError); + + impl IntoResponse for GraphQLRejection { + fn into_response(self) -> http::Response { + todo!() + } + } + + impl From for GraphQLRejection { + fn from(err: ParseRequestError) -> Self { + GraphQLRejection(err) + } + } +} + +#[async_trait::async_trait] +impl FromRequest for GraphQLRequest +where + B: http_body::Body + Unpin + Send + Sync + 'static, + B::Data: Buf + Send, + B::Error: Display + Send, +{ + type Rejection = rejection::GraphQLRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + Ok(GraphQLRequest( + GraphQLBatchRequest::from_request(req) + .await? + .0 + .into_single()?, + )) + } +} + +/// Extractor for GraphQL batch request. +pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest); + +impl GraphQLBatchRequest { + /// Unwraps the value to `async_graphql::BatchRequest`. + #[must_use] + pub fn into_inner(self) -> async_graphql::BatchRequest { + self.0 + } +} + +#[async_trait::async_trait] +impl FromRequest for GraphQLBatchRequest +where + B: http_body::Body + Unpin + Send + Sync + 'static, + B::Data: Buf + Send, + B::Error: Display + Send, +{ + type Rejection = rejection::GraphQLRejection; + + async fn from_request(req: &mut RequestParts) -> Result { + if let (Some(&Method::GET), Some(uri)) = (req.method(), req.uri()) { + let res = serde_urlencoded::from_str(uri.query().unwrap_or_default()).map_err(|err| { + ParseRequestError::Io(std::io::Error::new( + ErrorKind::Other, + format!("failed to parse graphql request from uri query: {}", err), + )) + }); + Ok(Self(async_graphql::BatchRequest::Single(res?))) + } else { + let content_type = req + .headers() + .and_then(|headers| headers.get(http::header::CONTENT_TYPE)) + .and_then(|value| value.to_str().ok()) + .map(ToString::to_string); + let body_stream = BodyStream::from_request(req) + .await + .map_err(|_| { + ParseRequestError::Io(std::io::Error::new( + ErrorKind::Other, + "body has been taken by another extractor".to_string(), + )) + })? + .map_err(|err| std::io::Error::new(ErrorKind::Other, err.to_string())); + let body_reader = tokio_util::io::StreamReader::new(body_stream).compat(); + Ok(Self( + async_graphql::http::receive_batch_body( + content_type, + body_reader, + MultipartOptions::default(), + ) + .await?, + )) + } + } +} diff --git a/integrations/axum/src/lib.rs b/integrations/axum/src/lib.rs new file mode 100644 index 00000000..ec6814e2 --- /dev/null +++ b/integrations/axum/src/lib.rs @@ -0,0 +1,13 @@ +//! Async-graphql integration with Axum +#![forbid(unsafe_code)] +#![warn(missing_docs)] + +mod extract; +mod response; +mod subscription; + +pub use extract::{GraphQLBatchRequest, GraphQLRequest}; +pub use response::GraphQLResponse; +pub use subscription::{ + graphql_subscription, graphql_subscription_with_data, SecWebsocketProtocol, +}; diff --git a/integrations/axum/src/response.rs b/integrations/axum/src/response.rs new file mode 100644 index 00000000..c92549fd --- /dev/null +++ b/integrations/axum/src/response.rs @@ -0,0 +1,51 @@ +use std::convert::TryFrom; + +use axum::body::Body; +use axum::prelude::response::IntoResponse; +use headers::HeaderName; +use http::{HeaderValue, Response}; + +/// Responder for a GraphQL response. +/// +/// This contains a batch response, but since regular responses are a type of batch response it +/// works for both. +pub struct GraphQLResponse(pub async_graphql::BatchResponse); + +impl From for GraphQLResponse { + fn from(resp: async_graphql::Response) -> Self { + Self(resp.into()) + } +} + +impl From for GraphQLResponse { + fn from(resp: async_graphql::BatchResponse) -> Self { + Self(resp) + } +} + +impl IntoResponse for GraphQLResponse { + fn into_response(self) -> Response { + let mut resp = Response::new(serde_json::to_string(&self.0).unwrap().into()); + resp.headers_mut().insert( + http::header::CONTENT_TYPE, + HeaderValue::from_static("application/json"), + ); + if self.0.is_ok() { + if let Some(cache_control) = self.0.cache_control().value() { + if let Ok(value) = HeaderValue::from_str(&cache_control) { + resp.headers_mut() + .insert(http::header::CACHE_CONTROL, value); + } + } + } + for (name, value) in self.0.http_headers() { + if let (Ok(name), Ok(value)) = ( + HeaderName::try_from(name.as_bytes()), + HeaderValue::from_str(value), + ) { + resp.headers_mut().insert(name, value); + } + } + resp + } +} diff --git a/integrations/axum/src/subscription.rs b/integrations/axum/src/subscription.rs new file mode 100644 index 00000000..f34f14ac --- /dev/null +++ b/integrations/axum/src/subscription.rs @@ -0,0 +1,99 @@ +use std::future::Future; + +use async_graphql::http::{WebSocketProtocols, WsMessage}; +use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType}; +use axum::ws::{Message, WebSocket}; +use futures_util::{future, SinkExt, StreamExt}; +use headers::{Header, HeaderName, HeaderValue}; + +/// The Sec-Websocket-Protocol header. +#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash)] +pub struct SecWebsocketProtocol(pub WebSocketProtocols); + +impl Header for SecWebsocketProtocol { + fn name() -> &'static HeaderName { + &http::header::SEC_WEBSOCKET_PROTOCOL + } + + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + match values.next() { + Some(value) => { + let value = value.to_str().map_err(|_| headers::Error::invalid())?; + Ok(SecWebsocketProtocol( + value + .parse() + .ok() + .unwrap_or(WebSocketProtocols::SubscriptionsTransportWS), + )) + } + None => Err(headers::Error::invalid()), + } + } + + fn encode>(&self, values: &mut E) { + values.extend(std::iter::once(HeaderValue::from_static( + self.0.sec_websocket_protocol(), + ))) + } +} + +/// GraphQL subscription handler +pub async fn graphql_subscription( + websocket: WebSocket, + schema: Schema, + protocol: SecWebsocketProtocol, +) where + Query: ObjectType + Sync + Send + 'static, + Mutation: ObjectType + Sync + Send + 'static, + Subscription: SubscriptionType + Send + Sync + 'static, +{ + graphql_subscription_with_data(websocket, schema, protocol, |_| async { + Ok(Default::default()) + }) + .await +} + +/// GraphQL subscription handler +/// +/// Specifies that a function converts the init payload to data. +pub async fn graphql_subscription_with_data( + websocket: WebSocket, + schema: Schema, + protocol: SecWebsocketProtocol, + initializer: F, +) where + Query: ObjectType + 'static, + Mutation: ObjectType + 'static, + Subscription: SubscriptionType + 'static, + F: FnOnce(serde_json::Value) -> R + Send + 'static, + R: Future> + Send + 'static, +{ + let (mut sink, stream) = websocket.split(); + let input = stream + .take_while(|res| future::ready(res.is_ok())) + .map(Result::unwrap) + .filter_map(|msg| { + if msg.is_text() || msg.is_binary() { + future::ready(Some(msg)) + } else { + future::ready(None) + } + }) + .map(Message::into_bytes); + + let mut stream = + async_graphql::http::WebSocket::with_data(schema, input, initializer, protocol.0).map( + |msg| match msg { + WsMessage::Text(text) => Message::text(text), + WsMessage::Close(code, status) => Message::close_with(code, status), + }, + ); + + while let Some(item) = stream.next().await { + let _ = sink.send(item).await; + } +} diff --git a/src/http/mod.rs b/src/http/mod.rs index b811b089..c0bb6284 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -12,7 +12,9 @@ use crate::{BatchRequest, ParseRequestError, Request}; pub use graphiql_source::graphiql_source; pub use multipart::MultipartOptions; pub use playground_source::{playground_source, GraphQLPlaygroundConfig}; -pub use websocket::{ClientMessage, Protocols as WebSocketProtocols, WebSocket, WsMessage}; +pub use websocket::{ + ClientMessage, Protocols as WebSocketProtocols, WebSocket, WsMessage, ALL_WEBSOCKET_PROTOCOLS, +}; /// Receive a GraphQL request from a content type and body. pub async fn receive_body( diff --git a/src/http/websocket.rs b/src/http/websocket.rs index 8fd8944a..357a5b81 100644 --- a/src/http/websocket.rs +++ b/src/http/websocket.rs @@ -17,6 +17,9 @@ use serde::{Deserialize, Serialize}; use crate::{Data, Error, ObjectType, Request, Response, Result, Schema, SubscriptionType}; +/// All known protocols based on WebSocket. +pub const ALL_WEBSOCKET_PROTOCOLS: [&str; 2] = ["graphql-transport-ws", "graphql-ws"]; + /// An enum representing the various forms of a WebSocket message. #[derive(Clone, Debug)] pub enum WsMessage { @@ -296,7 +299,7 @@ where } /// Specification of which GraphQL Over WebSockets protocol is being utilized -#[derive(Copy, Clone)] +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] pub enum Protocols { /// [subscriptions-transport-ws protocol](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md). SubscriptionsTransportWS, @@ -306,7 +309,7 @@ pub enum Protocols { impl Protocols { /// Returns the `Sec-WebSocket-Protocol` header value for the protocol - pub fn sec_websocket_protocol(&self) -> &str { + pub fn sec_websocket_protocol(&self) -> &'static str { match self { Protocols::SubscriptionsTransportWS => "graphql-ws", Protocols::GraphQLWS => "graphql-transport-ws",