Report subscription stream errors to the client.

This commit is contained in:
sunli 2020-05-03 16:02:46 +08:00
parent b62b8e34c8
commit 220cd1e775
9 changed files with 388 additions and 216 deletions

View File

@ -249,36 +249,37 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
create_stream.push(quote! {
if ctx.name.as_str() == #field_name {
let field_name = ctx.result_name().to_string();
use #crate_name::futures::stream::{StreamExt, TryStreamExt};
let field_name = std::sync::Arc::new(ctx.result_name().to_string());
#(#get_params)*
let field_selection_set = std::sync::Arc::new(ctx.selection_set.clone());
let schema = schema.clone();
let pos = ctx.position;
let environment = environment.clone();
let stream = #crate_name::futures::stream::StreamExt::then(#create_field_stream, move |msg| {
let environment = environment.clone();
let field_selection_set = field_selection_set.clone();
let schema = schema.clone();
async move {
let resolve_id = std::sync::atomic::AtomicUsize::default();
let ctx_selection_set = environment.create_context(
&schema,
Some(#crate_name::QueryPathNode {
parent: None,
segment: #crate_name::QueryPathSegment::Name("time"),
}),
&*field_selection_set,
&resolve_id,
);
#crate_name::OutputValueType::resolve(&msg, &ctx_selection_set, pos).await
let stream = #create_field_stream.then({
let field_name = field_name.clone();
move |msg| {
let environment = environment.clone();
let field_selection_set = field_selection_set.clone();
let schema = schema.clone();
let field_name = field_name.clone();
async move {
let resolve_id = std::sync::atomic::AtomicUsize::default();
let ctx_selection_set = environment.create_context(
&schema,
Some(#crate_name::QueryPathNode {
parent: None,
segment: #crate_name::QueryPathSegment::Name(&field_name),
}),
&*field_selection_set,
&resolve_id,
);
#crate_name::OutputValueType::resolve(&msg, &ctx_selection_set, pos).await
}
}
}).
filter_map(move |res| {
let res = res.ok().map(|value| {
#crate_name::serde_json::json!({ &field_name: value })
});
async move { res }
});
map_ok(move |value| #crate_name::serde_json::json!({ field_name.as_str(): value }));
return Ok(Box::pin(stream));
}
});
@ -329,7 +330,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
ctx: &#crate_name::Context<'_>,
schema: &#crate_name::Schema<Query, Mutation, Self>,
environment: std::sync::Arc<#crate_name::Environment>,
) -> #crate_name::Result<std::pin::Pin<Box<dyn #crate_name::futures::Stream<Item = #crate_name::serde_json::Value> + Send>>>
) -> #crate_name::Result<std::pin::Pin<Box<dyn #crate_name::futures::Stream<Item = #crate_name::Result<#crate_name::serde_json::Value>> + Send>>>
where
Query: #crate_name::ObjectType + Send + Sync + 'static,
Mutation: #crate_name::ObjectType + Send + Sync + 'static,

View File

