Add context data for subscription
This commit is contained in:
sunli 2020-04-23 10:26:16 +08:00
parent 6dfa6a2614
commit 09624cab24
13 changed files with 236 additions and 68 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "async-graphql"
version = "1.9.10"
version = "1.9.11"
authors = ["sunli <scott_s829@163.com>"]
edition = "2018"
description = "The GraphQL server library implemented by rust"
@ -18,7 +18,7 @@ default = ["bson", "uuid", "url", "chrono-tz", "validators"]
validators = ["regex"]
[dependencies]
async-graphql-derive = { path = "async-graphql-derive", version = "1.9.10" }
async-graphql-derive = { path = "async-graphql-derive", version = "1.9.11" }
graphql-parser = "=0.2.3"
anyhow = "1.0.26"
thiserror = "1.0.11"

View File

@ -1,6 +1,6 @@
[package]
name = "async-graphql-actix-web"
version = "1.0.10"
version = "1.1.0"
authors = ["sunli <scott_s829@163.com>"]
edition = "2018"
description = "async-graphql for actix-web"
@ -13,7 +13,7 @@ keywords = ["futures", "async", "graphql"]
categories = ["network-programming", "asynchronous"]
[dependencies]
async-graphql = { path = "..", version = "1.9.10" }
async-graphql = { path = "..", version = "1.9.11" }
actix-web = "2.0.0"
actix-web-actors = "2.0.0"
actix = "0.9.0"

View File

