The subscription field now returns a stream

This commit is contained in:
sunli 2020-04-06 13:49:39 +08:00
parent 4225a03de0
commit b88ac80084
15 changed files with 331 additions and 570 deletions

View File

@ -1,109 +1,15 @@
use actix::clock::Duration;
use actix_web::{web, App, HttpServer};
use async_graphql::{publish, Context, FieldResult, Schema, ID};
use futures::lock::Mutex;
use slab::Slab;
use std::sync::Arc;
#[derive(Clone)]
struct Book {
id: ID,
name: String,
author: String,
}
#[async_graphql::Object]
impl Book {
#[field]
async fn id(&self) -> &str {
&self.id
}
#[field]
async fn name(&self) -> &str {
&self.name
}
#[field]
async fn author(&self) -> &str {
&self.author
}
}
type Storage = Arc<Mutex<Slab<Book>>>;
use async_graphql::{EmptyMutation, Schema};
use futures::{Stream, StreamExt};
struct QueryRoot;
#[async_graphql::Object(cache_control(max_age = 5))]
#[async_graphql::Object]
impl QueryRoot {
#[field]
async fn books(&self, ctx: &Context<'_>) -> Vec<Book> {
let books = ctx.data::<Storage>().lock().await;
books.iter().map(|(_, book)| book).cloned().collect()
}
}
struct MutationRoot;
#[async_graphql::Object]
impl MutationRoot {
#[field]
async fn create_book(&self, ctx: &Context<'_>, name: String, author: String) -> ID {
let mut books = ctx.data::<Storage>().lock().await;
let entry = books.vacant_entry();
let id: ID = entry.key().into();
let book = Book {
id: id.clone(),
name,
author,
};
entry.insert(book);
publish(BookChanged {
mutation_type: MutationType::Created,
id: id.clone(),
})
.await;
id
}
#[field]
async fn delete_book(&self, ctx: &Context<'_>, id: ID) -> FieldResult<bool> {
let mut books = ctx.data::<Storage>().lock().await;
let id = id.parse::<usize>()?;
if books.contains(id) {
books.remove(id);
publish(BookChanged {
mutation_type: MutationType::Deleted,
id: id.into(),
})
.await;
Ok(true)
} else {
Ok(false)
}
}
}
#[async_graphql::Enum]
enum MutationType {
Created,
Deleted,
}
struct BookChanged {
mutation_type: MutationType,
id: ID,
}
#[async_graphql::Object]
impl BookChanged {
#[field]
async fn mutation_type(&self) -> &MutationType {
&self.mutation_type
}
#[field]
async fn id(&self) -> &ID {
&self.id
async fn value(&self) -> i32 {
0
}
}
@ -112,20 +18,19 @@ struct SubscriptionRoot;
#[async_graphql::Subscription]
impl SubscriptionRoot {
#[field]
fn books(&self, changed: &BookChanged, mutation_type: Option<MutationType>) -> bool {
if let Some(mutation_type) = mutation_type {
return changed.mutation_type == mutation_type;
}
true
fn interval(&self, n: i32) -> impl Stream<Item = i32> {
let mut value = 0;
actix_rt::time::interval(Duration::from_secs(1)).map(move |_| {
value += n;
value
})
}
}
#[actix_rt::main]
async fn main() -> std::io::Result<()> {
HttpServer::new(move || {
let schema = Schema::build(QueryRoot, MutationRoot, SubscriptionRoot)
.data(Storage::default())
.finish();
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let handler = async_graphql_actix_web::HandlerBuilder::new(schema)
.enable_ui("http://localhost:8000", Some("ws://localhost:8000"))
.enable_subscription()

View File

@ -11,10 +11,7 @@ use actix_web::web::{BytesMut, Payload};
use actix_web::{web, FromRequest, HttpRequest, HttpResponse, Responder};
use actix_web_actors::ws;
use async_graphql::http::{GQLRequest, GQLResponse};
use async_graphql::{
ObjectType, QueryBuilder, Schema, SubscriptionConnectionBuilder, SubscriptionType,
WebSocketTransport,
};
use async_graphql::{ObjectType, QueryBuilder, Schema, SubscriptionType};
use bytes::Bytes;
use futures::StreamExt;
use mime::Mime;
@ -30,13 +27,6 @@ type BoxOnRequestFn<Query, Mutation, Subscription> = Arc<
) -> QueryBuilder<Query, Mutation, Subscription>,
>;
type BoxOnConnectFn<Query, Mutation, Subscription> = Arc<
dyn Fn(
&HttpRequest,
SubscriptionConnectionBuilder<Query, Mutation, Subscription, WebSocketTransport>,
) -> SubscriptionConnectionBuilder<Query, Mutation, Subscription, WebSocketTransport>,
>;
/// Actix-web handler builder
pub struct HandlerBuilder<Query, Mutation, Subscription> {
schema: Schema<Query, Mutation, Subscription>,
@ -45,7 +35,6 @@ pub struct HandlerBuilder<Query, Mutation, Subscription> {
enable_subscription: bool,
enable_ui: Option<(String, Option<String>)>,
on_request: Option<BoxOnRequestFn<Query, Mutation, Subscription>>,
on_connect: Option<BoxOnConnectFn<Query, Mutation, Subscription>>,
}
impl<Query, Mutation, Subscription> HandlerBuilder<Query, Mutation, Subscription>
@ -63,7 +52,6 @@ where
enable_subscription: false,
enable_ui: None,
on_request: None,
on_connect: None,
}
}
@ -122,24 +110,6 @@ where
}
}
/// When there is a new subscription connection, you can use this closure to append your own data to the `SubscriptionConnectionBuilder`.
pub fn on_connect<
F: Fn(
&HttpRequest,
SubscriptionConnectionBuilder<Query, Mutation, Subscription, WebSocketTransport>,
)
-> SubscriptionConnectionBuilder<Query, Mutation, Subscription, WebSocketTransport>
+ 'static,
>(
self,
f: F,
) -> Self {
Self {
on_connect: Some(Arc::new(f)),
..self
}
}
/// Create an HTTP handler.
pub fn build(
self,
@ -155,13 +125,11 @@ where
let enable_ui = self.enable_ui;
let enable_subscription = self.enable_subscription;
let on_request = self.on_request;
let on_connect = self.on_connect;
move |req: HttpRequest, payload: Payload| {
let schema = schema.clone();
let enable_ui = enable_ui.clone();
let on_request = on_request.clone();
let on_connect = on_connect.clone();
Box::pin(async move {
if req.method() == Method::GET {
@ -170,11 +138,7 @@ where
if let Ok(s) = s.to_str() {
if s.to_ascii_lowercase().contains("websocket") {
return ws::start_with_protocols(
WsSession::new(
schema.clone(),
req.clone(),
on_connect.clone(),
),
WsSession::new(schema.clone()),
&["graphql-ws"],
&req,
payload,

View File

@ -1,8 +1,6 @@
use crate::BoxOnConnectFn;
use actix::{
Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture,
};
use actix_web::HttpRequest;
use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext};
use async_graphql::{ObjectType, Schema, SubscriptionType, WebSocketTransport};
use bytes::Bytes;
@ -11,11 +9,9 @@ use futures::SinkExt;
use std::time::{Duration, Instant};
pub struct WsSession<Query, Mutation, Subscription> {
req: HttpRequest,
schema: Schema<Query, Mutation, Subscription>,
hb: Instant,
sink: Option<mpsc::Sender<Bytes>>,
on_connect: Option<BoxOnConnectFn<Query, Mutation, Subscription>>,
}
impl<Query, Mutation, Subscription> WsSession<Query, Mutation, Subscription>
@ -24,17 +20,11 @@ where
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
pub fn new(
schema: Schema<Query, Mutation, Subscription>,
req: HttpRequest,
on_connect: Option<BoxOnConnectFn<Query, Mutation, Subscription>>,
) -> Self {
pub fn new(schema: Schema<Query, Mutation, Subscription>) -> Self {
Self {
req,
schema,
hb: Instant::now(),
sink: None,
on_connect,
}
}
@ -58,24 +48,9 @@ where
fn started(&mut self, ctx: &mut Self::Context) {
self.hb(ctx);
let schema = self.schema.clone();
let on_connect = self.on_connect.clone();
let req = self.req.clone();
async move {
let mut builder = schema
.clone()
.subscription_connection(WebSocketTransport::default());
if let Some(on_connect) = on_connect {
builder = on_connect(&req, builder);
}
builder.build().await
}
.into_actor(self)
.then(|(sink, stream), actor, ctx| {
actor.sink = Some(sink);
ctx.add_stream(stream);
async {}.into_actor(actor)
})
.wait(ctx);
let (sink, stream) = schema.subscription_connection(WebSocketTransport::default());
ctx.add_stream(stream);
self.sink = Some(sink);
}
}

View File

@ -3,7 +3,7 @@ use crate::utils::{build_value_repr, check_reserved_name, get_crate_name};
use inflector::Inflector;
use proc_macro::TokenStream;
use quote::quote;
use syn::{Error, FnArg, ImplItem, ItemImpl, Pat, Result, ReturnType, Type};
use syn::{Error, FnArg, ImplItem, ItemImpl, Pat, Result, ReturnType, Type, TypeImplTrait};
pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<TokenStream> {
let crate_name = get_crate_name(object_args.internal);
@ -32,8 +32,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
.map(|s| quote! {Some(#s)})
.unwrap_or_else(|| quote! {None});
let mut create_types = Vec::new();
let mut filters = Vec::new();
let mut create_stream = Vec::new();
let mut schema_fields = Vec::new();
for item in &mut item_impl.items {
@ -55,32 +54,10 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
.map(|s| quote! {Some(#s)})
.unwrap_or_else(|| quote! {None});
if method.sig.inputs.len() < 2 {
return Err(Error::new_spanned(
&method.sig.inputs,
"The filter function needs at least two arguments",
));
}
if method.sig.asyncness.is_some() {
return Err(Error::new_spanned(
&method.sig.inputs,
"The filter function must be synchronous",
));
}
let mut res_typ_ok = false;
if let ReturnType::Type(_, res_ty) = &method.sig.output {
if let Type::Path(p) = res_ty.as_ref() {
if p.path.is_ident("bool") {
res_typ_ok = true;
}
}
}
if !res_typ_ok {
return Err(Error::new_spanned(
&method.sig.output,
"The filter function must return a boolean value",
&method.sig.asyncness,
"The subscription stream function must be synchronous",
));
}
@ -94,23 +71,9 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
}
}
let ty = if let FnArg::Typed(ty) = &method.sig.inputs[1] {
match ty.ty.as_ref() {
Type::Reference(r) => r.elem.as_ref().clone(),
_ => {
return Err(Error::new_spanned(ty, "Incorrect object type"));
}
}
} else {
return Err(Error::new_spanned(
&method.sig.inputs[1],
"Incorrect object type",
));
};
let mut args = Vec::new();
for arg in method.sig.inputs.iter_mut().skip(2) {
for arg in method.sig.inputs.iter_mut().skip(1) {
if let FnArg::Typed(pat) = arg {
match (&*pat.pat, &*pat.ty) {
(Pat::Ident(arg_ident), Type::Path(arg_ty)) => {
@ -181,10 +144,26 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
};
get_params.push(quote! {
let #ident: #ty = ctx_field.param_value(#name, field.position, #default)?;
let #ident: #ty = ctx.param_value(#name, ctx.position, #default)?;
});
}
let stream_ty = match &method.sig.output {
ReturnType::Default => {
return Err(Error::new_spanned(
&method.sig.output,
"Must be return a stream type",
))
}
ReturnType::Type(_, ty) => {
if let Type::ImplTrait(TypeImplTrait { bounds, .. }) = ty.as_ref() {
quote! { #bounds }
} else {
quote! { #ty }
}
}
};
schema_fields.push(quote! {
fields.insert(#field_name.to_string(), #crate_name::registry::Field {
name: #field_name.to_string(),
@ -194,30 +173,46 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
#(#schema_args)*
args
},
ty: <#ty as #crate_name::Type>::create_type_info(registry),
ty: <#stream_ty as #crate_name::futures::stream::Stream>::Item::create_type_info(registry),
deprecation: #field_deprecation,
cache_control: Default::default(),
});
});
create_types.push(quote! {
if field.name.as_str() == #field_name {
types.insert(std::any::TypeId::of::<#ty>(), field.clone());
return Ok(());
}
});
filters.push(quote! {
if let Some(msg) = msg.downcast_ref::<#ty>() {
create_stream.push(quote! {
if ctx.name.as_str() == #field_name {
let field_name = ctx.result_name().to_string();
#(#get_params)*
if self.#ident(msg, #(#use_params)*) {
let ctx_selection_set = ctx_field.with_selection_set(&field.selection_set);
let value =
#crate_name::OutputValueType::resolve(msg, &ctx_selection_set, field.position).await?;
let mut res = #crate_name::serde_json::Map::new();
res.insert(ctx_field.result_name().to_string(), value);
return Ok(Some(res.into()));
}
let field_selection_set = std::sync::Arc::new(ctx.selection_set.clone());
let schema = schema.clone();
let pos = ctx.position;
let environment = environment.clone();
let stream = #crate_name::futures::stream::StreamExt::then(self.#ident(#(#use_params)*).fuse(), move |msg| {
let environment = environment.clone();
let field_selection_set = field_selection_set.clone();
let schema = schema.clone();
async move {
let resolve_id = std::sync::atomic::AtomicUsize::default();
let ctx_selection_set = environment.create_context(
&*field_selection_set,
Some(#crate_name::QueryPathNode {
parent: None,
segment: #crate_name::QueryPathSegment::Name("time"),
}),
&resolve_id,
schema.registry(),
schema.data(),
);
#crate_name::OutputValueType::resolve(&msg, &ctx_selection_set, pos).await
}
}).
filter_map(move |res| {
let res = res.ok().map(|value| {
#crate_name::serde_json::json!({ &field_name: value })
});
async move { res }
});
return Ok(Box::pin(stream));
}
});
@ -234,6 +229,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
std::borrow::Cow::Borrowed(#gql_typename)
}
#[allow(bare_trait_objects)]
fn create_type_info(registry: &mut #crate_name::registry::Registry) -> String {
registry.create_type::<Self, _>(|registry| #crate_name::registry::Type::Object {
name: #gql_typename.to_string(),
@ -250,26 +246,24 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
#[#crate_name::async_trait::async_trait]
impl #crate_name::SubscriptionType for SubscriptionRoot {
fn create_type(field: &#crate_name::graphql_parser::query::Field, types: &mut std::collections::HashMap<std::any::TypeId, #crate_name::graphql_parser::query::Field>) -> #crate_name::Result<()> {
#(#create_types)*
Err(#crate_name::QueryError::FieldNotFound {
field_name: field.name.clone(),
object: #gql_typename.to_string(),
}.into_error(field.position))
}
async fn resolve(
#[allow(unused_variables)]
#[allow(bare_trait_objects)]
fn create_field_stream<Query, Mutation>(
&self,
ctx: &#crate_name::ContextBase<'_, ()>,
types: &std::collections::HashMap<std::any::TypeId, #crate_name::graphql_parser::query::Field>,
msg: &(dyn std::any::Any + Send + Sync),
) -> #crate_name::Result<Option<#crate_name::serde_json::Value>> {
let tid = msg.type_id();
if let Some(field) = types.get(&tid) {
let ctx_field = ctx.with_field(field);
#(#filters)*
}
Ok(None)
ctx: &#crate_name::Context<'_>,
schema: &#crate_name::Schema<Query, Mutation, Self>,
environment: std::sync::Arc<#crate_name::Environment>,
) -> #crate_name::Result<std::pin::Pin<Box<dyn futures::Stream<Item = #crate_name::serde_json::Value>>>>
where
Query: #crate_name::ObjectType + Send + Sync + 'static,
Mutation: #crate_name::ObjectType + Send + Sync + 'static,
Self: Send + Sync + 'static + Sized,
{
#(#create_stream)*
Err(#crate_name::QueryError::FieldNotFound {
field_name: ctx.name.clone(),
object: #gql_typename.to_string(),
}.into_error(ctx.position))
}
}
};

View File

@ -242,7 +242,48 @@ impl<'a, T> Deref for ContextBase<'a, T> {
}
}
#[doc(hidden)]
pub struct Environment {
pub variables: Variables,
pub variable_definitions: Vec<VariableDefinition>,
pub fragments: HashMap<String, FragmentDefinition>,
}
impl Environment {
#[doc(hidden)]
pub fn create_context<'a, T>(
&'a self,
item: T,
path_node: Option<QueryPathNode<'a>>,
resolve_id: &'a AtomicUsize,
registry: &'a Registry,
data: &'a Data,
) -> ContextBase<'a, T> {
ContextBase {
path_node,
resolve_id,
extensions: &[],
item,
variables: &self.variables,
variable_definitions: &self.variable_definitions,
registry,
data,
ctx_data: None,
fragments: &self.fragments,
}
}
}
impl<'a, T> ContextBase<'a, T> {
#[doc(hidden)]
pub fn create_environment(&self) -> Environment {
Environment {
variables: self.variables.clone(),
variable_definitions: self.variable_definitions.to_vec(),
fragments: self.fragments.clone(),
}
}
#[doc(hidden)]
pub fn get_resolve_id(&self) -> usize {
self.resolve_id

View File

@ -76,7 +76,6 @@ extern crate serde_derive;
mod base;
mod context;
mod error;
pub mod extensions;
mod model;
mod mutation_resolver;
mod query;
@ -87,7 +86,7 @@ mod subscription;
mod types;
mod validation;
/// Input value validators
pub mod extensions;
pub mod validators;
#[doc(hidden)]
@ -95,6 +94,8 @@ pub use anyhow;
#[doc(hidden)]
pub use async_trait;
#[doc(hidden)]
pub use futures;
#[doc(hidden)]
pub use graphql_parser;
#[doc(hidden)]
pub use serde_json;
@ -102,17 +103,16 @@ pub use serde_json;
pub mod http;
pub use base::{Scalar, Type};
pub use context::{Context, QueryPathSegment, Variables};
pub use context::{Context, Environment, QueryPathNode, QueryPathSegment, Variables};
pub use error::{Error, ErrorExtensions, FieldError, FieldResult, QueryError, ResultExt};
pub use graphql_parser::query::Value;
pub use graphql_parser::Pos;
pub use query::{QueryBuilder, QueryResponse};
pub use registry::CacheControl;
pub use scalars::ID;
pub use schema::{publish, Schema};
pub use schema::Schema;
pub use subscription::{
SubscriptionConnectionBuilder, SubscriptionStream, SubscriptionStub, SubscriptionStubs,
SubscriptionTransport, WebSocketTransport,
SubscriptionStream, SubscriptionStreams, SubscriptionTransport, WebSocketTransport,
};
pub use types::{
Connection, DataSource, EmptyEdgeFields, EmptyMutation, EmptySubscription, QueryOperation,

View File

@ -3,32 +3,25 @@ use crate::extensions::{BoxExtension, Extension};
use crate::model::__DirectiveLocation;
use crate::query::QueryBuilder;
use crate::registry::{Directive, InputValue, Registry};
use crate::subscription::{SubscriptionConnectionBuilder, SubscriptionStub, SubscriptionTransport};
use crate::subscription::{create_connection, create_subscription_stream, SubscriptionTransport};
use crate::types::QueryRoot;
use crate::validation::{check_rules, CheckResult};
use crate::{
ContextSelectionSet, Error, ObjectType, Pos, QueryError, QueryResponse, Result,
SubscriptionType, Type, Variables,
SubscriptionStream, SubscriptionType, Type, Variables,
};
use bytes::Bytes;
use futures::channel::mpsc;
use futures::lock::Mutex;
use futures::{SinkExt, TryFutureExt};
use futures::{Stream, TryFutureExt};
use graphql_parser::parse_query;
use graphql_parser::query::{
Definition, Field, FragmentDefinition, OperationDefinition, Selection,
};
use graphql_parser::query::{Definition, OperationDefinition};
use itertools::Itertools;
use once_cell::sync::Lazy;
use slab::Slab;
use std::any::{Any, TypeId};
use std::any::Any;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
type MsgSender = mpsc::Sender<Arc<dyn Any + Sync + Send>>;
pub(crate) static SUBSCRIPTION_SENDERS: Lazy<Mutex<Slab<MsgSender>>> = Lazy::new(Default::default);
pub(crate) struct SchemaInner<Query, Mutation, Subscription> {
pub(crate) query: QueryRoot<Query>,
pub(crate) mutation: Mutation,
@ -211,6 +204,16 @@ where
Self::build(query, mutation, subscription).finish()
}
#[doc(hidden)]
pub fn data(&self) -> &Data {
&self.0.data
}
#[doc(hidden)]
pub fn registry(&self) -> &Registry {
&self.0.registry
}
/// Start a query and return `QueryBuilder`.
pub fn query(&self, source: &str) -> Result<QueryBuilder<Query, Mutation, Subscription>> {
let extensions = self
@ -233,13 +236,13 @@ where
if let Some(limit_complexity) = self.0.complexity {
if complexity > limit_complexity {
return Err(QueryError::TooComplex.into_error(Pos { line: 0, column: 0 }));
return Err(QueryError::TooComplex.into_error(Pos::default()));
}
}
if let Some(limit_depth) = self.0.depth {
if depth > limit_depth {
return Err(QueryError::TooDeep.into_error(Pos { line: 0, column: 0 }));
return Err(QueryError::TooDeep.into_error(Pos::default()));
}
}
@ -261,16 +264,13 @@ where
.await
}
/// Create subscription stub, typically called inside the `SubscriptionTransport::handle_request` method/
pub fn create_subscription_stub(
/// Create subscription stream, typically called inside the `SubscriptionTransport::handle_request` method
pub fn create_subscription_stream(
&self,
source: &str,
operation_name: Option<&str>,
variables: Variables,
) -> Result<SubscriptionStub<Query, Mutation, Subscription>>
where
Self: Sized,
{
) -> Result<Pin<Box<dyn Stream<Item = serde_json::Value>>>> {
let document = parse_query(source).map_err(Into::<Error>::into)?;
check_rules(&self.0.registry, &document)?;
@ -301,7 +301,6 @@ where
QueryError::MissingOperation.into_error(Pos::default())
})?;
let mut types = HashMap::new();
let resolve_id = AtomicUsize::default();
let ctx = ContextSelectionSet {
path_node: None,
@ -315,87 +314,20 @@ where
ctx_data: None,
fragments: &fragments,
};
create_subscription_types::<Subscription>(&ctx, &fragments, &mut types)?;
Ok(SubscriptionStub {
schema: self.clone(),
types,
variables,
variable_definitions: subscription.variable_definitions,
fragments,
ctx_data: None,
})
let mut streams = Vec::new();
create_subscription_stream(self, Arc::new(ctx.create_environment()), &ctx, &mut streams)?;
Ok(Box::pin(futures::stream::select_all(streams)))
}
/// Create subscription connection, returns `SubscriptionConnectionBuilder`.
/// Create subscription connection, returns `Sink` and `Stream`.
pub fn subscription_connection<T: SubscriptionTransport>(
&self,
transport: T,
) -> SubscriptionConnectionBuilder<Query, Mutation, Subscription, T> {
SubscriptionConnectionBuilder {
schema: self.clone(),
transport,
ctx_data: None,
}
}
}
fn create_subscription_types<T: SubscriptionType>(
ctx: &ContextSelectionSet<'_>,
fragments: &HashMap<String, FragmentDefinition>,
types: &mut HashMap<TypeId, Field>,
) -> Result<()> {
for selection in &ctx.items {
match selection {
Selection::Field(field) => {
if ctx.is_skip(&field.directives)? {
continue;
}
T::create_type(field, types)?;
}
Selection::FragmentSpread(fragment_spread) => {
if ctx.is_skip(&fragment_spread.directives)? {
continue;
}
if let Some(fragment) = fragments.get(&fragment_spread.fragment_name) {
create_subscription_types::<T>(
&ctx.with_selection_set(&fragment.selection_set),
fragments,
types,
)?;
} else {
return Err(QueryError::UnknownFragment {
name: fragment_spread.fragment_name.clone(),
}
.into_error(fragment_spread.position));
}
}
Selection::InlineFragment(inline_fragment) => {
if ctx.is_skip(&inline_fragment.directives)? {
continue;
}
create_subscription_types::<T>(
&ctx.with_selection_set(&inline_fragment.selection_set),
fragments,
types,
)?;
}
}
}
Ok(())
}
/// Publish a message that will be pushed to all subscribed clients.
pub async fn publish<T: Any + Send + Sync + Sized>(msg: T) {
let mut senders = SUBSCRIPTION_SENDERS.lock().await;
let msg = Arc::new(msg);
let mut remove = Vec::new();
for (id, sender) in senders.iter_mut() {
if sender.send(msg.clone()).await.is_err() {
remove.push(id);
}
}
for id in remove {
senders.remove(id);
) -> (
mpsc::Sender<Bytes>,
SubscriptionStream<Query, Mutation, Subscription, T>,
) {
create_connection(self.clone(), transport)
}
}

View File

@ -1,32 +1,25 @@
use crate::context::Data;
use crate::schema::SUBSCRIPTION_SENDERS;
use crate::subscription::SubscriptionStub;
use crate::{ObjectType, Result, Schema, SubscriptionType};
use crate::{ObjectType, Schema, SubscriptionType};
use bytes::Bytes;
use futures::channel::mpsc;
use futures::task::{Context, Poll};
use futures::{Future, FutureExt, Stream};
use futures::Stream;
use slab::Slab;
use std::any::Any;
use std::collections::VecDeque;
use std::pin::Pin;
use std::sync::Arc;
/// Subscription stubs, use to hold all subscription information for the `SubscriptionConnection`
pub struct SubscriptionStubs<Query, Mutation, Subscription> {
stubs: Slab<SubscriptionStub<Query, Mutation, Subscription>>,
ctx_data: Option<Arc<Data>>,
/// Use to hold all subscription stream for the `SubscriptionConnection`
pub struct SubscriptionStreams {
streams: Slab<Pin<Box<dyn Stream<Item = serde_json::Value>>>>,
}
#[allow(missing_docs)]
impl<Query, Mutation, Subscription> SubscriptionStubs<Query, Mutation, Subscription> {
pub fn add(&mut self, mut stub: SubscriptionStub<Query, Mutation, Subscription>) -> usize {
stub.ctx_data = self.ctx_data.clone();
self.stubs.insert(stub)
impl SubscriptionStreams {
pub fn add(&mut self, stream: Pin<Box<dyn Stream<Item = serde_json::Value>>>) -> usize {
self.streams.insert(stream)
}
pub fn remove(&mut self, id: usize) {
self.stubs.remove(id);
self.streams.remove(id);
}
}
@ -38,12 +31,12 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static {
type Error;
/// Parse the request data here.
/// If you have a new request, create a `SubscriptionStub` with the `Schema::create_subscription_stub`, and then call `SubscriptionStubs::add`.
/// If you have a new subscribe, create a stream with the `Schema::create_subscription_stream`, and then call `SubscriptionStreams::add`.
/// You can return a `Byte`, which will be sent to the client. If it returns an error, the connection will be broken.
fn handle_request<Query, Mutation, Subscription>(
&mut self,
schema: &Schema<Query, Mutation, Subscription>,
stubs: &mut SubscriptionStubs<Query, Mutation, Subscription>,
streams: &mut SubscriptionStreams,
data: Bytes,
) -> std::result::Result<Option<Bytes>, Self::Error>
where
@ -52,13 +45,12 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static {
Subscription: SubscriptionType + Sync + Send + 'static;
/// When a response message is generated, you can convert the message to the format you want here.
fn handle_response(&mut self, id: usize, result: Result<serde_json::Value>) -> Option<Bytes>;
fn handle_response(&mut self, id: usize, value: serde_json::Value) -> Option<Bytes>;
}
pub async fn create_connection<Query, Mutation, Subscription, T: SubscriptionTransport>(
pub fn create_connection<Query, Mutation, Subscription, T: SubscriptionTransport>(
schema: Schema<Query, Mutation, Subscription>,
transport: T,
ctx_data: Option<Data>,
) -> (
mpsc::Sender<Bytes>,
SubscriptionStream<Query, Mutation, Subscription, T>,
@ -69,23 +61,16 @@ where
Subscription: SubscriptionType + Sync + Send + 'static,
{
let (tx_bytes, rx_bytes) = mpsc::channel(8);
let (tx_msg, rx_msg) = mpsc::channel(8);
let mut senders = SUBSCRIPTION_SENDERS.lock().await;
senders.insert(tx_msg);
(
tx_bytes.clone(),
SubscriptionStream {
schema,
transport,
stubs: SubscriptionStubs {
stubs: Default::default(),
ctx_data: ctx_data.map(Arc::new),
streams: SubscriptionStreams {
streams: Default::default(),
},
rx_bytes,
rx_msg,
send_queue: VecDeque::new(),
resolve_queue: VecDeque::default(),
resolve_fut: None,
},
)
}
@ -94,12 +79,9 @@ where
pub struct SubscriptionStream<Query, Mutation, Subscription, T: SubscriptionTransport> {
schema: Schema<Query, Mutation, Subscription>,
transport: T,
stubs: SubscriptionStubs<Query, Mutation, Subscription>,
streams: SubscriptionStreams,
rx_bytes: mpsc::Receiver<Bytes>,
rx_msg: mpsc::Receiver<Arc<dyn Any + Sync + Send>>,
send_queue: VecDeque<Bytes>,
resolve_queue: VecDeque<Arc<dyn Any + Sync + Send>>,
resolve_fut: Option<Pin<Box<dyn Future<Output = ()>>>>,
}
impl<Query, Mutation, Subscription, T> Stream
@ -125,7 +107,7 @@ where
let this = &mut *self;
match this
.transport
.handle_request(&this.schema, &mut this.stubs, data)
.handle_request(&this.schema, &mut this.streams, data)
{
Ok(Some(bytes)) => {
this.send_queue.push_back(bytes);
@ -139,44 +121,38 @@ where
Poll::Pending => {}
}
if let Some(resolve_fut) = &mut self.resolve_fut {
match resolve_fut.poll_unpin(cx) {
Poll::Ready(_) => {
self.resolve_fut = None;
}
Poll::Pending => return Poll::Pending,
}
} else if let Some(msg) = self.resolve_queue.pop_front() {
// FIXME: I think this code is safe, but I don't know how to implement it in safe code.
let this = &mut *self;
let stubs = &this.stubs as *const SubscriptionStubs<Query, Mutation, Subscription>;
let transport = &mut this.transport as *mut T;
let send_queue = &mut this.send_queue as *mut VecDeque<Bytes>;
let fut = async move {
unsafe {
for (id, stub) in (*stubs).stubs.iter() {
if let Some(res) = stub.resolve(msg.as_ref()).await.transpose() {
if let Some(bytes) = (*transport).handle_response(id, res) {
(*send_queue).push_back(bytes);
// receive msg
let this = &mut *self;
if !this.streams.streams.is_empty() {
loop {
let mut num_closed = 0;
let mut num_pending = 0;
for (id, incoming_stream) in &mut this.streams.streams {
match incoming_stream.as_mut().poll_next(cx) {
Poll::Ready(Some(value)) => {
if let Some(bytes) = this.transport.handle_response(id, value) {
this.send_queue.push_back(bytes);
}
}
Poll::Ready(None) => {
num_closed += 1;
}
Poll::Pending => {
num_pending += 1;
}
}
}
};
self.resolve_fut = Some(Box::pin(fut));
continue;
}
// receive msg
match Pin::new(&mut self.rx_msg).poll_next(cx) {
Poll::Ready(Some(msg)) => {
self.resolve_queue.push_back(msg);
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => {
// all pending
return Poll::Pending;
if num_closed == this.streams.streams.len() {
// all closed
return Poll::Ready(None);
} else if num_pending == this.streams.streams.len() {
return Poll::Pending;
}
}
} else {
return Poll::Pending;
}
}
}

View File

@ -1,43 +0,0 @@
use crate::context::Data;
use crate::subscription::create_connection;
use crate::{ObjectType, Schema, SubscriptionStream, SubscriptionTransport, SubscriptionType};
use bytes::Bytes;
use futures::channel::mpsc;
use std::any::Any;
/// SubscriptionConnection builder
pub struct SubscriptionConnectionBuilder<Query, Mutation, Subscription, T: SubscriptionTransport> {
pub(crate) schema: Schema<Query, Mutation, Subscription>,
pub(crate) transport: T,
pub(crate) ctx_data: Option<Data>,
}
impl<Query, Mutation, Subscription, T: SubscriptionTransport>
SubscriptionConnectionBuilder<Query, Mutation, Subscription, T>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
/// Add a context data that can be accessed in the `Context`, you access it with `Context::data`.
pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
if let Some(ctx_data) = &mut self.ctx_data {
ctx_data.insert(data);
} else {
let mut ctx_data = Data::default();
ctx_data.insert(data);
self.ctx_data = Some(ctx_data);
}
self
}
/// Create subscription connection, returns `Sink` and `Stream`.
pub async fn build(
self,
) -> (
mpsc::Sender<Bytes>,
SubscriptionStream<Query, Mutation, Subscription, T>,
) {
create_connection(self.schema, self.transport, self.ctx_data).await
}
}

View File

@ -1,13 +1,9 @@
mod connection;
mod connection_builder;
mod subscription_stub;
mod subscription_type;
mod ws_transport;
pub use connection::{
create_connection, SubscriptionStream, SubscriptionStubs, SubscriptionTransport,
create_connection, SubscriptionStream, SubscriptionStreams, SubscriptionTransport,
};
pub use connection_builder::SubscriptionConnectionBuilder;
pub use subscription_stub::SubscriptionStub;
pub use subscription_type::SubscriptionType;
pub use subscription_type::{create_subscription_stream, SubscriptionType};
pub use ws_transport::WebSocketTransport;

View File

@ -1,52 +0,0 @@
use crate::context::Data;
use crate::{ContextBase, ObjectType, Result, Schema, SubscriptionType, Variables};
use graphql_parser::query::{Field, FragmentDefinition, VariableDefinition};
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
/// Subscription stub
///
/// When a new push message is generated, a JSON object that needs to be pushed can be obtained by
/// `Subscribe::resolve`, and if None is returned, the Subscribe is not subscribed to a message of this type.
pub struct SubscriptionStub<Query, Mutation, Subscription> {
pub(crate) schema: Schema<Query, Mutation, Subscription>,
pub(crate) types: HashMap<TypeId, Field>,
pub(crate) variables: Variables,
pub(crate) variable_definitions: Vec<VariableDefinition>,
pub(crate) fragments: HashMap<String, FragmentDefinition>,
pub(crate) ctx_data: Option<Arc<Data>>,
}
impl<Query, Mutation, Subscription> SubscriptionStub<Query, Mutation, Subscription>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
#[doc(hidden)]
pub async fn resolve(
&self,
msg: &(dyn Any + Send + Sync),
) -> Result<Option<serde_json::Value>> {
let resolve_id = AtomicUsize::default();
let ctx = ContextBase::<()> {
path_node: None,
extensions: &[],
item: (),
resolve_id: &resolve_id,
variables: &self.variables,
variable_definitions: &self.variable_definitions,
registry: &self.schema.0.registry,
data: &self.schema.0.data,
ctx_data: self.ctx_data.as_deref(),
fragments: &self.fragments,
};
self.schema
.0
.subscription
.resolve(&ctx, &self.types, msg)
.await
}
}

View File

@ -1,7 +1,9 @@
use crate::{ContextBase, Result, Type};
use graphql_parser::query::Field;
use std::any::{Any, TypeId};
use std::collections::HashMap;
use crate::context::Environment;
use crate::{Context, ContextSelectionSet, ObjectType, Result, Schema, Type};
use futures::Stream;
use graphql_parser::query::{Selection, TypeCondition};
use std::pin::Pin;
use std::sync::Arc;
/// Represents a GraphQL subscription object
#[async_trait::async_trait]
@ -13,13 +15,79 @@ pub trait SubscriptionType: Type {
}
#[doc(hidden)]
fn create_type(field: &Field, types: &mut HashMap<TypeId, Field>) -> Result<()>;
/// Resolve a subscription message, If no message of this type is subscribed, None is returned.
async fn resolve(
fn create_field_stream<Query, Mutation>(
&self,
ctx: &ContextBase<'_, ()>,
types: &HashMap<TypeId, Field>,
msg: &(dyn Any + Send + Sync),
) -> Result<Option<serde_json::Value>>;
ctx: &Context<'_>,
schema: &Schema<Query, Mutation, Self>,
environment: Arc<Environment>,
) -> Result<Pin<Box<dyn Stream<Item = serde_json::Value>>>>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Self: Send + Sync + 'static + Sized;
}
pub fn create_subscription_stream<Query, Mutation, Subscription>(
schema: &Schema<Query, Mutation, Subscription>,
environment: Arc<Environment>,
ctx: &ContextSelectionSet<'_>,
streams: &mut Vec<Pin<Box<dyn Stream<Item = serde_json::Value>>>>,
) -> Result<()>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static + Sized,
{
for selection in &ctx.items {
match selection {
Selection::Field(field) => {
if ctx.is_skip(&field.directives)? {
continue;
}
streams.push(schema.0.subscription.create_field_stream(
&ctx.with_field(field),
schema,
environment.clone(),
)?)
}
Selection::FragmentSpread(fragment_spread) => {
if ctx.is_skip(&fragment_spread.directives)? {
continue;
}
if let Some(fragment) = ctx.fragments.get(fragment_spread.fragment_name.as_str()) {
create_subscription_stream(
schema,
environment.clone(),
&ctx.with_selection_set(&fragment.selection_set),
streams,
)?;
}
}
Selection::InlineFragment(inline_fragment) => {
if ctx.is_skip(&inline_fragment.directives)? {
continue;
}
if let Some(TypeCondition::On(name)) = &inline_fragment.type_condition {
if name.as_str() == Subscription::type_name() {
create_subscription_stream(
schema,
environment.clone(),
&ctx.with_selection_set(&inline_fragment.selection_set),
streams,
)?;
}
} else {
create_subscription_stream(
schema,
environment.clone(),
&ctx.with_selection_set(&inline_fragment.selection_set),
streams,
)?;
}
}
}
}
Ok(())
}

View File

@ -1,6 +1,6 @@
use crate::http::{GQLError, GQLRequest, GQLResponse};
use crate::{
ObjectType, QueryResponse, Result, Schema, SubscriptionStubs, SubscriptionTransport,
ObjectType, QueryResponse, Schema, SubscriptionStreams, SubscriptionTransport,
SubscriptionType, Variables,
};
use bytes::Bytes;
@ -27,7 +27,7 @@ impl SubscriptionTransport for WebSocketTransport {
fn handle_request<Query, Mutation, Subscription>(
&mut self,
schema: &Schema<Query, Mutation, Subscription>,
stubs: &mut SubscriptionStubs<Query, Mutation, Subscription>,
streams: &mut SubscriptionStreams,
data: Bytes,
) -> std::result::Result<Option<Bytes>, Self::Error>
where
@ -54,16 +54,15 @@ impl SubscriptionTransport for WebSocketTransport {
.map(|value| Variables::parse_from_json(value).ok())
.flatten()
.unwrap_or_default();
match schema.create_subscription_stub(
match schema.create_subscription_stream(
&request.query,
request.operation_name.as_deref(),
variables,
) {
Ok(stub) => {
let stub_id = stubs.add(stub);
self.id_to_sid.insert(id.clone(), stub_id);
self.sid_to_id.insert(stub_id, id);
Ok(stream) => {
let stream_id = streams.add(stream);
self.id_to_sid.insert(id.clone(), stream_id);
self.sid_to_id.insert(stream_id, id);
Ok(None)
}
Err(err) => Ok(Some(
@ -89,7 +88,7 @@ impl SubscriptionTransport for WebSocketTransport {
if let Some(id) = msg.id {
if let Some(id) = self.id_to_sid.remove(&id) {
self.sid_to_id.remove(&id);
stubs.remove(id);
streams.remove(id);
}
}
Ok(None)
@ -101,15 +100,15 @@ impl SubscriptionTransport for WebSocketTransport {
}
}
fn handle_response(&mut self, id: usize, result: Result<serde_json::Value>) -> Option<Bytes> {
fn handle_response(&mut self, id: usize, value: serde_json::Value) -> Option<Bytes> {
if let Some(id) = self.sid_to_id.get(&id) {
Some(
serde_json::to_vec(&OperationMessage {
ty: "data".to_string(),
id: Some(id.clone()),
payload: Some(
serde_json::to_value(GQLResponse(result.map(|data| QueryResponse {
data,
serde_json::to_value(GQLResponse(Ok(QueryResponse {
data: value,
extensions: None,
})))
.unwrap(),

View File

@ -1,14 +1,13 @@
use crate::context::Environment;
use crate::{
registry, ContextBase, ContextSelectionSet, Error, OutputValueType, QueryError, Result,
SubscriptionType, Type,
registry, Context, ContextSelectionSet, Error, ObjectType, OutputValueType, QueryError, Result,
Schema, SubscriptionType, Type,
};
use graphql_parser::query::Field;
use futures::Stream;
use graphql_parser::Pos;
use serde_json::Value;
use std::any::{Any, TypeId};
use std::borrow::Cow;
use std::collections::hash_map::RandomState;
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
/// Empty subscription
///
@ -36,17 +35,22 @@ impl SubscriptionType for EmptySubscription {
true
}
fn create_type(_field: &Field, _types: &mut HashMap<TypeId, Field>) -> Result<()> {
unreachable!()
}
async fn resolve(
fn create_field_stream<Query, Mutation>(
&self,
_ctx: &ContextBase<'_, ()>,
_types: &HashMap<TypeId, Field, RandomState>,
_msg: &(dyn Any + Send + Sync),
) -> Result<Option<Value>> {
unreachable!()
_ctx: &Context<'_>,
_schema: &Schema<Query, Mutation, Self>,
_environment: Arc<Environment>,
) -> Result<Pin<Box<dyn Stream<Item = serde_json::Value>>>>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Self: Send + Sync + 'static + Sized,
{
Err(Error::Query {
pos: Pos::default(),
path: None,
err: QueryError::NotConfiguredSubscriptions,
})
}
}

View File

@ -1,3 +1,5 @@
//! Input value validators
mod int_validators;
mod list_validators;
mod string_validators;