@ -241,7 +241,7 @@ where
operation_name: Option<&str>,
variables: Variables,
ctx_data: Option<Arc<Data>>,
) -> Result<impl Stream<Item = serde_json::Value> + Send> {
) -> Result<impl Stream<Item = Result<serde_json::Value>> + Send> {
let document = parse_query(source).map_err(Into::<Error>::into)?;
check_rules(&self.0.registry, &document, self.0.validation_mode)?;

View File

@ -1,4 +1,4 @@
use crate::{ObjectType, Schema, SubscriptionType};
use crate::{ObjectType, Result, Schema, SubscriptionType};
use bytes::Bytes;
use futures::channel::mpsc;
use futures::task::{AtomicWaker, Context, Poll};
@ -9,12 +9,12 @@ use std::pin::Pin;
/// Use to hold all subscription stream for the `SubscriptionConnection`
pub struct SubscriptionStreams {
streams: Slab<Pin<Box<dyn Stream<Item = serde_json::Value> + Send>>>,
streams: Slab<Pin<Box<dyn Stream<Item = Result<serde_json::Value>> + Send>>>,
}
#[allow(missing_docs)]
impl SubscriptionStreams {
pub fn add<S: Stream<Item = serde_json::Value> + Send + 'static>(
pub fn add<S: Stream<Item = Result<serde_json::Value>> + Send + 'static>(
&mut self,
stream: S,
) -> usize {
@ -49,7 +49,7 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static {
Subscription: SubscriptionType + Sync + Send + 'static;
/// When a response message is generated, you can convert the message to the format you want here.
fn handle_response(&mut self, id: usize, value: serde_json::Value) -> Option<Bytes>;
fn handle_response(&mut self, id: usize, res: Result<serde_json::Value>) -> Option<Bytes>;
}
pub fn create_connection<Query, Mutation, Subscription, T: SubscriptionTransport>(
@ -154,8 +154,11 @@ where
for (id, incoming_stream) in &mut this.streams.streams {
match incoming_stream.as_mut().poll_next(cx) {
Poll::Ready(Some(value)) => {
if let Some(bytes) = this.transport.handle_response(id, value) {
Poll::Ready(Some(res)) => {
if res.is_err() {
closed.push(id);
}
if let Some(bytes) = this.transport.handle_response(id, res) {
return Poll::Ready(Some(bytes));
}
}

View File

@ -20,7 +20,7 @@ pub trait SubscriptionType: Type {
ctx: &Context<'_>,
schema: &Schema<Query, Mutation, Self>,
environment: Arc<Environment>,
) -> Result<Pin<Box<dyn Stream<Item = serde_json::Value> + Send>>>
) -> Result<Pin<Box<dyn Stream<Item = Result<serde_json::Value>> + Send>>>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
@ -33,7 +33,7 @@ pub fn create_subscription_stream<'a, Query, Mutation, Subscription>(
schema: &'a Schema<Query, Mutation, Subscription>,
environment: Arc<Environment>,
ctx: &'a ContextSelectionSet<'_>,
streams: &'a mut Vec<Pin<Box<dyn Stream<Item = serde_json::Value> + Send>>>,
streams: &'a mut Vec<Pin<Box<dyn Stream<Item = Result<serde_json::Value>> + Send>>>,
) -> BoxCreateStreamFuture<'a>
where
Query: ObjectType + Send + Sync + 'static,

View File

@ -1,7 +1,7 @@
use crate::context::Data;
use crate::http::{GQLError, GQLRequest, GQLResponse};
use crate::{
FieldError, FieldResult, ObjectType, QueryResponse, Schema, SubscriptionStreams,
FieldError, FieldResult, ObjectType, QueryResponse, Result, Schema, SubscriptionStreams,
SubscriptionTransport, SubscriptionType, Variables,
};
use bytes::Bytes;
@ -132,24 +132,35 @@ impl SubscriptionTransport for WebSocketTransport {
}
}
fn handle_response(&mut self, id: usize, value: serde_json::Value) -> Option<Bytes> {
fn handle_response(&mut self, id: usize, res: Result<serde_json::Value>) -> Option<Bytes> {
if let Some(id) = self.sid_to_id.get(&id) {
Some(
serde_json::to_vec(&OperationMessage {
ty: "data".to_string(),
id: Some(id.clone()),
payload: Some(
serde_json::to_value(GQLResponse(Ok(QueryResponse {
data: value,
extensions: None,
cache_control: Default::default(),
})))
.unwrap(),
),
})
.unwrap()
.into(),
)
match res {
Ok(value) => Some(
serde_json::to_vec(&OperationMessage {
ty: "data".to_string(),
id: Some(id.clone()),
payload: Some(
serde_json::to_value(GQLResponse(Ok(QueryResponse {
data: value,
extensions: None,
cache_control: Default::default(),
})))
.unwrap(),
),
})
.unwrap()
.into(),
),
Err(err) => Some(
serde_json::to_vec(&OperationMessage {
ty: "error".to_string(),
id: Some(id.to_string()),
payload: Some(serde_json::to_value(GQLError(&err)).unwrap()),
})
.unwrap()
.into(),
),
}
} else {
None
}

View File

@ -42,7 +42,7 @@ impl SubscriptionType for EmptySubscription {
_ctx: &Context<'_>,
_schema: &Schema<Query, Mutation, Self>,
_environment: Arc<Environment>,
) -> Result<Pin<Box<dyn Stream<Item = serde_json::Value> + Send>>>
) -> Result<Pin<Box<dyn Stream<Item = Result<serde_json::Value>> + Send>>>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,

View File

@ -204,9 +204,9 @@ pub async fn test_guard() {
.collect::<Vec<_>>()
.await,
vec![
serde_json::json! ({"values1": 1}),
serde_json::json! ({"values1": 2}),
serde_json::json! ({"values1": 3})
Ok(serde_json::json! ({"values1": 1})),
Ok(serde_json::json! ({"values1": 2})),
Ok(serde_json::json! ({"values1": 3}))
]
);
@ -227,9 +227,9 @@ pub async fn test_guard() {
.collect::<Vec<_>>()
.await,
vec![
serde_json::json! ({"values2": 1}),
serde_json::json! ({"values2": 2}),
serde_json::json! ({"values2": 3})
Ok(serde_json::json! ({"values2": 1})),
Ok(serde_json::json! ({"values2": 2})),
Ok(serde_json::json! ({"values2": 3}))
]
);
@ -250,9 +250,9 @@ pub async fn test_guard() {
.collect::<Vec<_>>()
.await,
vec![
serde_json::json! ({"values3": 1}),
serde_json::json! ({"values3": 2}),
serde_json::json! ({"values3": 3})
Ok(serde_json::json! ({"values3": 1})),
Ok(serde_json::json! ({"values3": 2})),
Ok(serde_json::json! ({"values3": 3}))
]
);

View File

@ -1,5 +1,5 @@
use async_graphql::*;
use futures::{SinkExt, Stream, StreamExt};
use futures::{Stream, StreamExt};
use std::sync::Arc;
#[async_std::test]
@ -42,7 +42,7 @@ pub async fn test_subscription() {
.unwrap();
for i in 10..20 {
assert_eq!(
Some(serde_json::json!({ "values": i })),
Some(Ok(serde_json::json!({ "values": i }))),
stream.next().await
);
}
@ -61,7 +61,7 @@ pub async fn test_subscription() {
.unwrap();
for i in 10..20 {
assert_eq!(
Some(serde_json::json!({ "events": {"a": i, "b": i * 10} })),
Some(Ok(serde_json::json!({ "events": {"a": i, "b": i * 10} }))),
stream.next().await
);
}
@ -128,20 +128,20 @@ pub async fn test_simple_broker() {
assert_eq!(
stream1.next().await,
Some(serde_json::json!({ "events1": {"value": 10} }))
Some(Ok(serde_json::json!({ "events1": {"value": 10} })))
);
assert_eq!(
stream1.next().await,
Some(serde_json::json!({ "events1": {"value": 15} }))
Some(Ok(serde_json::json!({ "events1": {"value": 15} })))
);
assert_eq!(
stream2.next().await,
Some(serde_json::json!({ "events2": {"value": 88} }))
Some(Ok(serde_json::json!({ "events2": {"value": 88} })))
);
assert_eq!(
stream2.next().await,
Some(serde_json::json!({ "events2": {"value": 99} }))
Some(Ok(serde_json::json!({ "events2": {"value": 99} })))
);
}
@ -192,11 +192,11 @@ pub async fn test_subscription_with_ctx_data() {
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({ "values": 100 })),
Some(Ok(serde_json::json!({ "values": 100 }))),
stream.next().await
);
assert_eq!(
Some(serde_json::json!({ "objects": { "value": 100 } })),
Some(Ok(serde_json::json!({ "objects": { "value": 100 } }))),
stream.next().await
);
assert!(stream.next().await.is_none());
@ -241,7 +241,7 @@ pub async fn test_subscription_with_token() {
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({ "values": 100 })),
Some(Ok(serde_json::json!({ "values": 100 }))),
stream.next().await
);
assert!(stream.next().await.is_none());
@ -264,148 +264,6 @@ pub async fn test_subscription_with_token() {
}
}
#[async_std::test]
pub async fn test_subscription_ws_transport() {
struct QueryRoot;
#[Object]
impl QueryRoot {}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn values(&self) -> impl Stream<Item = i32> {
futures::stream::iter(0..10)
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::default());
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init",
"payload": { "token": "123456" }
}))
.unwrap()
.into(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap()
.into(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}
}
#[async_std::test]
pub async fn test_subscription_ws_transport_with_token() {
struct Token(String);
struct QueryRoot;
#[Object]
impl QueryRoot {}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn values(&self, ctx: &Context<'_>) -> FieldResult<impl Stream<Item = i32>> {
if ctx.data::<Token>().0 != "123456" {
return Err("forbidden".into());
}
Ok(futures::stream::iter(0..10))
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::new(|value| {
#[derive(serde_derive::Deserialize)]
struct Payload {
token: String,
}
let payload: Payload = serde_json::from_value(value).unwrap();
let mut data = Data::default();
data.insert(Token(payload.token));
Ok(data)
}));
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init",
"payload": { "token": "123456" }
}))
.unwrap()
.into(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap()
.into(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}
}
#[async_std::test]
pub async fn test_subscription_inline_fragment() {
struct QueryRoot;
@ -449,7 +307,7 @@ pub async fn test_subscription_inline_fragment() {
.unwrap();
for i in 10..20 {
assert_eq!(
Some(serde_json::json!({ "events": {"a": i, "b": i * 10} })),
Some(Ok(serde_json::json!({ "events": {"a": i, "b": i * 10} }))),
stream.next().await
);
}
@ -504,7 +362,7 @@ pub async fn test_subscription_fragment() {
.unwrap();
for i in 10..20 {
assert_eq!(
Some(serde_json::json!({ "events": {"a": i, "b": i * 10} })),
Some(Ok(serde_json::json!({ "events": {"a": i, "b": i * 10} }))),
stream.next().await
);
}
@ -560,9 +418,72 @@ pub async fn test_subscription_fragment2() {
.unwrap();
for i in 10..20 {
assert_eq!(
Some(serde_json::json!({ "events": {"a": i, "b": i * 10} })),
Some(Ok(serde_json::json!({ "events": {"a": i, "b": i * 10} }))),
stream.next().await
);
}
assert!(stream.next().await.is_none());
}
#[async_std::test]
pub async fn test_subscription_error() {
struct QueryRoot;
struct Event {
value: i32,
}
#[Object]
impl Event {
async fn value(&self) -> FieldResult<i32> {
if self.value < 5 {
Ok(self.value)
} else {
Err("TestError".into())
}
}
}
#[Object]
impl QueryRoot {}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn events(&self) -> impl Stream<Item = Event> {
futures::stream::iter((0..10).map(|n| Event { value: n }))
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let mut stream = schema
.create_subscription_stream(
"subscription { events { value } }",
None,
Default::default(),
None,
)
.await
.unwrap();
for i in 0i32..5 {
assert_eq!(
Some(Ok(serde_json::json!({ "events": { "value": i } }))),
stream.next().await
);
}
assert_eq!(
stream.next().await,
Some(Err(Error::Query {
pos: Pos {
line: 1,
column: 25
},
path: Some(serde_json::json!(["events", "value"])),
err: QueryError::FieldError {
err: "TestError".to_string(),
extended_error: None,
},
}))
);
}

View File

@ -0,0 +1,236 @@
use async_graphql::*;
use futures::{SinkExt, Stream, StreamExt};
#[async_std::test]
pub async fn test_subscription_ws_transport() {
struct QueryRoot;
#[Object]
impl QueryRoot {}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn values(&self) -> impl Stream<Item = i32> {
futures::stream::iter(0..10)
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::default());
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init",
"payload": { "token": "123456" }
}))
.unwrap()
.into(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap()
.into(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}
}
#[async_std::test]
pub async fn test_subscription_ws_transport_with_token() {
struct Token(String);
struct QueryRoot;
#[Object]
impl QueryRoot {}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn values(&self, ctx: &Context<'_>) -> FieldResult<impl Stream<Item = i32>> {
if ctx.data::<Token>().0 != "123456" {
return Err("forbidden".into());
}
Ok(futures::stream::iter(0..10))
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, mut stream) = schema.subscription_connection(WebSocketTransport::new(|value| {
#[derive(serde_derive::Deserialize)]
struct Payload {
token: String,
}
let payload: Payload = serde_json::from_value(value).unwrap();
let mut data = Data::default();
data.insert(Token(payload.token));
Ok(data)
}));
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init",
"payload": { "token": "123456" }
}))
.unwrap()
.into(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap()
.into(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "values": i } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}
}
#[async_std::test]
pub async fn test_subscription_ws_transport_error() {
struct QueryRoot;
struct Event {
value: i32,
}
#[Object]
impl Event {
async fn value(&self) -> FieldResult<i32> {
if self.value < 5 {
Ok(self.value)
} else {
Err("TestError".into())
}
}
}
#[Object]
impl QueryRoot {}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
async fn events(&self) -> impl Stream<Item = Event> {
futures::stream::iter((0..10).map(|n| Event { value: n }))
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, mut stream) =
schema.subscription_connection(WebSocketTransport::new(|_| Ok(Data::default())));
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init"
}))
.unwrap()
.into(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { events { value } }"
},
}))
.unwrap()
.into(),
)
.await
.unwrap();
for i in 0i32..5 {
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "events": { "value": i } } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}
assert_eq!(
Some(serde_json::json!({
"type": "error",
"id": "1",
"payload": [{
"message": "TestError",
"locations": [{"line": 1, "column": 25}],
"path": ["events", "value"],
}],
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}