Add SimpleBroker

This commit is contained in:
sunli 2020-04-07 14:30:46 +08:00
parent 7b85062c6e
commit 5c710ff744
12 changed files with 299 additions and 100 deletions

View File

@ -1,36 +1,135 @@
use actix::clock::Duration;
use actix_web::{web, App, HttpServer};
use async_graphql::{EmptyMutation, Schema};
use async_graphql::{Context, FieldResult, Schema, SimpleBroker, ID};
use futures::lock::Mutex;
use futures::{Stream, StreamExt};
use slab::Slab;
use std::sync::Arc;
#[derive(Clone)]
struct Book {
id: ID,
name: String,
author: String,
}
#[async_graphql::Object]
impl Book {
#[field]
async fn id(&self) -> &str {
&self.id
}
#[field]
async fn name(&self) -> &str {
&self.name
}
#[field]
async fn author(&self) -> &str {
&self.author
}
}
type Storage = Arc<Mutex<Slab<Book>>>;
struct QueryRoot;
#[async_graphql::Object]
impl QueryRoot {
#[field]
async fn value(&self) -> i32 {
0
async fn books(&self, ctx: &Context<'_>) -> Vec<Book> {
let books = ctx.data::<Storage>().lock().await;
books.iter().map(|(_, book)| book).cloned().collect()
}
}
struct MutationRoot;
#[async_graphql::Object]
impl MutationRoot {
#[field]
async fn create_book(&self, ctx: &Context<'_>, name: String, author: String) -> ID {
let mut books = ctx.data::<Storage>().lock().await;
let entry = books.vacant_entry();
let id: ID = entry.key().into();
let book = Book {
id: id.clone(),
name,
author,
};
entry.insert(book);
SimpleBroker::publish(BookChanged {
mutation_type: MutationType::Created,
id: id.clone(),
});
id
}
#[field]
async fn delete_book(&self, ctx: &Context<'_>, id: ID) -> FieldResult<bool> {
let mut books = ctx.data::<Storage>().lock().await;
let id = id.parse::<usize>()?;
if books.contains(id) {
books.remove(id);
SimpleBroker::publish(BookChanged {
mutation_type: MutationType::Deleted,
id: id.into(),
});
Ok(true)
} else {
Ok(false)
}
}
}
#[async_graphql::Enum]
enum MutationType {
Created,
Deleted,
}
#[async_graphql::SimpleObject]
#[derive(Clone)]
struct BookChanged {
#[field]
mutation_type: MutationType,
#[field]
id: ID,
}
struct SubscriptionRoot;
#[async_graphql::Subscription]
impl SubscriptionRoot {
#[field]
fn interval(&self, n: i32) -> impl Stream<Item = i32> {
async fn interval(&self, n: i32) -> impl Stream<Item = i32> {
let mut value = 0;
actix_rt::time::interval(Duration::from_secs(1)).map(move |_| {
value += n;
value
})
}
#[field]
async fn books(&self, mutation_type: Option<MutationType>) -> impl Stream<Item = BookChanged> {
if let Some(mutation_type) = mutation_type {
SimpleBroker::<BookChanged>::subscribe()
.filter(move |event| futures::future::ready(event.mutation_type == mutation_type))
.boxed()
} else {
SimpleBroker::<BookChanged>::subscribe().boxed()
}
}
}
#[actix_rt::main]
async fn main() -> std::io::Result<()> {
HttpServer::new(move || {
let schema = Schema::new(QueryRoot, EmptyMutation, SubscriptionRoot);
let schema = Schema::build(QueryRoot, MutationRoot, SubscriptionRoot)
.data(Storage::default())
.finish();
let handler = async_graphql_actix_web::HandlerBuilder::new(schema)
.enable_ui("http://localhost:8000", Some("ws://localhost:8000"))
.enable_subscription()

View File

@ -8,6 +8,9 @@ use futures::channel::mpsc;
use futures::SinkExt;
use std::time::{Duration, Instant};
const HEARTBEAT_INTERVAL: Duration = Duration::from_secs(5);
const CLIENT_TIMEOUT: Duration = Duration::from_secs(10);
pub struct WsSession<Query, Mutation, Subscription> {
schema: Schema<Query, Mutation, Subscription>,
hb: Instant,
@ -29,10 +32,11 @@ where
}
fn hb(&self, ctx: &mut WebsocketContext<Self>) {
ctx.run_interval(Duration::new(1, 0), |act, ctx| {
if Instant::now().duration_since(act.hb) > Duration::new(10, 0) {
ctx.run_interval(HEARTBEAT_INTERVAL, |act, ctx| {
if Instant::now().duration_since(act.hb) > CLIENT_TIMEOUT {
ctx.stop();
}
ctx.ping(b"");
});
}
}

View File

@ -54,10 +54,10 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
.map(|s| quote! {Some(#s)})
.unwrap_or_else(|| quote! {None});
if method.sig.asyncness.is_some() {
if method.sig.asyncness.is_none() {
return Err(Error::new_spanned(
&method.sig.asyncness,
"The subscription stream function must be synchronous",
"The subscription stream function must be asynchronous",
));
}
@ -187,7 +187,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
let schema = schema.clone();
let pos = ctx.position;
let environment = environment.clone();
let stream = #crate_name::futures::stream::StreamExt::then(self.#ident(#(#use_params),*).fuse(), move |msg| {
let stream = #crate_name::futures::stream::StreamExt::then(self.#ident(#(#use_params),*).await.fuse(), move |msg| {
let environment = environment.clone();
let field_selection_set = field_selection_set.clone();
let schema = schema.clone();
@ -248,12 +248,12 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
impl #crate_name::SubscriptionType for SubscriptionRoot {
#[allow(unused_variables)]
#[allow(bare_trait_objects)]
fn create_field_stream<Query, Mutation>(
async fn create_field_stream<Query, Mutation>(
&self,
ctx: &#crate_name::Context<'_>,
schema: &#crate_name::Schema<Query, Mutation, Self>,
environment: std::sync::Arc<#crate_name::Environment>,
) -> #crate_name::Result<std::pin::Pin<Box<dyn futures::Stream<Item = #crate_name::serde_json::Value>>>>
) -> #crate_name::Result<std::pin::Pin<Box<dyn futures::Stream<Item = #crate_name::serde_json::Value> + Send>>>
where
Query: #crate_name::ObjectType + Send + Sync + 'static,
Mutation: #crate_name::ObjectType + Send + Sync + 'static,

View File

@ -112,7 +112,8 @@ pub use registry::CacheControl;
pub use scalars::ID;
pub use schema::Schema;
pub use subscription::{
SubscriptionStream, SubscriptionStreams, SubscriptionTransport, WebSocketTransport,
SimpleBroker, SubscriptionStream, SubscriptionStreams, SubscriptionTransport,
WebSocketTransport,
};
pub use types::{
Connection, DataSource, EmptyEdgeFields, EmptyMutation, EmptySubscription, QueryOperation,

View File

@ -272,12 +272,12 @@ where
}
/// Create subscription stream, typically called inside the `SubscriptionTransport::handle_request` method
pub fn create_subscription_stream(
pub async fn create_subscription_stream(
&self,
source: &str,
operation_name: Option<&str>,
variables: Variables,
) -> Result<impl Stream<Item = serde_json::Value>> {
) -> Result<impl Stream<Item = serde_json::Value> + Send> {
let document = parse_query(source).map_err(Into::<Error>::into)?;
check_rules(&self.0.registry, &document, self.0.validation_mode)?;
@ -323,7 +323,8 @@ where
};
let mut streams = Vec::new();
create_subscription_stream(self, Arc::new(ctx.create_environment()), &ctx, &mut streams)?;
create_subscription_stream(self, Arc::new(ctx.create_environment()), &ctx, &mut streams)
.await?;
Ok(futures::stream::select_all(streams))
}

View File

@ -4,17 +4,20 @@ use futures::channel::mpsc;
use futures::task::{Context, Poll};
use futures::Stream;
use slab::Slab;
use std::collections::VecDeque;
use std::future::Future;
use std::pin::Pin;
/// Use to hold all subscription stream for the `SubscriptionConnection`
pub struct SubscriptionStreams {
streams: Slab<Pin<Box<dyn Stream<Item = serde_json::Value>>>>,
streams: Slab<Pin<Box<dyn Stream<Item = serde_json::Value> + Send>>>,
}
#[allow(missing_docs)]
impl SubscriptionStreams {
pub fn add<S: Stream<Item = serde_json::Value> + 'static>(&mut self, stream: S) -> usize {
pub fn add<S: Stream<Item = serde_json::Value> + Send + 'static>(
&mut self,
stream: S,
) -> usize {
self.streams.insert(Box::pin(stream))
}
@ -26,6 +29,7 @@ impl SubscriptionStreams {
/// Subscription transport
///
/// You can customize your transport by implementing this trait.
#[async_trait::async_trait]
pub trait SubscriptionTransport: Send + Sync + Unpin + 'static {
/// The error type.
type Error;
@ -33,7 +37,7 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static {
/// Parse the request data here.
/// If you have a new subscribe, create a stream with the `Schema::create_subscription_stream`, and then call `SubscriptionStreams::add`.
/// You can return a `Byte`, which will be sent to the client. If it returns an error, the connection will be broken.
fn handle_request<Query, Mutation, Subscription>(
async fn handle_request<Query, Mutation, Subscription>(
&mut self,
schema: &Schema<Query, Mutation, Subscription>,
streams: &mut SubscriptionStreams,
@ -70,7 +74,7 @@ where
streams: Default::default(),
},
rx_bytes,
send_queue: VecDeque::new(),
handle_request_fut: None,
},
)
}
@ -81,7 +85,9 @@ pub struct SubscriptionStream<Query, Mutation, Subscription, T: SubscriptionTran
transport: T,
streams: SubscriptionStreams,
rx_bytes: mpsc::Receiver<Bytes>,
send_queue: VecDeque<Bytes>,
handle_request_fut: Option<
Pin<Box<dyn Future<Output = std::result::Result<Option<Bytes>, T::Error>> + 'static>>,
>,
}
impl<Query, Mutation, Subscription, T> Stream
@ -95,34 +101,44 @@ where
type Item = Bytes;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
loop {
// send bytes
if let Some(bytes) = self.send_queue.pop_front() {
return Poll::Ready(Some(bytes));
}
let this = &mut *self;
loop {
// receive bytes
match Pin::new(&mut self.rx_bytes).poll_next(cx) {
Poll::Ready(Some(data)) => {
let this = &mut *self;
match this
.transport
.handle_request(&this.schema, &mut this.streams, data)
{
Ok(Some(bytes)) => {
this.send_queue.push_back(bytes);
continue;
if let Some(handle_request_fut) = &mut this.handle_request_fut {
match handle_request_fut.as_mut().poll(cx) {
Poll::Ready(Ok(bytes)) => {
this.handle_request_fut = None;
if let Some(bytes) = bytes {
return Poll::Ready(Some(bytes));
}
Ok(None) => {}
Err(_) => return Poll::Ready(None),
continue;
}
Poll::Ready(Err(_)) => return Poll::Ready(None),
Poll::Pending => {}
}
} else {
match Pin::new(&mut this.rx_bytes).poll_next(cx) {
Poll::Ready(Some(data)) => {
// The following code I think is safe.😁
let transport = &mut this.transport as *mut T;
let schema = &this.schema as *const Schema<Query, Mutation, Subscription>;
let streams = &mut this.streams as *mut SubscriptionStreams;
unsafe {
this.handle_request_fut = Some(Box::pin((*transport).handle_request(
&*schema,
&mut *streams,
data,
)));
}
continue;
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => {}
}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Pending => {}
}
// receive msg
let this = &mut *self;
if !this.streams.streams.is_empty() {
loop {
let mut num_closed = 0;
@ -132,7 +148,7 @@ where
match incoming_stream.as_mut().poll_next(cx) {
Poll::Ready(Some(value)) => {
if let Some(bytes) = this.transport.handle_response(id, value) {
this.send_queue.push_back(bytes);
return Poll::Ready(Some(bytes));
}
}
Poll::Ready(None) => {

View File

@ -1,9 +1,11 @@
mod connection;
mod simple_broker;
mod subscription_type;
mod ws_transport;
pub use connection::{
create_connection, SubscriptionStream, SubscriptionStreams, SubscriptionTransport,
};
pub use simple_broker::SimpleBroker;
pub use subscription_type::{create_subscription_stream, SubscriptionType};
pub use ws_transport::WebSocketTransport;

View File

@ -0,0 +1,55 @@
use futures::channel::mpsc::{self, UnboundedReceiver, UnboundedSender};
use futures::task::{Context, Poll};
use futures::{Stream, StreamExt};
use once_cell::sync::OnceCell;
use serde::export::PhantomData;
use slab::Slab;
use std::any::Any;
use std::pin::Pin;
use std::sync::Mutex;
struct Senders<T>(Mutex<Slab<UnboundedSender<T>>>);
struct BrokerStream<T: Sync + Send + Clone + 'static>(usize, UnboundedReceiver<T>);
impl<T: Sync + Send + Clone + 'static> Drop for BrokerStream<T> {
fn drop(&mut self) {
let mut senders = SimpleBroker::<T>::senders().0.lock().unwrap();
senders.remove(self.0);
}
}
impl<T: Sync + Send + Clone + 'static> Stream for BrokerStream<T> {
type Item = T;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.1.poll_next_unpin(cx)
}
}
/// A simple broker based on memory
pub struct SimpleBroker<T>(PhantomData<T>);
impl<T: Sync + Send + Clone + 'static> SimpleBroker<T> {
fn senders() -> &'static Senders<T> {
static SUBSCRIBERS: OnceCell<Box<dyn Any + Send + Sync>> = OnceCell::new();
let instance = SUBSCRIBERS.get_or_init(|| Box::new(Senders::<T>(Mutex::new(Slab::new()))));
instance.downcast_ref::<Senders<T>>().unwrap()
}
/// Publish a message that all subscription streams can receive.
pub fn publish(msg: T) {
let mut senders = Self::senders().0.lock().unwrap();
for (_, sender) in senders.iter_mut() {
sender.start_send(msg.clone()).ok();
}
}
/// Subscribe to the message of the specified type and returns a `Stream`.
pub fn subscribe() -> impl Stream<Item = T> {
let mut senders = Self::senders().0.lock().unwrap();
let (tx, rx) = mpsc::unbounded();
let id = senders.insert(tx);
BrokerStream(id, rx)
}
}

View File

@ -1,6 +1,6 @@
use crate::context::Environment;
use crate::{Context, ContextSelectionSet, ObjectType, Result, Schema, Type};
use futures::Stream;
use futures::{Future, Stream};
use graphql_parser::query::{Selection, TypeCondition};
use std::pin::Pin;
use std::sync::Arc;
@ -15,79 +15,94 @@ pub trait SubscriptionType: Type {
}
#[doc(hidden)]
fn create_field_stream<Query, Mutation>(
async fn create_field_stream<Query, Mutation>(
&self,
ctx: &Context<'_>,
schema: &Schema<Query, Mutation, Self>,
environment: Arc<Environment>,
) -> Result<Pin<Box<dyn Stream<Item = serde_json::Value>>>>
) -> Result<Pin<Box<dyn Stream<Item = serde_json::Value> + Send>>>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Self: Send + Sync + 'static + Sized;
}
pub fn create_subscription_stream<Query, Mutation, Subscription>(
schema: &Schema<Query, Mutation, Subscription>,
type BoxCreateStreamFuture<'a> = Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>;
pub fn create_subscription_stream<'a, Query, Mutation, Subscription>(
schema: &'a Schema<Query, Mutation, Subscription>,
environment: Arc<Environment>,
ctx: &ContextSelectionSet<'_>,
streams: &mut Vec<Pin<Box<dyn Stream<Item = serde_json::Value>>>>,
) -> Result<()>
ctx: &'a ContextSelectionSet<'_>,
streams: &'a mut Vec<Pin<Box<dyn Stream<Item = serde_json::Value> + Send>>>,
) -> BoxCreateStreamFuture<'a>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static + Sized,
{
for selection in &ctx.items {
match selection {
Selection::Field(field) => {
if ctx.is_skip(&field.directives)? {
continue;
}
streams.push(schema.0.subscription.create_field_stream(
&ctx.with_field(field),
schema,
environment.clone(),
)?)
}
Selection::FragmentSpread(fragment_spread) => {
if ctx.is_skip(&fragment_spread.directives)? {
continue;
Box::pin(async move {
for selection in &ctx.items {
match selection {
Selection::Field(field) => {
if ctx.is_skip(&field.directives)? {
continue;
}
streams.push(
schema
.0
.subscription
.create_field_stream(
&ctx.with_field(field),
schema,
environment.clone(),
)
.await?,
)
}
Selection::FragmentSpread(fragment_spread) => {
if ctx.is_skip(&fragment_spread.directives)? {
continue;
}
if let Some(fragment) = ctx.fragments.get(fragment_spread.fragment_name.as_str()) {
create_subscription_stream(
schema,
environment.clone(),
&ctx.with_selection_set(&fragment.selection_set),
streams,
)?;
}
}
Selection::InlineFragment(inline_fragment) => {
if ctx.is_skip(&inline_fragment.directives)? {
continue;
if let Some(fragment) =
ctx.fragments.get(fragment_spread.fragment_name.as_str())
{
create_subscription_stream(
schema,
environment.clone(),
&ctx.with_selection_set(&fragment.selection_set),
streams,
)
.await?;
}
}
Selection::InlineFragment(inline_fragment) => {
if ctx.is_skip(&inline_fragment.directives)? {
continue;
}
if let Some(TypeCondition::On(name)) = &inline_fragment.type_condition {
if name.as_str() == Subscription::type_name() {
if let Some(TypeCondition::On(name)) = &inline_fragment.type_condition {
if name.as_str() == Subscription::type_name() {
create_subscription_stream(
schema,
environment.clone(),
&ctx.with_selection_set(&inline_fragment.selection_set),
streams,
)
.await?;
}
} else {
create_subscription_stream(
schema,
environment.clone(),
&ctx.with_selection_set(&inline_fragment.selection_set),
streams,
)?;
)
.await?;
}
} else {
create_subscription_stream(
schema,
environment.clone(),
&ctx.with_selection_set(&inline_fragment.selection_set),
streams,
)?;
}
}
}
}
Ok(())
Ok(())
})
}

View File

@ -21,10 +21,11 @@ pub struct WebSocketTransport {
sid_to_id: HashMap<usize, String>,
}
#[async_trait::async_trait]
impl SubscriptionTransport for WebSocketTransport {
type Error = String;
fn handle_request<Query, Mutation, Subscription>(
async fn handle_request<Query, Mutation, Subscription>(
&mut self,
schema: &Schema<Query, Mutation, Subscription>,
streams: &mut SubscriptionStreams,
@ -54,11 +55,14 @@ impl SubscriptionTransport for WebSocketTransport {
.map(|value| Variables::parse_from_json(value).ok())
.flatten()
.unwrap_or_default();
match schema.create_subscription_stream(
&request.query,
request.operation_name.as_deref(),
variables,
) {
match schema
.create_subscription_stream(
&request.query,
request.operation_name.as_deref(),
variables,
)
.await
{
Ok(stream) => {
let stream_id = streams.add(stream);
self.id_to_sid.insert(id.clone(), stream_id);

View File

@ -35,12 +35,12 @@ impl SubscriptionType for EmptySubscription {
true
}
fn create_field_stream<Query, Mutation>(
async fn create_field_stream<Query, Mutation>(
&self,
_ctx: &Context<'_>,
_schema: &Schema<Query, Mutation, Self>,
_environment: Arc<Environment>,
) -> Result<Pin<Box<dyn Stream<Item = serde_json::Value>>>>
) -> Result<Pin<Box<dyn Stream<Item = serde_json::Value> + Send>>>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,

View File

@ -22,12 +22,12 @@ pub async fn test_subscription() {
#[Subscription]
impl SubscriptionRoot {
#[field]
fn values(&self, start: i32, end: i32) -> impl Stream<Item = i32> {
async fn values(&self, start: i32, end: i32) -> impl Stream<Item = i32> {
futures::stream::iter(start..end)
}
#[field]
fn events(&self, start: i32, end: i32) -> impl Stream<Item = Event> {
async fn events(&self, start: i32, end: i32) -> impl Stream<Item = Event> {
futures::stream::iter((start..end).map(|n| Event { a: n, b: n * 10 }))
}
}
@ -41,6 +41,7 @@ pub async fn test_subscription() {
None,
Default::default(),
)
.await
.unwrap();
for i in 10..20 {
assert_eq!(
@ -58,6 +59,7 @@ pub async fn test_subscription() {
None,
Default::default(),
)
.await
.unwrap();
for i in 10..20 {
assert_eq!(