Websocket transport creates context data from the connect_init.payload property
This commit is contained in:
parent
423603e212
commit
998e9b7b85
|
@ -2,12 +2,10 @@ use actix::{
|
|||
Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture,
|
||||
};
|
||||
use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext};
|
||||
use async_graphql::{Data, ObjectType, Schema, SubscriptionType, WebSocketTransport};
|
||||
use async_graphql::{ObjectType, Schema, SubscriptionType, WebSocketTransport};
|
||||
use bytes::Bytes;
|
||||
use futures::channel::mpsc;
|
||||
use futures::SinkExt;
|
||||
use std::any::Any;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
|
||||
|
@ -18,7 +16,6 @@ pub struct WSSubscription<Query, Mutation, Subscription> {
|
|||
schema: Schema<Query, Mutation, Subscription>,
|
||||
hb: Instant,
|
||||
sink: Option<mpsc::Sender<Bytes>>,
|
||||
data: Data,
|
||||
}
|
||||
|
||||
impl<Query, Mutation, Subscription> WSSubscription<Query, Mutation, Subscription>
|
||||
|
@ -33,18 +30,9 @@ where
|
|||
schema: schema.clone(),
|
||||
hb: Instant::now(),
|
||||
sink: None,
|
||||
data: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a context data that can be accessed in the `Context`, you access it with `Context::data`.
|
||||
///
|
||||
/// **This data is only valid for this subscription**
|
||||
pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
|
||||
self.data.insert(data);
|
||||
self
|
||||
}
|
||||
|
||||
fn hb(&self, ctx: &mut WebsocketContext<Self>) {
|
||||
ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
|
||||
if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT {
|
||||
|
@ -66,10 +54,7 @@ where
|
|||
fn started(&mut self, ctx: &mut Self::Context) {
|
||||
self.hb(ctx);
|
||||
let schema = self.schema.clone();
|
||||
let (sink, stream) = schema.subscription_connection(
|
||||
WebSocketTransport::default(),
|
||||
Some(Arc::new(std::mem::take(&mut self.data))),
|
||||
);
|
||||
let (sink, stream) = schema.subscription_connection(WebSocketTransport::default());
|
||||
ctx.add_stream(stream);
|
||||
self.sink = Some(sink);
|
||||
}
|
||||
|
|
|
@ -6,8 +6,8 @@
|
|||
|
||||
use async_graphql::http::StreamBody;
|
||||
use async_graphql::{
|
||||
Data, IntoQueryBuilder, IntoQueryBuilderOpts, ObjectType, QueryBuilder, Schema,
|
||||
SubscriptionType, WebSocketTransport,
|
||||
IntoQueryBuilder, IntoQueryBuilderOpts, ObjectType, QueryBuilder, Schema, SubscriptionType,
|
||||
WebSocketTransport,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures::select;
|
||||
|
@ -145,30 +145,27 @@ where
|
|||
/// #[tokio::main]
|
||||
/// async fn main() {
|
||||
/// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
|
||||
/// let filter = async_graphql_warp::graphql_subscription(schema, None);
|
||||
/// let filter = async_graphql_warp::graphql_subscription(schema);
|
||||
/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
|
||||
/// }
|
||||
/// ```
|
||||
pub fn graphql_subscription<Query, Mutation, Subscription>(
|
||||
schema: Schema<Query, Mutation, Subscription>,
|
||||
ctx_data: Option<Data>,
|
||||
) -> BoxedFilter<(impl Reply,)>
|
||||
where
|
||||
Query: ObjectType + Sync + Send + 'static,
|
||||
Mutation: ObjectType + Sync + Send + 'static,
|
||||
Subscription: SubscriptionType + Send + Sync + 'static,
|
||||
{
|
||||
let ctx_data = ctx_data.map(Arc::new);
|
||||
warp::any()
|
||||
.and(warp::ws())
|
||||
.and(warp::any().map(move || schema.clone()))
|
||||
.and(warp::any().map(move || ctx_data.clone()))
|
||||
.map(
|
||||
|ws: warp::ws::Ws, schema: Schema<Query, Mutation, Subscription>, ctx_data: Option<Arc<Data>>| {
|
||||
|ws: warp::ws::Ws, schema: Schema<Query, Mutation, Subscription>| {
|
||||
ws.on_upgrade(move |websocket| {
|
||||
let (mut tx, rx) = websocket.split();
|
||||
let (mut stx, srx) =
|
||||
schema.subscription_connection(WebSocketTransport::default(),ctx_data.clone());
|
||||
schema.subscription_connection(WebSocketTransport::default());
|
||||
|
||||
let mut rx = rx.fuse();
|
||||
let mut srx = srx.fuse();
|
||||
|
|
|
@ -329,11 +329,16 @@ impl<'a, T> ContextBase<'a, T> {
|
|||
|
||||
/// Gets the global data defined in the `Context` or `Schema`.
|
||||
pub fn data<D: Any + Send + Sync>(&self) -> &D {
|
||||
self.data_opt::<D>()
|
||||
.expect("The specified data type does not exist.")
|
||||
}
|
||||
|
||||
/// Gets the global data defined in the `Context` or `Schema`, returns `None` if the specified type data does not exist.
|
||||
pub fn data_opt<D: Any + Send + Sync>(&self) -> Option<&D> {
|
||||
self.ctx_data
|
||||
.and_then(|ctx_data| ctx_data.0.get(&TypeId::of::<D>()))
|
||||
.or_else(|| self.data.0.get(&TypeId::of::<D>()))
|
||||
.and_then(|d| d.downcast_ref::<D>())
|
||||
.expect("The specified data type does not exist.")
|
||||
}
|
||||
|
||||
fn var_value(&self, name: &str, pos: Pos) -> Result<Value> {
|
||||
|
|
|
@ -289,11 +289,10 @@ where
|
|||
pub fn subscription_connection<T: SubscriptionTransport>(
|
||||
&self,
|
||||
transport: T,
|
||||
ctx_data: Option<Arc<Data>>,
|
||||
) -> (
|
||||
mpsc::Sender<Bytes>,
|
||||
SubscriptionStream<Query, Mutation, Subscription, T>,
|
||||
) {
|
||||
create_connection(self.clone(), transport, ctx_data.unwrap_or_default())
|
||||
create_connection(self.clone(), transport)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use crate::context::Data;
|
||||
use crate::{ObjectType, Schema, SubscriptionType};
|
||||
use bytes::Bytes;
|
||||
use futures::channel::mpsc;
|
||||
|
@ -7,7 +6,6 @@ use futures::Stream;
|
|||
use slab::Slab;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Use to hold all subscription stream for the `SubscriptionConnection`
|
||||
pub struct SubscriptionStreams {
|
||||
|
@ -44,7 +42,6 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static {
|
|||
schema: &Schema<Query, Mutation, Subscription>,
|
||||
streams: &mut SubscriptionStreams,
|
||||
data: Bytes,
|
||||
ctx_data: Arc<Data>,
|
||||
) -> std::result::Result<Option<Bytes>, Self::Error>
|
||||
where
|
||||
Query: ObjectType + Sync + Send + 'static,
|
||||
|
@ -58,7 +55,6 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static {
|
|||
pub fn create_connection<Query, Mutation, Subscription, T: SubscriptionTransport>(
|
||||
schema: Schema<Query, Mutation, Subscription>,
|
||||
transport: T,
|
||||
ctx_data: Arc<Data>,
|
||||
) -> (
|
||||
mpsc::Sender<Bytes>,
|
||||
SubscriptionStream<Query, Mutation, Subscription, T>,
|
||||
|
@ -73,7 +69,6 @@ where
|
|||
tx_bytes,
|
||||
SubscriptionStream {
|
||||
schema,
|
||||
ctx_data,
|
||||
transport,
|
||||
streams: SubscriptionStreams {
|
||||
streams: Default::default(),
|
||||
|
@ -84,19 +79,22 @@ where
|
|||
)
|
||||
}
|
||||
|
||||
type HandleRequestBoxFut<T> = Pin<
|
||||
Box<
|
||||
dyn Future<Output = std::result::Result<Option<Bytes>, <T as SubscriptionTransport>::Error>>
|
||||
+ Send
|
||||
+ 'static,
|
||||
>,
|
||||
>;
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub struct SubscriptionStream<Query, Mutation, Subscription, T: SubscriptionTransport> {
|
||||
schema: Schema<Query, Mutation, Subscription>,
|
||||
ctx_data: Arc<Data>,
|
||||
transport: T,
|
||||
streams: SubscriptionStreams,
|
||||
rx_bytes: mpsc::Receiver<Bytes>,
|
||||
handle_request_fut: Option<
|
||||
Pin<
|
||||
Box<dyn Future<Output = std::result::Result<Option<Bytes>, T::Error>> + Send + 'static>,
|
||||
>,
|
||||
>,
|
||||
handle_request_fut: Option<HandleRequestBoxFut<T>>,
|
||||
}
|
||||
|
||||
impl<Query, Mutation, Subscription, T> Stream
|
||||
|
@ -133,13 +131,11 @@ where
|
|||
let transport = &mut this.transport as *mut T;
|
||||
let schema = &this.schema as *const Schema<Query, Mutation, Subscription>;
|
||||
let streams = &mut this.streams as *mut SubscriptionStreams;
|
||||
let ctx_data = this.ctx_data.clone();
|
||||
unsafe {
|
||||
this.handle_request_fut = Some(Box::pin((*transport).handle_request(
|
||||
&*schema,
|
||||
&mut *streams,
|
||||
data,
|
||||
ctx_data.clone(),
|
||||
)));
|
||||
}
|
||||
continue;
|
||||
|
|
|
@ -25,6 +25,20 @@ struct OperationMessage {
|
|||
pub struct WebSocketTransport {
|
||||
id_to_sid: HashMap<String, usize>,
|
||||
sid_to_id: HashMap<usize, String>,
|
||||
data: Arc<Data>,
|
||||
init_with_payload: Option<Box<dyn Fn(serde_json::Value) -> Data + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl WebSocketTransport {
|
||||
/// Creates a websocket transport and sets the function that converts the `payload` of the `connect_init` message to `Data`.
|
||||
pub fn new<F: Fn(serde_json::Value) -> Data + Send + Sync + 'static>(
|
||||
init_with_payload: F,
|
||||
) -> Self {
|
||||
WebSocketTransport {
|
||||
init_with_payload: Some(Box::new(init_with_payload)),
|
||||
..WebSocketTransport::default()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
|
@ -36,7 +50,6 @@ impl SubscriptionTransport for WebSocketTransport {
|
|||
schema: &Schema<Query, Mutation, Subscription>,
|
||||
streams: &mut SubscriptionStreams,
|
||||
data: Bytes,
|
||||
ctx_data: Arc<Data>,
|
||||
) -> std::result::Result<Option<Bytes>, Self::Error>
|
||||
where
|
||||
Query: ObjectType + Sync + Send + 'static,
|
||||
|
@ -45,15 +58,22 @@ impl SubscriptionTransport for WebSocketTransport {
|
|||
{
|
||||
match serde_json::from_slice::<OperationMessage>(&data) {
|
||||
Ok(msg) => match msg.ty.as_str() {
|
||||
"connection_init" => Ok(Some(
|
||||
serde_json::to_vec(&OperationMessage {
|
||||
ty: "connection_ack".to_string(),
|
||||
id: None,
|
||||
payload: None,
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
)),
|
||||
"connection_init" => {
|
||||
if let Some(payload) = msg.payload {
|
||||
if let Some(init_with_payload) = &self.init_with_payload {
|
||||
self.data = Arc::new(init_with_payload(payload));
|
||||
}
|
||||
}
|
||||
Ok(Some(
|
||||
serde_json::to_vec(&OperationMessage {
|
||||
ty: "connection_ack".to_string(),
|
||||
id: None,
|
||||
payload: None,
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
"start" => {
|
||||
if let (Some(id), Some(payload)) = (msg.id, msg.payload) {
|
||||
if let Ok(request) = serde_json::from_value::<GQLRequest>(payload) {
|
||||
|
@ -67,7 +87,7 @@ impl SubscriptionTransport for WebSocketTransport {
|
|||
&request.query,
|
||||
request.operation_name.as_deref(),
|
||||
variables,
|
||||
Some(ctx_data),
|
||||
Some(self.data.clone()),
|
||||
)
|
||||
.await
|
||||
{
|
||||
|
@ -106,7 +126,7 @@ impl SubscriptionTransport for WebSocketTransport {
|
|||
Ok(None)
|
||||
}
|
||||
"connection_terminate" => Err("connection_terminate".to_string()),
|
||||
_ => Err("unknown op".to_string()),
|
||||
_ => Err("Unknown op".to_string()),
|
||||
},
|
||||
Err(err) => Err(err.to_string()),
|
||||
}
|
||||
|
|
|
@ -270,100 +270,98 @@ pub async fn test_subscription_ws_transport() {
|
|||
#[Subscription]
|
||||
impl SubscriptionRoot {
|
||||
#[field]
|
||||
async fn values(&self, ctx: &Context<'_>) -> impl Stream<Item = i32> {
|
||||
let step = *ctx.data::<i32>();
|
||||
futures::stream::iter((0..10).map(move |n| n * step))
|
||||
async fn values(&self) -> impl Stream<Item = i32> {
|
||||
futures::stream::iter(0..10)
|
||||
}
|
||||
}
|
||||
|
||||
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
|
||||
let (mut sink, mut stream) = schema.subscription_connection(
|
||||
WebSocketTransport::default(),
|
||||
Some(Arc::new({
|
||||
let mut data = Data::default();
|
||||
data.insert(5);
|
||||
data
|
||||
})),
|
||||
);
|
||||
|
||||
sink.send(
|
||||
serde_json::to_vec(&serde_json::json!({
|
||||
"type": "connection_init",
|
||||
}))
|
||||
.unwrap()
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
"type": "connection_ack",
|
||||
})),
|
||||
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
|
||||
);
|
||||
|
||||
sink.send(
|
||||
serde_json::to_vec(&serde_json::json!({
|
||||
"type": "start",
|
||||
"id": "1",
|
||||
"payload": {
|
||||
"query": "subscription { values }"
|
||||
},
|
||||
}))
|
||||
.unwrap()
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
for i in 0..10 {
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
"type": "data",
|
||||
"id": "1",
|
||||
"payload": { "data": { "values": i * 5 } },
|
||||
})),
|
||||
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
pub async fn test_subscription_ws_transport_with_token() {
|
||||
struct Token(String);
|
||||
|
||||
struct QueryRoot;
|
||||
|
||||
#[Object]
|
||||
impl QueryRoot {}
|
||||
|
||||
struct SubscriptionRoot;
|
||||
|
||||
#[Subscription]
|
||||
impl SubscriptionRoot {
|
||||
#[field]
|
||||
async fn values(&self, ctx: &Context<'_>) -> FieldResult<impl Stream<Item = i32>> {
|
||||
if ctx.data::<Token>().0 != "123456" {
|
||||
return Err("forbidden".into());
|
||||
}
|
||||
Ok(futures::stream::iter(0..10))
|
||||
}
|
||||
}
|
||||
|
||||
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
|
||||
let (mut sink, mut stream) = schema.subscription_connection(
|
||||
WebSocketTransport::default(),
|
||||
Some(Arc::new({
|
||||
let mut data = Data::default();
|
||||
data.insert(Token("123456".to_string()));
|
||||
data
|
||||
})),
|
||||
);
|
||||
let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::default());
|
||||
|
||||
sink.send(
|
||||
serde_json::to_vec(&serde_json::json!({
|
||||
"type": "connection_init",
|
||||
"payload": { "token": "123456" }
|
||||
}))
|
||||
.unwrap()
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
"type": "connection_ack",
|
||||
})),
|
||||
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
|
||||
);
|
||||
|
||||
sink.send(
|
||||
serde_json::to_vec(&serde_json::json!({
|
||||
"type": "start",
|
||||
"id": "1",
|
||||
"payload": {
|
||||
"query": "subscription { values }"
|
||||
},
|
||||
}))
|
||||
.unwrap()
|
||||
.into(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
for i in 0..10 {
|
||||
assert_eq!(
|
||||
Some(serde_json::json!({
|
||||
"type": "data",
|
||||
"id": "1",
|
||||
"payload": { "data": { "values": i } },
|
||||
})),
|
||||
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[async_std::test]
|
||||
pub async fn test_subscription_ws_transport_with_token() {
|
||||
struct Token(String);
|
||||
|
||||
struct QueryRoot;
|
||||
|
||||
#[Object]
|
||||
impl QueryRoot {}
|
||||
|
||||
struct SubscriptionRoot;
|
||||
|
||||
#[Subscription]
|
||||
impl SubscriptionRoot {
|
||||
#[field]
|
||||
async fn values(&self, ctx: &Context<'_>) -> FieldResult<impl Stream<Item = i32>> {
|
||||
if ctx.data::<Token>().0 != "123456" {
|
||||
return Err("forbidden".into());
|
||||
}
|
||||
Ok(futures::stream::iter(0..10))
|
||||
}
|
||||
}
|
||||
|
||||
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
|
||||
|
||||
let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::new(|value| {
|
||||
#[derive(serde_derive::Deserialize)]
|
||||
struct Payload {
|
||||
token: String,
|
||||
}
|
||||
|
||||
let payload: Payload = serde_json::from_value(value).unwrap();
|
||||
let mut data = Data::default();
|
||||
data.insert(Token(payload.token));
|
||||
data
|
||||
}));
|
||||
|
||||
sink.send(
|
||||
serde_json::to_vec(&serde_json::json!({
|
||||
"type": "connection_init",
|
||||
"payload": { "token": "123456" }
|
||||
}))
|
||||
.unwrap()
|
||||
.into(),
|
||||
|
|
Loading…
Reference in New Issue
Block a user