Websocket transport creates context data from the connect_init.payload property

This commit is contained in:
sunli 2020-04-23 14:52:22 +08:00
parent 423603e212
commit 998e9b7b85
7 changed files with 139 additions and 139 deletions

View File

@ -2,12 +2,10 @@ use actix::{
Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture,
};
use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext};
use async_graphql::{Data, ObjectType, Schema, SubscriptionType, WebSocketTransport};
use async_graphql::{ObjectType, Schema, SubscriptionType, WebSocketTransport};
use bytes::Bytes;
use futures::channel::mpsc;
use futures::SinkExt;
use std::any::Any;
use std::sync::Arc;
use std::time::{Duration, Instant};
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
@ -18,7 +16,6 @@ pub struct WSSubscription<Query, Mutation, Subscription> {
schema: Schema<Query, Mutation, Subscription>,
hb: Instant,
sink: Option<mpsc::Sender<Bytes>>,
data: Data,
}
impl<Query, Mutation, Subscription> WSSubscription<Query, Mutation, Subscription>
@ -33,18 +30,9 @@ where
schema: schema.clone(),
hb: Instant::now(),
sink: None,
data: Default::default(),
}
}
/// Add a context data that can be accessed in the `Context`, you access it with `Context::data`.
///
/// **This data is only valid for this subscription**
pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
self.data.insert(data);
self
}
fn hb(&self, ctx: &mut WebsocketContext<Self>) {
ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT {
@ -66,10 +54,7 @@ where
fn started(&mut self, ctx: &mut Self::Context) {
self.hb(ctx);
let schema = self.schema.clone();
let (sink, stream) = schema.subscription_connection(
WebSocketTransport::default(),
Some(Arc::new(std::mem::take(&mut self.data))),
);
let (sink, stream) = schema.subscription_connection(WebSocketTransport::default());
ctx.add_stream(stream);
self.sink = Some(sink);
}

View File

@ -6,8 +6,8 @@
use async_graphql::http::StreamBody;
use async_graphql::{
Data, IntoQueryBuilder, IntoQueryBuilderOpts, ObjectType, QueryBuilder, Schema,
SubscriptionType, WebSocketTransport,
IntoQueryBuilder, IntoQueryBuilderOpts, ObjectType, QueryBuilder, Schema, SubscriptionType,
WebSocketTransport,
};
use bytes::Bytes;
use futures::select;
@ -145,30 +145,27 @@ where
/// #[tokio::main]
/// async fn main() {
/// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
/// let filter = async_graphql_warp::graphql_subscription(schema, None);
/// let filter = async_graphql_warp::graphql_subscription(schema);
/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
/// }
/// ```
pub fn graphql_subscription<Query, Mutation, Subscription>(
schema: Schema<Query, Mutation, Subscription>,
ctx_data: Option<Data>,
) -> BoxedFilter<(impl Reply,)>
where
Query: ObjectType + Sync + Send + 'static,
Mutation: ObjectType + Sync + Send + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
let ctx_data = ctx_data.map(Arc::new);
warp::any()
.and(warp::ws())
.and(warp::any().map(move || schema.clone()))
.and(warp::any().map(move || ctx_data.clone()))
.map(
|ws: warp::ws::Ws, schema: Schema<Query, Mutation, Subscription>, ctx_data: Option<Arc<Data>>| {
|ws: warp::ws::Ws, schema: Schema<Query, Mutation, Subscription>| {
ws.on_upgrade(move |websocket| {
let (mut tx, rx) = websocket.split();
let (mut stx, srx) =
schema.subscription_connection(WebSocketTransport::default(),ctx_data.clone());
schema.subscription_connection(WebSocketTransport::default());
let mut rx = rx.fuse();
let mut srx = srx.fuse();

View File

@ -329,11 +329,16 @@ impl<'a, T> ContextBase<'a, T> {
/// Gets the global data defined in the `Context` or `Schema`.
pub fn data<D: Any + Send + Sync>(&self) -> &D {
self.data_opt::<D>()
.expect("The specified data type does not exist.")
}
/// Gets the global data defined in the `Context` or `Schema`, returns `None` if the specified type data does not exist.
pub fn data_opt<D: Any + Send + Sync>(&self) -> Option<&D> {
self.ctx_data
.and_then(|ctx_data| ctx_data.0.get(&TypeId::of::<D>()))
.or_else(|| self.data.0.get(&TypeId::of::<D>()))
.and_then(|d| d.downcast_ref::<D>())
.expect("The specified data type does not exist.")
}
fn var_value(&self, name: &str, pos: Pos) -> Result<Value> {

View File

@ -289,11 +289,10 @@ where
pub fn subscription_connection<T: SubscriptionTransport>(
&self,
transport: T,
ctx_data: Option<Arc<Data>>,
) -> (
mpsc::Sender<Bytes>,
SubscriptionStream<Query, Mutation, Subscription, T>,
) {
create_connection(self.clone(), transport, ctx_data.unwrap_or_default())
create_connection(self.clone(), transport)
}
}

View File

@ -1,4 +1,3 @@
use crate::context::Data;
use crate::{ObjectType, Schema, SubscriptionType};
use bytes::Bytes;
use futures::channel::mpsc;
@ -7,7 +6,6 @@ use futures::Stream;
use slab::Slab;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
/// Use to hold all subscription stream for the `SubscriptionConnection`
pub struct SubscriptionStreams {
@ -44,7 +42,6 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static {
schema: &Schema<Query, Mutation, Subscription>,
streams: &mut SubscriptionStreams,
data: Bytes,
ctx_data: Arc<Data>,
) -> std::result::Result<Option<Bytes>, Self::Error>
where
Query: ObjectType + Sync + Send + 'static,
@ -58,7 +55,6 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static {
pub fn create_connection<Query, Mutation, Subscription, T: SubscriptionTransport>(
schema: Schema<Query, Mutation, Subscription>,
transport: T,
ctx_data: Arc<Data>,
) -> (
mpsc::Sender<Bytes>,
SubscriptionStream<Query, Mutation, Subscription, T>,
@ -73,7 +69,6 @@ where
tx_bytes,
SubscriptionStream {
schema,
ctx_data,
transport,
streams: SubscriptionStreams {
streams: Default::default(),
@ -84,19 +79,22 @@ where
)
}
type HandleRequestBoxFut<T> = Pin<
Box<
dyn Future<Output = std::result::Result<Option<Bytes>, <T as SubscriptionTransport>::Error>>
+ Send
+ 'static,
>,
>;
#[allow(missing_docs)]
#[allow(clippy::type_complexity)]
pub struct SubscriptionStream<Query, Mutation, Subscription, T: SubscriptionTransport> {
schema: Schema<Query, Mutation, Subscription>,
ctx_data: Arc<Data>,
transport: T,
streams: SubscriptionStreams,
rx_bytes: mpsc::Receiver<Bytes>,
handle_request_fut: Option<
Pin<
Box<dyn Future<Output = std::result::Result<Option<Bytes>, T::Error>> + Send + 'static>,
>,
>,
handle_request_fut: Option<HandleRequestBoxFut<T>>,
}
impl<Query, Mutation, Subscription, T> Stream
@ -133,13 +131,11 @@ where
let transport = &mut this.transport as *mut T;
let schema = &this.schema as *const Schema<Query, Mutation, Subscription>;
let streams = &mut this.streams as *mut SubscriptionStreams;
let ctx_data = this.ctx_data.clone();
unsafe {
this.handle_request_fut = Some(Box::pin((*transport).handle_request(
&*schema,
&mut *streams,
data,
ctx_data.clone(),
)));
}
continue;

View File

@ -25,6 +25,20 @@ struct OperationMessage {
pub struct WebSocketTransport {
id_to_sid: HashMap<String, usize>,
sid_to_id: HashMap<usize, String>,
data: Arc<Data>,
init_with_payload: Option<Box<dyn Fn(serde_json::Value) -> Data + Send + Sync>>,
}
impl WebSocketTransport {
/// Creates a websocket transport and sets the function that converts the `payload` of the `connect_init` message to `Data`.
pub fn new<F: Fn(serde_json::Value) -> Data + Send + Sync + 'static>(
init_with_payload: F,
) -> Self {
WebSocketTransport {
init_with_payload: Some(Box::new(init_with_payload)),
..WebSocketTransport::default()
}
}
}
#[async_trait::async_trait]
@ -36,7 +50,6 @@ impl SubscriptionTransport for WebSocketTransport {
schema: &Schema<Query, Mutation, Subscription>,
streams: &mut SubscriptionStreams,
data: Bytes,
ctx_data: Arc<Data>,
) -> std::result::Result<Option<Bytes>, Self::Error>
where
Query: ObjectType + Sync + Send + 'static,
@ -45,15 +58,22 @@ impl SubscriptionTransport for WebSocketTransport {
{
match serde_json::from_slice::<OperationMessage>(&data) {
Ok(msg) => match msg.ty.as_str() {
"connection_init" => Ok(Some(
serde_json::to_vec(&OperationMessage {
ty: "connection_ack".to_string(),
id: None,
payload: None,
})
.unwrap()
.into(),
)),
"connection_init" => {
if let Some(payload) = msg.payload {
if let Some(init_with_payload) = &self.init_with_payload {
self.data = Arc::new(init_with_payload(payload));
}
}
Ok(Some(
serde_json::to_vec(&OperationMessage {
ty: "connection_ack".to_string(),
id: None,
payload: None,
})
.unwrap()
.into(),
))
}
"start" => {
if let (Some(id), Some(payload)) = (msg.id, msg.payload) {
if let Ok(request) = serde_json::from_value::<GQLRequest>(payload) {
@ -67,7 +87,7 @@ impl SubscriptionTransport for WebSocketTransport {
&request.query,
request.operation_name.as_deref(),
variables,
Some(ctx_data),
Some(self.data.clone()),
)
.await
{
@ -106,7 +126,7 @@ impl SubscriptionTransport for WebSocketTransport {
Ok(None)
}
"connection_terminate" => Err("connection_terminate".to_string()),
_ => Err("unknown op".to_string()),
_ => Err("Unknown op".to_string()),
},
Err(err) => Err(err.to_string()),
}

View File

@ -270,100 +270,98 @@ pub async fn test_subscription_ws_transport() {
#[Subscription]
impl SubscriptionRoot {
#[field]
async fn values(&self, ctx: &Context<'_>) -> impl Stream<Item = i32> {
let step = *ctx.data::<i32>();
futures::stream::iter((0..10).map(move |n| n * step))
async fn values(&self) -> impl Stream<Item = i32> {
futures::stream::iter(0..10)
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, mut stream) = schema.subscription_connection(
WebSocketTransport::default(),
Some(Arc::new({
let mut data = Data::default();
data.insert(5);
data
})),
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init",
}))
.unwrap()
.into(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap()
.into(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "values": i * 5 } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}
}
#[async_std::test]
pub async fn test_subscription_ws_transport_with_token() {
struct Token(String);
struct QueryRoot;
#[Object]
impl QueryRoot {}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
#[field]
async fn values(&self, ctx: &Context<'_>) -> FieldResult<impl Stream<Item = i32>> {
if ctx.data::<Token>().0 != "123456" {
return Err("forbidden".into());
}
Ok(futures::stream::iter(0..10))
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, mut stream) = schema.subscription_connection(
WebSocketTransport::default(),
Some(Arc::new({
let mut data = Data::default();
data.insert(Token("123456".to_string()));
data
})),
);
let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::default());
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init",
"payload": { "token": "123456" }
}))
.unwrap()
.into(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap()
.into(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}
}
#[async_std::test]
pub async fn test_subscription_ws_transport_with_token() {
struct Token(String);
struct QueryRoot;
#[Object]
impl QueryRoot {}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
#[field]
async fn values(&self, ctx: &Context<'_>) -> FieldResult<impl Stream<Item = i32>> {
if ctx.data::<Token>().0 != "123456" {
return Err("forbidden".into());
}
Ok(futures::stream::iter(0..10))
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::new(|value| {
#[derive(serde_derive::Deserialize)]
struct Payload {
token: String,
}
let payload: Payload = serde_json::from_value(value).unwrap();
let mut data = Data::default();
data.insert(Token(payload.token));
data
}));
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init",
"payload": { "token": "123456" }
}))
.unwrap()
.into(),