Re-enable the test on websocket.

This commit is contained in:
Sunli 2020-09-11 16:41:56 +08:00
parent c129079f69
commit a4fdf6a38a
3 changed files with 333 additions and 297 deletions

View File

@ -453,8 +453,11 @@ where
resp
}
/// Execute an GraphQL subscription.
pub fn execute_stream(&self, request: impl Into<Request>) -> impl Stream<Item = Response> {
pub(crate) fn execute_stream_with_ctx_data(
&self,
request: impl Into<Request>,
ctx_data: Arc<Data>,
) -> impl Stream<Item = Response> {
let schema = self.clone();
async_stream::stream! {
let request = request.into();
@ -480,7 +483,7 @@ where
extensions,
request.variables,
document,
Arc::new(request.ctx_data),
ctx_data,
);
let ctx = env.create_context(
@ -507,4 +510,11 @@ where
}
}
}
/// Execute an GraphQL subscription.
pub fn execute_stream(&self, request: impl Into<Request>) -> impl Stream<Item = Response> {
let mut request = request.into();
let ctx_data = std::mem::replace(&mut request.ctx_data, Default::default());
self.execute_stream_with_ctx_data(request, Arc::new(ctx_data))
}
}

View File

@ -1,6 +1,6 @@
//! WebSocket transport for subscription
use crate::{http, Data, FieldResult, ObjectType, Response, Schema, SubscriptionType};
use crate::{http, Data, FieldResult, ObjectType, Request, Response, Schema, SubscriptionType};
use futures::channel::mpsc;
use futures::task::{Context, Poll};
use futures::{Future, Stream, StreamExt};
@ -181,8 +181,11 @@ where
payload: None,
},
);
break;
}
Poll::Pending => {
break;
}
Poll::Pending => break,
}
}
}
@ -191,6 +194,10 @@ where
ctx.streams.remove(&id);
}
}
if !ctx.send_buf.is_empty() {
continue;
}
}
return Poll::Pending;
@ -227,7 +234,10 @@ where
"start" => {
if let (Some(id), Some(payload)) = (msg.id, msg.payload) {
if let Ok(request) = serde_json::from_value::<http::GQLRequest>(payload) {
let stream = schema.execute_stream(request).boxed();
let request = Request::from(request);
let stream = schema
.execute_stream_with_ctx_data(request, ctx.ctx_data.clone())
.boxed();
ctx.streams.insert(id, stream);
}
}

View File

@ -1,291 +1,307 @@
// use async_graphql::*;
// use futures::{SinkExt, Stream, StreamExt};
//
// #[async_std::test]
// pub async fn test_subscription_ws_transport() {
// struct QueryRoot;
//
// #[Object]
// impl QueryRoot {}
//
// struct SubscriptionRoot;
//
// #[Subscription]
// impl SubscriptionRoot {
// async fn values(&self) -> impl Stream<Item = i32> {
// futures::stream::iter(0..10)
// }
// }
//
// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
// let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::default());
//
// sink.send(
// serde_json::to_vec(&serde_json::json!({
// "type": "connection_init",
// "payload": { "token": "123456" }
// }))
// .unwrap(),
// )
// .await
// .unwrap();
//
// assert_eq!(
// Some(serde_json::json!({
// "type": "connection_ack",
// })),
// serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
// );
//
// sink.send(
// serde_json::to_vec(&serde_json::json!({
// "type": "start",
// "id": "1",
// "payload": {
// "query": "subscription { values }"
// },
// }))
// .unwrap(),
// )
// .await
// .unwrap();
//
// for i in 0..10 {
// assert_eq!(
// Some(serde_json::json!({
// "type": "data",
// "id": "1",
// "payload": { "data": { "values": i } },
// })),
// serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
// );
// }
// }
//
// #[async_std::test]
// pub async fn test_subscription_ws_transport_with_token() {
// struct Token(String);
//
// struct QueryRoot;
//
// #[Object]
// impl QueryRoot {}
//
// struct SubscriptionRoot;
//
// #[Subscription]
// impl SubscriptionRoot {
// async fn values(&self, ctx: &Context<'_>) -> FieldResult<impl Stream<Item = i32>> {
// if ctx.data_unchecked::<Token>().0 != "123456" {
// return Err("forbidden".into());
// }
// Ok(futures::stream::iter(0..10))
// }
// }
//
// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
//
// let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::new(|value| {
// #[derive(serde::Deserialize)]
// struct Payload {
// token: String,
// }
//
// let payload: Payload = serde_json::from_value(value).unwrap();
// let mut data = Data::default();
// data.insert(Token(payload.token));
// Ok(data)
// }));
//
// sink.send(
// serde_json::to_vec(&serde_json::json!({
// "type": "connection_init",
// "payload": { "token": "123456" }
// }))
// .unwrap(),
// )
// .await
// .unwrap();
//
// assert_eq!(
// Some(serde_json::json!({
// "type": "connection_ack",
// })),
// serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
// );
//
// sink.send(
// serde_json::to_vec(&serde_json::json!({
// "type": "start",
// "id": "1",
// "payload": {
// "query": "subscription { values }"
// },
// }))
// .unwrap(),
// )
// .await
// .unwrap();
//
// for i in 0..10 {
// assert_eq!(
// Some(serde_json::json!({
// "type": "data",
// "id": "1",
// "payload": { "data": { "values": i } },
// })),
// serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
// );
// }
// }
//
// #[async_std::test]
// pub async fn test_subscription_ws_transport_error() {
// struct QueryRoot;
//
// struct Event {
// value: i32,
// }
//
// #[Object]
// impl Event {
// async fn value(&self) -> FieldResult<i32> {
// if self.value < 5 {
// Ok(self.value)
// } else {
// Err("TestError".into())
// }
// }
// }
//
// #[Object]
// impl QueryRoot {}
//
// struct SubscriptionRoot;
//
// #[Subscription]
// impl SubscriptionRoot {
// async fn events(&self) -> impl Stream<Item = Event> {
// futures::stream::iter((0..10).map(|n| Event { value: n }))
// }
// }
//
// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
//
// let (mut sink, mut stream) =
// schema.subscription_connection(WebSocketTransport::new(|_| Ok(Data::default())));
//
// sink.send(
// serde_json::to_vec(&serde_json::json!({
// "type": "connection_init"
// }))
// .unwrap(),
// )
// .await
// .unwrap();
//
// assert_eq!(
// Some(serde_json::json!({
// "type": "connection_ack",
// })),
// serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
// );
//
// sink.send(
// serde_json::to_vec(&serde_json::json!({
// "type": "start",
// "id": "1",
// "payload": {
// "query": "subscription { events { value } }"
// },
// }))
// .unwrap(),
// )
// .await
// .unwrap();
//
// for i in 0i32..5 {
// assert_eq!(
// Some(serde_json::json!({
// "type": "data",
// "id": "1",
// "payload": { "data": { "events": { "value": i } } },
// })),
// serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
// );
// }
//
// assert_eq!(
// Some(serde_json::json!({
// "type": "error",
// "id": "1",
// "payload": [{
// "message": "TestError",
// "locations": [{"line": 1, "column": 25}],
// "path": ["events", "value"],
// }],
// })),
// serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
// );
// }
//
// #[async_std::test]
// pub async fn test_query_over_websocket() {
// struct QueryRoot;
//
// #[Object]
// impl QueryRoot {
// async fn value(&self) -> i32 {
// 999
// }
// }
//
// let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription);
// let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::default());
//
// sink.send(
// serde_json::to_vec(&serde_json::json!({
// "type": "connection_init",
// }))
// .unwrap(),
// )
// .await
// .unwrap();
//
// assert_eq!(
// Some(serde_json::json!({
// "type": "connection_ack",
// })),
// serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
// );
//
// sink.send(
// serde_json::to_vec(&serde_json::json!({
// "type": "start",
// "id": "1",
// "payload": {
// "query": "query { value }"
// },
// }))
// .unwrap(),
// )
// .await
// .unwrap();
//
// assert_eq!(
// Some(serde_json::json!({
// "type": "data",
// "id": "1",
// "payload": { "data": { "value": 999 } },
// })),
// serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
// );
//
// assert_eq!(
// Some(serde_json::json!({
// "type": "complete",
// "id": "1",
// })),
// serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
// );
// }
use async_graphql::*;
use futures::{SinkExt, Stream, StreamExt};
#[async_std::test]
pub async fn test_subscription_ws_transport() {
struct QueryRoot;
#[Object]
impl QueryRoot {}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn values(&self) -> impl Stream<Item = i32> {
futures::stream::iter(0..10)
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, stream) = transports::websocket::create(&schema);
futures::pin_mut!(stream);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init",
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}
assert_eq!(
Some(serde_json::json!({
"type": "complete",
"id": "1",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}
#[async_std::test]
pub async fn test_subscription_ws_transport_with_token() {
struct Token(String);
struct QueryRoot;
#[Object]
impl QueryRoot {}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn values(&self, ctx: &Context<'_>) -> FieldResult<impl Stream<Item = i32>> {
if ctx.data_unchecked::<Token>().0 != "123456" {
return Err("forbidden".into());
}
Ok(futures::stream::iter(0..10))
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, stream) = transports::websocket::create_with_initializer(&schema, |value| {
#[derive(serde::Deserialize)]
struct Payload {
token: String,
}
let payload: Payload = serde_json::from_value(value).unwrap();
let mut data = Data::default();
data.insert(Token(payload.token));
Ok(data)
});
futures::pin_mut!(stream);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init",
"payload": { "token": "123456" }
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}
assert_eq!(
Some(serde_json::json!({
"type": "complete",
"id": "1",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}
#[async_std::test]
pub async fn test_subscription_ws_transport_error() {
struct QueryRoot;
struct Event {
value: i32,
}
#[Object]
impl Event {
async fn value(&self) -> FieldResult<i32> {
if self.value < 5 {
Ok(self.value)
} else {
Err("TestError".into())
}
}
}
#[Object]
impl QueryRoot {}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn events(&self) -> impl Stream<Item = Event> {
futures::stream::iter((0..10).map(|n| Event { value: n }))
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, stream) = transports::websocket::create(&schema);
futures::pin_mut!(stream);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init"
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { events { value } }"
},
}))
.unwrap(),
)
.await
.unwrap();
for i in 0i32..5 {
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "events": { "value": i } } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}
assert_eq!(
Some(serde_json::json!({
"type": "error",
"id": "1",
"payload": [{
"message": "TestError",
"locations": [{"line": 1, "column": 25}],
"path": ["events", "value"],
}],
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}
#[async_std::test]
pub async fn test_query_over_websocket() {
struct QueryRoot;
#[Object]
impl QueryRoot {
async fn value(&self) -> i32 {
999
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription);
let (mut sink, stream) = transports::websocket::create(&schema);
futures::pin_mut!(stream);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init",
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "query { value }"
},
}))
.unwrap(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "value": 999 } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
assert_eq!(
Some(serde_json::json!({
"type": "complete",
"id": "1",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}