@ -2,10 +2,12 @@ use actix::{
Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture,
};
use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext};
use async_graphql::{ObjectType, Schema, SubscriptionType, WebSocketTransport};
use async_graphql::{Data, ObjectType, Schema, SubscriptionType, WebSocketTransport};
use bytes::Bytes;
use futures::channel::mpsc;
use futures::SinkExt;
use std::any::Any;
use std::sync::Arc;
use std::time::{Duration, Instant};
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
@ -16,6 +18,7 @@ pub struct WSSubscription<Query, Mutation, Subscription> {
schema: Schema<Query, Mutation, Subscription>,
hb: Instant,
sink: Option<mpsc::Sender<Bytes>>,
data: Data,
}
impl<Query, Mutation, Subscription> WSSubscription<Query, Mutation, Subscription>
@ -30,9 +33,18 @@ where
schema: schema.clone(),
hb: Instant::now(),
sink: None,
data: Default::default(),
}
}
/// Add a context data that can be accessed in the `Context`, you access it with `Context::data`.
///
/// **This data is only valid for this subscription**
pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
self.data.insert(data);
self
}
fn hb(&self, ctx: &mut WebsocketContext<Self>) {
ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT {
@ -54,7 +66,10 @@ where
fn started(&mut self, ctx: &mut Self::Context) {
self.hb(ctx);
let schema = self.schema.clone();
let (sink, stream) = schema.subscription_connection(WebSocketTransport::default());
let (sink, stream) = schema.subscription_connection(
WebSocketTransport::default(),
Some(Arc::new(std::mem::take(&mut self.data))),
);
ctx.add_stream(stream);
self.sink = Some(sink);
}

View File

@ -1,6 +1,6 @@
[package]
name = "async-graphql-derive"
version = "1.9.10"
version = "1.9.11"
authors = ["sunli <scott_s829@163.com>"]
edition = "2018"
description = "Macros for async-graphql"

View File

@ -3,7 +3,9 @@ use crate::utils::{build_value_repr, check_reserved_name, get_crate_name};
use inflector::Inflector;
use proc_macro::TokenStream;
use quote::quote;
use syn::{Error, FnArg, ImplItem, ItemImpl, Pat, Result, ReturnType, Type, TypeImplTrait};
use syn::{
Error, FnArg, ImplItem, ItemImpl, Pat, Result, ReturnType, Type, TypeImplTrait, TypeReference,
};
pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<TokenStream> {
let crate_name = get_crate_name(object_args.internal);
@ -61,20 +63,25 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
));
}
match &method.sig.inputs[0] {
FnArg::Receiver(_) => {}
_ => {
return Err(Error::new_spanned(
&method.sig.inputs[0],
"The first argument must be self receiver",
));
}
}
let mut arg_ctx = false;
let mut args = Vec::new();
for arg in method.sig.inputs.iter_mut().skip(1) {
if let FnArg::Typed(pat) = arg {
for (idx, arg) in method.sig.inputs.iter_mut().enumerate() {
if let FnArg::Receiver(receiver) = arg {
if idx != 0 {
return Err(Error::new_spanned(
receiver,
"The self receiver must be the first parameter.",
));
}
} else if let FnArg::Typed(pat) = arg {
if idx == 0 {
return Err(Error::new_spanned(
pat,
"The self receiver must be the first parameter.",
));
}
match (&*pat.pat, &*pat.ty) {
(Pat::Ident(arg_ident), Type::Path(arg_ty)) => {
args.push((
@ -84,6 +91,19 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
));
pat.attrs.clear();
}
(_, Type::Reference(TypeReference { elem, .. })) => {
if let Type::Path(path) = elem.as_ref() {
if idx != 1
|| path.path.segments.last().unwrap().ident != "Context"
{
return Err(Error::new_spanned(
arg,
"The Context must be the second argument.",
));
}
arg_ctx = true;
}
}
_ => {
return Err(Error::new_spanned(arg, "Incorrect argument type"));
}
@ -182,6 +202,12 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
});
});
let ctx_param = if arg_ctx {
quote! { &ctx, }
} else {
quote! {}
};
create_stream.push(quote! {
if ctx.name.as_str() == #field_name {
let field_name = ctx.result_name().to_string();
@ -190,21 +216,20 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
let schema = schema.clone();
let pos = ctx.position;
let environment = environment.clone();
let stream = #crate_name::futures::stream::StreamExt::then(self.#ident(#(#use_params),*).await.fuse(), move |msg| {
let stream = #crate_name::futures::stream::StreamExt::then(self.#ident(#ctx_param #(#use_params),*).await.fuse(), move |msg| {
let environment = environment.clone();
let field_selection_set = field_selection_set.clone();
let schema = schema.clone();
async move {
let resolve_id = std::sync::atomic::AtomicUsize::default();
let ctx_selection_set = environment.create_context(
&*field_selection_set,
&schema,
Some(#crate_name::QueryPathNode {
parent: None,
segment: #crate_name::QueryPathSegment::Name("time"),
}),
&*field_selection_set,
&resolve_id,
schema.registry(),
schema.data(),
);
#crate_name::OutputValueType::resolve(&msg, &ctx_selection_set, pos).await
}

View File

@ -1,6 +1,6 @@
[package]
name = "async-graphql-warp"
version = "1.0.11"
version = "1.1.0"
authors = ["sunli <scott_s829@163.com>"]
edition = "2018"
description = "async-graphql for warp"
@ -13,7 +13,7 @@ keywords = ["futures", "async", "graphql"]
categories = ["network-programming", "asynchronous"]
[dependencies]
async-graphql = { path = "..", version = "1.9.10" }
async-graphql = { path = "..", version = "1.9.11" }
warp = "0.2.2"
futures = "0.3.0"
bytes = "0.5.4"

View File

@ -6,8 +6,8 @@
use async_graphql::http::StreamBody;
use async_graphql::{
IntoQueryBuilder, IntoQueryBuilderOpts, ObjectType, QueryBuilder, Schema, SubscriptionType,
WebSocketTransport,
Data, IntoQueryBuilder, IntoQueryBuilderOpts, ObjectType, QueryBuilder, Schema,
SubscriptionType, WebSocketTransport,
};
use bytes::Bytes;
use futures::select;
@ -145,27 +145,30 @@ where
/// #[tokio::main]
/// async fn main() {
/// let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
/// let filter = async_graphql_warp::graphql_subscription(schema);
/// let filter = async_graphql_warp::graphql_subscription(schema, None);
/// warp::serve(filter).run(([0, 0, 0, 0], 8000)).await;
/// }
/// ```
pub fn graphql_subscription<Query, Mutation, Subscription>(
schema: Schema<Query, Mutation, Subscription>,
ctx_data: Option<Data>,
) -> BoxedFilter<(impl Reply,)>
where
Query: ObjectType + Sync + Send + 'static,
Mutation: ObjectType + Sync + Send + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
let ctx_data = ctx_data.map(Arc::new);
warp::any()
.and(warp::ws())
.and(warp::any().map(move || schema.clone()))
.and(warp::any().map(move || ctx_data.clone()))
.map(
|ws: warp::ws::Ws, schema: Schema<Query, Mutation, Subscription>| {
|ws: warp::ws::Ws, schema: Schema<Query, Mutation, Subscription>, ctx_data: Option<Arc<Data>>| {
ws.on_upgrade(move |websocket| {
let (mut tx, rx) = websocket.split();
let (mut stx, srx) =
schema.subscription_connection(WebSocketTransport::default());
schema.subscription_connection(WebSocketTransport::default(),ctx_data.clone());
let mut rx = rx.fuse();
let mut srx = srx.fuse();

View File

@ -1,6 +1,6 @@
use crate::extensions::BoxExtension;
use crate::registry::Registry;
use crate::{InputValueType, Pos, QueryError, Result, Type};
use crate::{InputValueType, Pos, QueryError, Result, Schema, Type};
use graphql_parser::query::{
Directive, Field, FragmentDefinition, SelectionSet, Value, VariableDefinition,
};
@ -9,6 +9,7 @@ use std::collections::{BTreeMap, HashMap};
use std::ops::{Deref, DerefMut};
use std::path::Path;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
/// Variables of query
#[derive(Debug, Clone)]
@ -128,9 +129,11 @@ fn json_value_to_gql_value(value: serde_json::Value) -> Value {
}
#[derive(Default)]
/// Schema/Context data
pub struct Data(BTreeMap<TypeId, Box<dyn Any + Sync + Send>>);
impl Data {
#[allow(missing_docs)]
pub fn insert<D: Any + Send + Sync>(&mut self, data: D) {
self.0.insert(TypeId::of::<D>(), Box::new(data));
}
@ -247,17 +250,17 @@ pub struct Environment {
pub variables: Variables,
pub variable_definitions: Vec<VariableDefinition>,
pub fragments: HashMap<String, FragmentDefinition>,
pub ctx_data: Arc<Data>,
}
impl Environment {
#[doc(hidden)]
pub fn create_context<'a, T>(
pub fn create_context<'a, T, Query, Mutation, Subscription>(
&'a self,
item: T,
schema: &'a Schema<Query, Mutation, Subscription>,
path_node: Option<QueryPathNode<'a>>,
item: T,
resolve_id: &'a AtomicUsize,
registry: &'a Registry,
data: &'a Data,
) -> ContextBase<'a, T> {
ContextBase {
path_node,
@ -266,24 +269,15 @@ impl Environment {
item,
variables: &self.variables,
variable_definitions: &self.variable_definitions,
registry,
data,
ctx_data: None,
registry: &schema.0.registry,
data: &schema.0.data,
ctx_data: Some(&self.ctx_data),
fragments: &self.fragments,
}
}
}
impl<'a, T> ContextBase<'a, T> {
#[doc(hidden)]
pub fn create_environment(&self) -> Environment {
Environment {
variables: self.variables.clone(),
variable_definitions: self.variable_definitions.to_vec(),
fragments: self.fragments.clone(),
}
}
#[doc(hidden)]
pub fn get_resolve_id(&self) -> usize {
self.resolve_id

View File

@ -106,7 +106,9 @@ pub use serde_json;
pub mod http;
pub use base::{Scalar, Type};
pub use context::{Context, ContextBase, Environment, QueryPathNode, QueryPathSegment, Variables};
pub use context::{
Context, ContextBase, Data, Environment, QueryPathNode, QueryPathSegment, Variables,
};
pub use error::{
Error, ErrorExtensions, FieldError, FieldResult, ParseRequestError, QueryError, ResultExt,
};

View File

@ -7,8 +7,8 @@ use crate::subscription::{create_connection, create_subscription_stream, Subscri
use crate::types::QueryRoot;
use crate::validation::{check_rules, ValidationMode};
use crate::{
ContextSelectionSet, Error, ObjectType, Pos, QueryError, QueryResponse, Result,
SubscriptionStream, SubscriptionType, Type, Variables,
Environment, Error, ObjectType, Pos, QueryError, QueryResponse, Result, SubscriptionStream,
SubscriptionType, Type, Variables,
};
use bytes::Bytes;
use futures::channel::mpsc;
@ -240,6 +240,7 @@ where
source: &str,
operation_name: Option<&str>,
variables: Variables,
ctx_data: Option<Arc<Data>>,
) -> Result<impl Stream<Item = serde_json::Value> + Send> {
let document = parse_query(source).map_err(Into::<Error>::into)?;
check_rules(&self.0.registry, &document, self.0.validation_mode)?;
@ -272,22 +273,15 @@ where
})?;
let resolve_id = AtomicUsize::default();
let ctx = ContextSelectionSet {
path_node: None,
extensions: &[],
item: &subscription.selection_set,
resolve_id: &resolve_id,
variables: &variables,
variable_definitions: &subscription.variable_definitions,
registry: &self.0.registry,
data: &Default::default(),
ctx_data: None,
fragments: &fragments,
};
let environment = Arc::new(Environment {
variables,
variable_definitions: subscription.variable_definitions,
fragments,
ctx_data: ctx_data.unwrap_or_default(),
});
let ctx = environment.create_context(self, None, &subscription.selection_set, &resolve_id);
let mut streams = Vec::new();
create_subscription_stream(self, Arc::new(ctx.create_environment()), &ctx, &mut streams)
.await?;
create_subscription_stream(self, environment.clone(), &ctx, &mut streams).await?;
Ok(futures::stream::select_all(streams))
}
@ -295,10 +289,11 @@ where
pub fn subscription_connection<T: SubscriptionTransport>(
&self,
transport: T,
ctx_data: Option<Arc<Data>>,
) -> (
mpsc::Sender<Bytes>,
SubscriptionStream<Query, Mutation, Subscription, T>,
) {
create_connection(self.clone(), transport)
create_connection(self.clone(), transport, ctx_data.unwrap_or_default())
}
}

View File

@ -1,3 +1,4 @@
use crate::context::Data;
use crate::{ObjectType, Schema, SubscriptionType};
use bytes::Bytes;
use futures::channel::mpsc;
@ -6,6 +7,7 @@ use futures::Stream;
use slab::Slab;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
/// Use to hold all subscription stream for the `SubscriptionConnection`
pub struct SubscriptionStreams {
@ -42,6 +44,7 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static {
schema: &Schema<Query, Mutation, Subscription>,
streams: &mut SubscriptionStreams,
data: Bytes,
ctx_data: Arc<Data>,
) -> std::result::Result<Option<Bytes>, Self::Error>
where
Query: ObjectType + Sync + Send + 'static,
@ -55,6 +58,7 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static {
pub fn create_connection<Query, Mutation, Subscription, T: SubscriptionTransport>(
schema: Schema<Query, Mutation, Subscription>,
transport: T,
ctx_data: Arc<Data>,
) -> (
mpsc::Sender<Bytes>,
SubscriptionStream<Query, Mutation, Subscription, T>,
@ -69,6 +73,7 @@ where
tx_bytes,
SubscriptionStream {
schema,
ctx_data,
transport,
streams: SubscriptionStreams {
streams: Default::default(),
@ -83,6 +88,7 @@ where
#[allow(clippy::type_complexity)]
pub struct SubscriptionStream<Query, Mutation, Subscription, T: SubscriptionTransport> {
schema: Schema<Query, Mutation, Subscription>,
ctx_data: Arc<Data>,
transport: T,
streams: SubscriptionStreams,
rx_bytes: mpsc::Receiver<Bytes>,
@ -127,11 +133,13 @@ where
let transport = &mut this.transport as *mut T;
let schema = &this.schema as *const Schema<Query, Mutation, Subscription>;
let streams = &mut this.streams as *mut SubscriptionStreams;
let ctx_data = this.ctx_data.clone();
unsafe {
this.handle_request_fut = Some(Box::pin((*transport).handle_request(
&*schema,
&mut *streams,
data,
ctx_data.clone(),
)));
}
continue;

View File

@ -1,3 +1,4 @@
use crate::context::Data;
use crate::http::{GQLError, GQLRequest, GQLResponse};
use crate::{
ObjectType, QueryResponse, Schema, SubscriptionStreams, SubscriptionTransport,
@ -5,12 +6,17 @@ use crate::{
};
use bytes::Bytes;
use std::collections::HashMap;
use std::sync::Arc;
#[derive(Serialize, Deserialize)]
struct OperationMessage {
#[serde(rename = "type")]
ty: String,
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
payload: Option<serde_json::Value>,
}
@ -30,6 +36,7 @@ impl SubscriptionTransport for WebSocketTransport {
schema: &Schema<Query, Mutation, Subscription>,
streams: &mut SubscriptionStreams,
data: Bytes,
ctx_data: Arc<Data>,
) -> std::result::Result<Option<Bytes>, Self::Error>
where
Query: ObjectType + Sync + Send + 'static,
@ -60,6 +67,7 @@ impl SubscriptionTransport for WebSocketTransport {
&request.query,
request.operation_name.as_deref(),
variables,
Some(ctx_data),
)
.await
{

View File

@ -1,5 +1,6 @@
use async_graphql::*;
use futures::{Stream, StreamExt};
use futures::{SinkExt, Stream, StreamExt};
use std::sync::Arc;
#[async_std::test]
pub async fn test_subscription() {
@ -40,6 +41,7 @@ pub async fn test_subscription() {
"subscription { values(start: 10, end: 20) }",
None,
Default::default(),
None,
)
.await
.unwrap();
@ -58,6 +60,7 @@ pub async fn test_subscription() {
"subscription { events(start: 10, end: 20) { a b } }",
None,
Default::default(),
None,
)
.await
.unwrap();
@ -113,6 +116,7 @@ pub async fn test_simple_broker() {
"subscription { events1 { value } }",
None,
Default::default(),
None,
)
.await
.unwrap();
@ -121,6 +125,7 @@ pub async fn test_simple_broker() {
"subscription { events2 { value } }",
None,
Default::default(),
None,
)
.await
.unwrap();
@ -148,3 +153,116 @@ pub async fn test_simple_broker() {
Some(serde_json::json!({ "events2": {"value": 99} }))
);
}
#[async_std::test]
pub async fn test_subscription_with_ctx_data() {
struct QueryRoot;
#[Object]
impl QueryRoot {}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
#[field]
async fn values(&self, ctx: &Context<'_>) -> impl Stream<Item = i32> {
let value = *ctx.data::<i32>();
futures::stream::once(async move { value })
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
{
let mut stream = schema
.create_subscription_stream(
"subscription { values }",
None,
Default::default(),
Some(Arc::new({
let mut data = Data::default();
data.insert(100i32);
data
})),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({ "values": 100 })),
stream.next().await
);
assert!(stream.next().await.is_none());
}
}
#[async_std::test]
pub async fn test_subscription_ws_transport() {
struct QueryRoot;
#[Object]
impl QueryRoot {}
struct SubscriptionRoot;
#[Subscription]
impl SubscriptionRoot {
#[field]
async fn values(&self, ctx: &Context<'_>) -> impl Stream<Item = i32> {
let step = *ctx.data::<i32>();
futures::stream::iter((0..10).map(move |n| n * step))
}
}
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let (mut sink, mut stream) = schema.subscription_connection(
WebSocketTransport::default(),
Some(Arc::new({
let mut data = Data::default();
data.insert(5);
data
})),
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "connection_init",
}))
.unwrap()
.into(),
)
.await
.unwrap();
assert_eq!(
Some(serde_json::json!({
"type": "connection_ack",
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
sink.send(
serde_json::to_vec(&serde_json::json!({
"type": "start",
"id": "1",
"payload": {
"query": "subscription { values }"
},
}))
.unwrap()
.into(),
)
.await
.unwrap();
for i in 0..10 {
assert_eq!(
Some(serde_json::json!({
"type": "data",
"id": "1",
"payload": { "data": { "values": i * 5 } },
})),
serde_json::from_slice(&stream.next().await.unwrap()).unwrap()
);
}
}