Update async-graphql-actix-web

This commit is contained in:
Sunli 2021-11-12 21:24:24 +08:00
parent 111923ac0f
commit 32ba455fb2
4 changed files with 131 additions and 125 deletions

View File

@ -22,6 +22,7 @@ futures-util = { version = "0.3.17", default-features = false }
serde_json = "1.0.64" serde_json = "1.0.64"
serde_urlencoded = "0.7.0" serde_urlencoded = "0.7.0"
futures-channel = "0.3.13" futures-channel = "0.3.13"
thiserror = "1.0.30"
[dev-dependencies] [dev-dependencies]
actix-rt = "2.2.0" actix-rt = "2.2.0"

View File

@ -3,6 +3,8 @@
#![allow(clippy::upper_case_acronyms)] #![allow(clippy::upper_case_acronyms)]
#![warn(missing_docs)] #![warn(missing_docs)]
mod subscription;
use std::future::Future; use std::future::Future;
use std::io::{self, ErrorKind}; use std::io::{self, ErrorKind};
use std::pin::Pin; use std::pin::Pin;
@ -16,16 +18,14 @@ use futures_util::{StreamExt, TryStreamExt};
use async_graphql::http::MultipartOptions; use async_graphql::http::MultipartOptions;
use async_graphql::ParseRequestError; use async_graphql::ParseRequestError;
pub use subscription::WSSubscription; pub use subscription::GraphQLSubscription;
mod subscription;
/// Extractor for GraphQL request. /// Extractor for GraphQL request.
/// ///
/// `async_graphql::http::MultipartOptions` allows to configure extraction process. /// `async_graphql::http::MultipartOptions` allows to configure extraction process.
pub struct Request(pub async_graphql::Request); pub struct GraphQLRequest(pub async_graphql::Request);
impl Request { impl GraphQLRequest {
/// Unwraps the value to `async_graphql::Request`. /// Unwraps the value to `async_graphql::Request`.
#[must_use] #[must_use]
pub fn into_inner(self) -> async_graphql::Request { pub fn into_inner(self) -> async_graphql::Request {
@ -34,14 +34,14 @@ impl Request {
} }
type BatchToRequestMapper = type BatchToRequestMapper =
fn(<<BatchRequest as FromRequest>::Future as Future>::Output) -> Result<Request>; fn(<<GraphQLBatchRequest as FromRequest>::Future as Future>::Output) -> Result<GraphQLRequest>;
impl FromRequest for Request { impl FromRequest for GraphQLRequest {
type Error = Error; type Error = Error;
type Future = future::Map<<BatchRequest as FromRequest>::Future, BatchToRequestMapper>; type Future = future::Map<<GraphQLBatchRequest as FromRequest>::Future, BatchToRequestMapper>;
fn from_request(req: &HttpRequest, payload: &mut Payload<PayloadStream>) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut Payload<PayloadStream>) -> Self::Future {
BatchRequest::from_request(req, payload).map(|res| { GraphQLBatchRequest::from_request(req, payload).map(|res| {
Ok(Self( Ok(Self(
res?.0 res?.0
.into_single() .into_single()
@ -54,9 +54,9 @@ impl FromRequest for Request {
/// Extractor for GraphQL batch request. /// Extractor for GraphQL batch request.
/// ///
/// `async_graphql::http::MultipartOptions` allows to configure extraction process. /// `async_graphql::http::MultipartOptions` allows to configure extraction process.
pub struct BatchRequest(pub async_graphql::BatchRequest); pub struct GraphQLBatchRequest(pub async_graphql::BatchRequest);
impl BatchRequest { impl GraphQLBatchRequest {
/// Unwraps the value to `async_graphql::BatchRequest`. /// Unwraps the value to `async_graphql::BatchRequest`.
#[must_use] #[must_use]
pub fn into_inner(self) -> async_graphql::BatchRequest { pub fn into_inner(self) -> async_graphql::BatchRequest {
@ -64,9 +64,9 @@ impl BatchRequest {
} }
} }
impl FromRequest for BatchRequest { impl FromRequest for GraphQLBatchRequest {
type Error = Error; type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<BatchRequest>>>>; type Future = Pin<Box<dyn Future<Output = Result<GraphQLBatchRequest>>>>;
fn from_request(req: &HttpRequest, payload: &mut Payload<PayloadStream>) -> Self::Future { fn from_request(req: &HttpRequest, payload: &mut Payload<PayloadStream>) -> Self::Future {
let config = req let config = req
@ -97,7 +97,7 @@ impl FromRequest for BatchRequest {
}); });
Box::pin(async move { Box::pin(async move {
Ok(BatchRequest( Ok(GraphQLBatchRequest(
async_graphql::http::receive_batch_body( async_graphql::http::receive_batch_body(
content_type, content_type,
rx.map_err(|e| match e { rx.map_err(|e| match e {
@ -146,21 +146,21 @@ impl FromRequest for BatchRequest {
/// ///
/// This contains a batch response, but since regular responses are a type of batch response it /// This contains a batch response, but since regular responses are a type of batch response it
/// works for both. /// works for both.
pub struct Response(pub async_graphql::BatchResponse); pub struct GraphQLResponse(pub async_graphql::BatchResponse);
impl From<async_graphql::Response> for Response { impl From<async_graphql::Response> for GraphQLResponse {
fn from(resp: async_graphql::Response) -> Self { fn from(resp: async_graphql::Response) -> Self {
Self(resp.into()) Self(resp.into())
} }
} }
impl From<async_graphql::BatchResponse> for Response { impl From<async_graphql::BatchResponse> for GraphQLResponse {
fn from(resp: async_graphql::BatchResponse) -> Self { fn from(resp: async_graphql::BatchResponse) -> Self {
Self(resp) Self(resp)
} }
} }
impl Responder for Response { impl Responder for GraphQLResponse {
fn respond_to(self, _req: &HttpRequest) -> HttpResponse { fn respond_to(self, _req: &HttpRequest) -> HttpResponse {
let mut res = HttpResponse::build(StatusCode::OK); let mut res = HttpResponse::build(StatusCode::OK);
res.content_type("application/json"); res.content_type("application/json");
@ -174,4 +174,4 @@ impl Responder for Response {
} }
res.body(serde_json::to_string(&self.0).unwrap()) res.body(serde_json::to_string(&self.0).unwrap())
} }
} }

View File

@ -7,9 +7,9 @@ use actix::{
}; };
use actix::{ActorFutureExt, ActorStreamExt}; use actix::{ActorFutureExt, ActorStreamExt};
use actix_http::error::PayloadError; use actix_http::error::PayloadError;
use actix_http::{ws, Payload, StatusCode}; use actix_http::ws;
use actix_web::web::Bytes; use actix_web::web::Bytes;
use actix_web::{http, Error, FromRequest, HttpRequest, HttpResponse}; use actix_web::{Error, HttpRequest, HttpResponse};
use actix_web_actors::ws::{CloseReason, Message, ProtocolError, WebsocketContext}; use actix_web_actors::ws::{CloseReason, Message, ProtocolError, WebsocketContext};
use futures_util::future::Ready; use futures_util::future::Ready;
use futures_util::stream::Stream; use futures_util::stream::Stream;
@ -20,86 +20,80 @@ use async_graphql::{Data, ObjectType, Result, Schema, SubscriptionType};
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5); const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
const CLIENT_TIMEOUT: Duration = Duration::from_secs(10); const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
/// A GraphQL protocol extractor. #[derive(thiserror::Error, Debug)]
/// #[error("failed to parse graphql protocol")]
/// It extract GraphQL protocol from `SEC_WEBSOCKET_PROTOCOL` header. pub struct ParseGraphQLProtocolError;
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub struct GraphQLProtocol(WebSocketProtocols);
#[async_trait::async_trait] type DefaultOnConnInitType = fn(serde_json::Value) -> Ready<async_graphql::Result<Data>>;
impl FromRequest for GraphQLProtocol {
type Error = Error;
type Future = Ready<Self>;
fn from_request(req: &HttpRequest, payload: &mut Payload) -> Self::Future { fn default_on_connection_init(_: serde_json::Value) -> Ready<async_graphql::Result<Data>> {
Ready( futures_util::future::ready(Ok(Data::default()))
req.headers()
.get(http::header::SEC_WEBSOCKET_PROTOCOL)
.and_then(|value| value.to_str().ok())
.and_then(|protocols| {
protocols
.split(',')
.find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
})
.map(Self)
.ok_or_else(|| Error::),
)
}
} }
/// Actor for subscription via websocket /// A builder for websocket subscription actor.
pub struct WSSubscription<Query, Mutation, Subscription, F> { pub struct GraphQLSubscription<Query, Mutation, Subscription, OnInit> {
schema: Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
protocol: WebSocketProtocols, data: Data,
last_heartbeat: Instant, on_connection_init: OnInit,
messages: Option<async_channel::Sender<Vec<u8>>>,
initializer: Option<F>,
continuation: Vec<u8>,
} }
impl<Query, Mutation, Subscription> impl<Query, Mutation, Subscription>
WSSubscription<Query, Mutation, Subscription, fn(serde_json::Value) -> Ready<Result<Data>>> GraphQLSubscription<Query, Mutation, Subscription, DefaultOnConnInitType>
where where
Query: ObjectType + 'static, Query: ObjectType + 'static,
Mutation: ObjectType + 'static, Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static, Subscription: SubscriptionType + 'static,
{ {
/// Start an actor for subscription connection via websocket. /// Create a GraphQL subscription builder.
pub fn start<T>( pub fn new(schema: Schema<Query, Mutation, Subscription>) -> Self {
schema: Schema<Query, Mutation, Subscription>, Self {
request: &HttpRequest, schema,
stream: T, data: Default::default(),
) -> Result<HttpResponse, actix_web::error::Error> on_connection_init: default_on_connection_init,
where }
T: Stream<Item = Result<Bytes, PayloadError>> + 'static,
{
Self::start_with_initializer(schema, request, stream, |_| {
futures_util::future::ready(Ok(Default::default()))
})
} }
} }
impl<Query, Mutation, Subscription, F, R> WSSubscription<Query, Mutation, Subscription, F> impl<Query, Mutation, Subscription, OnInit, OnInitFut>
GraphQLSubscription<Query, Mutation, Subscription, OnInit>
where where
Query: ObjectType + 'static, Query: ObjectType + 'static,
Mutation: ObjectType + 'static, Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static, Subscription: SubscriptionType + 'static,
F: FnOnce(serde_json::Value) -> R + Unpin + Send + 'static, OnInit: Fn(serde_json::Value) -> OnInitFut + Unpin + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static, OnInitFut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
{ {
/// Start an actor for subscription connection via websocket with an initialization function. /// Specify the initial subscription context data, usually you can get something from the
pub fn start_with_initializer<T>( /// incoming request to create it.
schema: Schema<Query, Mutation, Subscription>, pub fn with_data(self, data: Data) -> Self {
request: &HttpRequest, Self { data, ..self }
stream: T, }
initializer: F,
) -> Result<HttpResponse, actix_web::error::Error> /// Specify a callback function to be called when the connection is initialized.
///
/// You can get something from the payload of [`GQL_CONNECTION_INIT` message](https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md#gql_connection_init) to create [`Data`].
/// The data returned by this callback function will be merged with the data specified by [`with_data`].
pub fn on_connection_init<OnConnInit2, Fut>(
self,
callback: OnConnInit2,
) -> GraphQLSubscription<Query, Mutation, Subscription, OnConnInit2>
where where
T: Stream<Item = Result<Bytes, PayloadError>> + 'static, OnConnInit2: Fn(serde_json::Value) -> Fut + Unpin + Send + 'static,
F: FnOnce(serde_json::Value) -> R + Unpin + Send + 'static, Fut: Future<Output = async_graphql::Result<Data>> + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static,
{ {
let protocol = match request GraphQLSubscription {
schema: self.schema,
data: self.data,
on_connection_init: callback,
}
}
/// Start the subscription actor.
pub fn start<S>(self, request: &HttpRequest, stream: S) -> Result<HttpResponse, Error>
where
S: Stream<Item = Result<Bytes, PayloadError>> + 'static,
{
let protocol = request
.headers() .headers()
.get("sec-websocket-protocol") .get("sec-websocket-protocol")
.and_then(|value| value.to_str().ok()) .and_then(|value| value.to_str().ok())
@ -107,29 +101,42 @@ where
protocols protocols
.split(',') .split(',')
.find_map(|p| WebSocketProtocols::from_str(p.trim()).ok()) .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
}) { })
Some(protocol) => protocol, .ok_or_else(|| actix_web::error::ErrorBadRequest(ParseGraphQLProtocolError))?;
None => {
// default to the prior standard let actor = GraphQLSubscriptionActor {
WebSocketProtocols::SubscriptionsTransportWS schema: self.schema,
} data: Some(self.data),
protocol,
last_heartbeat: Instant::now(),
messages: None,
on_connection_init: Some(self.on_connection_init),
continuation: Vec::new(),
}; };
actix_web_actors::ws::start_with_protocols( actix_web_actors::ws::start_with_protocols(actor, &ALL_WEBSOCKET_PROTOCOLS, request, stream)
Self {
schema,
protocol,
last_heartbeat: Instant::now(),
messages: None,
initializer: Some(initializer),
continuation: Vec::new(),
},
&ALL_WEBSOCKET_PROTOCOLS,
request,
stream,
)
} }
}
struct GraphQLSubscriptionActor<Query, Mutation, Subscription, OnInit> {
schema: Schema<Query, Mutation, Subscription>,
data: Option<Data>,
protocol: WebSocketProtocols,
last_heartbeat: Instant,
messages: Option<async_channel::Sender<Vec<u8>>>,
on_connection_init: Option<OnInit>,
continuation: Vec<u8>,
}
impl<Query, Mutation, Subscription, OnInit, OnInitFut>
GraphQLSubscriptionActor<Query, Mutation, Subscription, OnInit>
where
Query: ObjectType + 'static,
Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static,
OnInit: FnOnce(serde_json::Value) -> OnInitFut + Unpin + Send + 'static,
OnInitFut: Future<Output = Result<Data>> + Send + 'static,
{
fn send_heartbeats(&self, ctx: &mut WebsocketContext<Self>) { fn send_heartbeats(&self, ctx: &mut WebsocketContext<Self>) {
ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| { ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
if Instant::now().duration_since(act.last_heartbeat) > CLIENT_TIMEOUT { if Instant::now().duration_since(act.last_heartbeat) > CLIENT_TIMEOUT {
@ -140,13 +147,14 @@ where
} }
} }
impl<Query, Mutation, Subscription, F, R> Actor for WSSubscription<Query, Mutation, Subscription, F> impl<Query, Mutation, Subscription, OnInit, OnInitFut> Actor
for GraphQLSubscriptionActor<Query, Mutation, Subscription, OnInit>
where where
Query: ObjectType + 'static, Query: ObjectType + 'static,
Mutation: ObjectType + 'static, Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static, Subscription: SubscriptionType + 'static,
F: FnOnce(serde_json::Value) -> R + Unpin + Send + 'static, OnInit: FnOnce(serde_json::Value) -> OnInitFut + Unpin + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static, OnInitFut: Future<Output = Result<Data>> + Send + 'static,
{ {
type Context = WebsocketContext<Self>; type Context = WebsocketContext<Self>;
@ -155,35 +163,32 @@ where
let (tx, rx) = async_channel::unbounded(); let (tx, rx) = async_channel::unbounded();
WebSocket::with_data( WebSocket::new(self.schema.clone(), rx, self.protocol)
self.schema.clone(), .connection_data(self.data.take().unwrap())
rx, .on_connection_init(self.on_connection_init.take().unwrap())
self.initializer.take().unwrap(), .into_actor(self)
self.protocol, .map(|response, _act, ctx| match response {
) WsMessage::Text(text) => ctx.text(text),
.into_actor(self) WsMessage::Close(code, msg) => ctx.close(Some(CloseReason {
.map(|response, _act, ctx| match response { code: code.into(),
WsMessage::Text(text) => ctx.text(text), description: Some(msg),
WsMessage::Close(code, msg) => ctx.close(Some(CloseReason { })),
code: code.into(), })
description: Some(msg), .finish()
})), .spawn(ctx);
})
.finish()
.spawn(ctx);
self.messages = Some(tx); self.messages = Some(tx);
} }
} }
impl<Query, Mutation, Subscription, F, R> StreamHandler<Result<Message, ProtocolError>> impl<Query, Mutation, Subscription, OnInit, OnInitFut> StreamHandler<Result<Message, ProtocolError>>
for WSSubscription<Query, Mutation, Subscription, F> for GraphQLSubscriptionActor<Query, Mutation, Subscription, OnInit>
where where
Query: ObjectType + 'static, Query: ObjectType + 'static,
Mutation: ObjectType + 'static, Mutation: ObjectType + 'static,
Subscription: SubscriptionType + 'static, Subscription: SubscriptionType + 'static,
F: FnOnce(serde_json::Value) -> R + Unpin + Send + 'static, OnInit: FnOnce(serde_json::Value) -> OnInitFut + Unpin + Send + 'static,
R: Future<Output = Result<Data>> + Send + 'static, OnInitFut: Future<Output = Result<Data>> + Send + 'static,
{ {
fn handle(&mut self, msg: Result<Message, ProtocolError>, ctx: &mut Self::Context) { fn handle(&mut self, msg: Result<Message, ProtocolError>, ctx: &mut Self::Context) {
let msg = match msg { let msg = match msg {

View File

@ -3,7 +3,7 @@ use async_graphql::http::{playground_source, GraphQLPlaygroundConfig};
use async_graphql::{ use async_graphql::{
Context, EmptyMutation, EmptySubscription, Object, ObjectType, Schema, SubscriptionType, Context, EmptyMutation, EmptySubscription, Object, ObjectType, Schema, SubscriptionType,
}; };
use async_graphql_actix_web::{Request, Response}; use async_graphql_actix_web::{GraphQLRequest, GraphQLResponse};
use async_mutex::Mutex; use async_mutex::Mutex;
pub async fn gql_playgound() -> HttpResponse { pub async fn gql_playgound() -> HttpResponse {
@ -69,16 +69,16 @@ pub async fn gql_handle_schema<
S: SubscriptionType + 'static, S: SubscriptionType + 'static,
>( >(
schema: web::Data<Schema<Q, M, S>>, schema: web::Data<Schema<Q, M, S>>,
req: Request, req: GraphQLRequest,
) -> Response { ) -> GraphQLResponse {
schema.execute(req.into_inner()).await.into() schema.execute(req.into_inner()).await.into()
} }
pub async fn gql_handle_schema_with_header<T: ObjectType + 'static>( pub async fn gql_handle_schema_with_header<T: ObjectType + 'static>(
schema: actix_web::web::Data<Schema<T, EmptyMutation, EmptySubscription>>, schema: actix_web::web::Data<Schema<T, EmptyMutation, EmptySubscription>>,
req: HttpRequest, req: HttpRequest,
gql_request: Request, gql_request: GraphQLRequest,
) -> Response { ) -> GraphQLResponse {
let name = req let name = req
.headers() .headers()
.get("Name") .get("Name")