diff --git a/Cargo.toml b/Cargo.toml index 45796e4f..0a53fc18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "async-graphql" -version = "1.6.3" +version = "1.6.4" authors = ["sunli "] edition = "2018" description = "The GraphQL server library implemented by rust" @@ -15,10 +15,10 @@ readme = "README.md" [features] default = ["bson", "uuid", "url", "chrono-tz", "validators"] -validators = ["regex", "once_cell"] +validators = ["regex"] [dependencies] -async-graphql-derive = { path = "async-graphql-derive", version = "1.6.3" } +async-graphql-derive = { path = "async-graphql-derive", version = "1.6.4" } graphql-parser = "0.2.3" anyhow = "1.0.26" thiserror = "1.0.11" @@ -34,7 +34,8 @@ byteorder = "1.3.4" futures = "0.3.0" parking_lot = "0.10.0" chrono = "0.4.10" -once_cell = { version = "1.3.1", optional = true } +slab = "0.4.2" +once_cell = "1.3.1" regex = { version = "1.3.5", optional = true } bson = { version = "0.14.1", optional = true } uuid = { version = "0.8.1", optional = true } @@ -45,7 +46,6 @@ chrono-tz = { version = "0.5.1", optional = true } async-std = { version = "1.5.0", features = ["attributes"] } actix-web = "2.0.0" actix-rt = "1.0.0" -slab = "0.4.2" tide = "0.6.0" mime = "0.3.16" diff --git a/async-graphql-actix-web/Cargo.toml b/async-graphql-actix-web/Cargo.toml index f1f3d2cd..e8d468be 100644 --- a/async-graphql-actix-web/Cargo.toml +++ b/async-graphql-actix-web/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "async-graphql-actix-web" -version = "0.6.6" +version = "0.6.7" authors = ["sunli "] edition = "2018" description = "async-graphql for actix-web" @@ -13,7 +13,7 @@ keywords = ["futures", "async", "graphql"] categories = ["network-programming", "asynchronous"] [dependencies] -async-graphql = { path = "..", version = "1.6.3" } +async-graphql = { path = "..", version = "1.6.4" } actix-web = "2.0.0" actix-multipart = "0.2.0" actix-web-actors = "2.0.0" diff --git a/async-graphql-actix-web/examples/subscription.rs b/async-graphql-actix-web/examples/subscription.rs index 903eb459..ce52449d 100644 --- a/async-graphql-actix-web/examples/subscription.rs +++ b/async-graphql-actix-web/examples/subscription.rs @@ -1,6 +1,5 @@ use actix_web::{web, App, HttpServer}; -use async_graphql::{Context, Result, Schema, ID}; -use async_graphql_actix_web::publish_message; +use async_graphql::{publish, Context, Result, Schema, ID}; use futures::lock::Mutex; use slab::Slab; use std::sync::Arc; @@ -58,10 +57,11 @@ impl MutationRoot { author, }; entry.insert(book); - publish_message(BookChanged { + publish(BookChanged { mutation_type: MutationType::Created, id: id.clone(), - }); + }) + .await; id } @@ -71,10 +71,11 @@ impl MutationRoot { let id = id.parse::()?; if books.contains(id) { books.remove(id); - publish_message(BookChanged { + publish(BookChanged { mutation_type: MutationType::Deleted, id: id.into(), - }); + }) + .await; Ok(true) } else { Ok(false) @@ -122,8 +123,9 @@ impl SubscriptionRoot { #[actix_rt::main] async fn main() -> std::io::Result<()> { HttpServer::new(move || { - let schema = - Schema::new(QueryRoot, MutationRoot, SubscriptionRoot).data(Storage::default()); + let schema = Schema::build(QueryRoot, MutationRoot, SubscriptionRoot) + .data(Storage::default()) + .finish(); let handler = async_graphql_actix_web::HandlerBuilder::new(schema) .enable_ui("http://localhost:8000", Some("ws://localhost:8000")) .enable_subscription() diff --git a/async-graphql-actix-web/src/lib.rs b/async-graphql-actix-web/src/lib.rs index 0b55acbf..997b6879 100644 --- a/async-graphql-actix-web/src/lib.rs +++ b/async-graphql-actix-web/src/lib.rs @@ -2,12 +2,6 @@ #![warn(missing_docs)] -#[macro_use] -extern crate serde_derive; -#[macro_use] -extern crate actix_derive; - -mod pubsub; mod session; use crate::session::WsSession; @@ -24,9 +18,8 @@ use mime::Mime; use std::collections::HashMap; use std::future::Future; use std::pin::Pin; -use std::sync::Arc; -pub use pubsub::publish_message; +// pub use pubsub::publish_message; /// Actix-web handler builder pub struct HandlerBuilder { @@ -99,9 +92,9 @@ where HttpRequest, Payload, ) -> Pin>>> - + 'static - + Clone { - let schema = Arc::new(self.schema); + + Clone + + 'static { + let schema = self.schema.clone(); let max_file_size = self.max_file_size; let max_file_count = self.max_file_count; let enable_ui = self.enable_ui; @@ -156,9 +149,9 @@ async fn handle_request( mut payload: Payload, ) -> actix_web::Result where - Query: ObjectType + Send + Sync, - Mutation: ObjectType + Send + Sync, - Subscription: SubscriptionType + Send + Sync, + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Subscription: SubscriptionType + Send + Sync + 'static, { if let Ok(ct) = get_content_type(req.headers()) { if ct.essence_str() == mime::MULTIPART_FORM_DATA { diff --git a/async-graphql-actix-web/src/pubsub.rs b/async-graphql-actix-web/src/pubsub.rs deleted file mode 100644 index 6836480a..00000000 --- a/async-graphql-actix-web/src/pubsub.rs +++ /dev/null @@ -1,86 +0,0 @@ -use actix::{Actor, Context, Handler, Recipient, Supervised, SystemService}; -use async_graphql::Result; -use slab::Slab; -use std::any::Any; -use std::sync::Arc; - -#[derive(Message)] -#[rtype(result = "std::result::Result<(), ()>")] -pub struct PushMessage(pub Arc); - -#[derive(Message)] -#[rtype(result = "usize")] -struct NewClient { - recipient: Recipient, -} - -#[derive(Message)] -#[rtype(result = "()")] -struct RemoveClient { - id: usize, -} - -#[derive(Message)] -#[rtype(result = "()")] -struct PubMessage(Arc); - -struct ClientInfo { - recipient: Recipient, -} - -#[derive(Default)] -struct PubSubService { - clients: Slab, -} - -impl Actor for PubSubService { - type Context = Context; -} - -impl Handler for PubSubService { - type Result = usize; - - fn handle(&mut self, msg: NewClient, _ctx: &mut Context) -> Self::Result { - self.clients.insert(ClientInfo { - recipient: msg.recipient, - }) - } -} - -impl Handler for PubSubService { - type Result = (); - - fn handle(&mut self, msg: RemoveClient, _ctx: &mut Context) -> Self::Result { - self.clients.remove(msg.id); - } -} - -impl Handler for PubSubService { - type Result = (); - - fn handle(&mut self, msg: PubMessage, _ctx: &mut Context) -> Self::Result { - for (_, client) in &self.clients { - client.recipient.do_send(PushMessage(msg.0.clone())).ok(); - } - } -} - -impl Supervised for PubSubService {} - -impl SystemService for PubSubService {} - -pub async fn new_client(recipient: Recipient) -> Result { - let id = PubSubService::from_registry() - .send(NewClient { recipient }) - .await?; - Ok(id) -} - -pub fn remove_client(id: usize) { - PubSubService::from_registry().do_send(RemoveClient { id }); -} - -/// Publish a message that will be pushed to all subscribed clients. -pub fn publish_message(msg: T) { - PubSubService::from_registry().do_send(PubMessage(Arc::new(msg))); -} diff --git a/async-graphql-actix-web/src/session.rs b/async-graphql-actix-web/src/session.rs index 8ad61559..e786156a 100644 --- a/async-graphql-actix-web/src/session.rs +++ b/async-graphql-actix-web/src/session.rs @@ -1,28 +1,17 @@ -use crate::pubsub::{new_client, remove_client, PushMessage}; use actix::{ - Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, Handler, - ResponseActFuture, Running, StreamHandler, WrapFuture, + Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture, }; use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext}; -use async_graphql::http::{GQLError, GQLRequest, GQLResponse}; -use async_graphql::{ObjectType, QueryResult, Schema, Subscribe, SubscriptionType, Variables}; -use std::collections::HashMap; -use std::sync::Arc; +use async_graphql::{ObjectType, Schema, SubscriptionType, WebSocketTransport}; +use bytes::Bytes; +use futures::channel::mpsc; +use futures::SinkExt; use std::time::{Duration, Instant}; -#[derive(Serialize, Deserialize)] -struct OperationMessage { - #[serde(rename = "type")] - ty: String, - id: Option, - payload: Option, -} - pub struct WsSession { - schema: Arc>, + schema: Schema, hb: Instant, - client_id: usize, - subscribes: HashMap>, + sink: Option>, } impl WsSession @@ -31,12 +20,11 @@ where Mutation: ObjectType + Send + Sync + 'static, Subscription: SubscriptionType + Send + Sync + 'static, { - pub fn new(schema: Arc>) -> Self { + pub fn new(schema: Schema) -> Self { Self { schema, hb: Instant::now(), - client_id: 0, - subscribes: Default::default(), + sink: None, } } @@ -59,19 +47,20 @@ where fn started(&mut self, ctx: &mut Self::Context) { self.hb(ctx); - - new_client(ctx.address().recipient()) - .into_actor(self) - .then(|client_id, actor, _| { - actor.client_id = client_id.unwrap(); - async {}.into_actor(actor) - }) - .wait(ctx); - } - - fn stopping(&mut self, _ctx: &mut Self::Context) -> Running { - remove_client(self.client_id); - Running::Stop + let schema = self.schema.clone(); + async move { + schema + .clone() + .subscription_connection(WebSocketTransport::default()) + .await + } + .into_actor(self) + .then(|(sink, stream), actor, ctx| { + actor.sink = Some(sink); + ctx.add_stream(stream); + async {}.into_actor(actor) + }) + .wait(ctx); } } @@ -100,67 +89,11 @@ where self.hb = Instant::now(); } Message::Text(s) => { - if let Ok(msg) = serde_json::from_str::(&s) { - match msg.ty.as_str() { - "connection_init" => { - ctx.text( - serde_json::to_string(&OperationMessage { - ty: "connection_ack".to_string(), - id: None, - payload: None, - }) - .unwrap(), - ); - } - "start" => { - if let (Some(id), Some(payload)) = (msg.id, msg.payload) { - if let Ok(request) = serde_json::from_value::(payload) { - let builder = self.schema.subscribe(&request.query); - let builder = if let Some(variables) = request.variables { - match Variables::parse_from_json(variables) { - Ok(variables) => builder.variables(variables), - Err(_) => builder, - } - } else { - builder - }; - let builder = - if let Some(operation_name) = &request.operation_name { - builder.operator_name(&operation_name) - } else { - builder - }; - let subscribe = match builder.execute() { - Ok(subscribe) => subscribe, - Err(err) => { - ctx.text( - serde_json::to_string(&OperationMessage { - ty: "error".to_string(), - id: Some(id), - payload: Some( - serde_json::to_value(GQLError(&err)) - .unwrap(), - ), - }) - .unwrap(), - ); - return; - } - }; - self.subscribes.insert(id, Arc::new(subscribe)); - } - } - } - "stop" => { - if let Some(id) = msg.id { - self.subscribes.remove(&id); - } - } - "connection_terminate" => { - ctx.stop(); - } - _ => {} - } + if let Some(mut sink) = self.sink.clone() { + async move { sink.send(s.into()).await } + .into_actor(self) + .then(|_, actor, _| async {}.into_actor(actor)) + .wait(ctx); } } Message::Binary(_) | Message::Close(_) | Message::Continuation(_) => { @@ -171,52 +104,14 @@ where } } -impl Handler +impl StreamHandler for WsSession where Query: ObjectType + Send + Sync + 'static, Mutation: ObjectType + Send + Sync + 'static, Subscription: SubscriptionType + Send + Sync + 'static, { - type Result = ResponseActFuture>; - - fn handle(&mut self, msg: PushMessage, _ctx: &mut Self::Context) -> Self::Result { - let subscribes = self.subscribes.clone(); - let schema = self.schema.clone(); - Box::new( - async move { - let mut push_msgs = Vec::new(); - for (id, subscribe) in subscribes { - let res = match subscribe.resolve(&schema, msg.0.as_ref()).await { - Ok(Some(value)) => Some(Ok(value)), - Ok(None) => None, - Err(err) => Some(Err(err)), - }; - if let Some(res) = res { - let push_msg = serde_json::to_string(&OperationMessage { - ty: "data".to_string(), - id: Some(id.clone()), - payload: Some( - serde_json::to_value(GQLResponse(res.map(|data| QueryResult { - data, - extensions: None, - }))) - .unwrap(), - ), - }) - .unwrap(); - push_msgs.push(push_msg); - } - } - push_msgs - } - .into_actor(self) - .map(|msgs, _, ctx| { - for msg in msgs { - ctx.text(msg); - } - Ok(()) - }), - ) + fn handle(&mut self, data: Bytes, ctx: &mut Self::Context) { + ctx.text(unsafe { std::str::from_utf8_unchecked(&data) }); } } diff --git a/async-graphql-derive/Cargo.toml b/async-graphql-derive/Cargo.toml index bfb9e2da..20c19ba4 100644 --- a/async-graphql-derive/Cargo.toml +++ b/async-graphql-derive/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "async-graphql-derive" -version = "1.6.3" +version = "1.6.4" authors = ["sunli "] edition = "2018" description = "Macros for async-graphql" diff --git a/examples/actix-web.rs b/examples/actix-web.rs index 9adba5d0..a52ef7dd 100644 --- a/examples/actix-web.rs +++ b/examples/actix-web.rs @@ -27,9 +27,10 @@ async fn main() -> std::io::Result<()> { HttpServer::new(move || { App::new() .data( - Schema::new(starwars::QueryRoot, EmptyMutation, EmptySubscription) + Schema::build(starwars::QueryRoot, EmptyMutation, EmptySubscription) .data(starwars::StarWars::new()) - .extension(|| async_graphql::extensions::ApolloTracing::default()), + .extension(|| async_graphql::extensions::ApolloTracing::default()) + .finish(), ) .service(web::resource("/").guard(guard::Post()).to(index)) .service(web::resource("/").guard(guard::Get()).to(gql_playgound)) diff --git a/examples/tide.rs b/examples/tide.rs index 23e6995c..6ea166ae 100644 --- a/examples/tide.rs +++ b/examples/tide.rs @@ -28,8 +28,9 @@ async fn gql_graphiql(_request: Request) -> Response { #[async_std::main] async fn main() -> std::io::Result<()> { let mut app = tide::with_state( - Schema::new(starwars::QueryRoot, EmptyMutation, EmptySubscription) - .data(starwars::StarWars::new()), + Schema::build(starwars::QueryRoot, EmptyMutation, EmptySubscription) + .data(starwars::StarWars::new()) + .finish(), ); app.at("/").post(index); app.at("/").get(gql_playground); diff --git a/src/http/mod.rs b/src/http/mod.rs index 3bc04d5f..7e75e18f 100644 --- a/src/http/mod.rs +++ b/src/http/mod.rs @@ -35,9 +35,9 @@ impl GQLRequest { schema: &Schema, ) -> GQLResponse where - Query: ObjectType + Send + Sync, - Mutation: ObjectType + Send + Sync, - Subscription: SubscriptionType + Send + Sync, + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Subscription: SubscriptionType + Send + Sync + 'static, { match self.prepare(schema) { Ok(query) => GQLResponse(query.execute().await), @@ -51,9 +51,9 @@ impl GQLRequest { schema: &'a Schema, ) -> Result> where - Query: ObjectType + Send + Sync, - Mutation: ObjectType + Send + Sync, - Subscription: SubscriptionType + Send + Sync, + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Subscription: SubscriptionType + Send + Sync + 'static, { let vars = match self.variables.take() { Some(value) => match Variables::parse_from_json(value) { diff --git a/src/lib.rs b/src/lib.rs index cae16096..665f0ecb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -104,8 +104,11 @@ pub use graphql_parser::query::Value; pub use query::{PreparedQuery, QueryBuilder, QueryResult}; pub use registry::CacheControl; pub use scalars::ID; -pub use schema::Schema; -pub use subscription::SubscribeBuilder; +pub use schema::{publish, Schema}; +pub use subscription::{ + SubscriptionStream, SubscriptionStub, SubscriptionStubs, SubscriptionTransport, + WebSocketTransport, +}; pub use types::{ Connection, DataSource, EmptyEdgeFields, EmptyMutation, EmptySubscription, QueryOperation, Upload, @@ -129,7 +132,7 @@ pub use context::ContextBase; #[doc(hidden)] pub use resolver::{collect_fields, do_resolve}; #[doc(hidden)] -pub use subscription::{Subscribe, SubscriptionType}; +pub use subscription::SubscriptionType; #[doc(hidden)] pub use types::{EnumItem, EnumType}; @@ -481,7 +484,7 @@ pub use async_graphql_derive::InputObject; /// /// #[async_std::main] /// async fn main() { -/// let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription).data("hello".to_string()); +/// let schema = Schema::build(QueryRoot, EmptyMutation, EmptySubscription).data("hello".to_string()).finish(); /// let res = schema.query(r#" /// { /// typeA { diff --git a/src/query.rs b/src/query.rs index e61b0b3c..f461871c 100644 --- a/src/query.rs +++ b/src/query.rs @@ -58,16 +58,16 @@ impl<'a, Query, Mutation, Subscription> QueryBuilder<'a, Query, Mutation, Subscr cache_control, complexity, depth, - } = check_rules(&self.schema.registry, &document)?; + } = check_rules(&self.schema.0.registry, &document)?; self.extensions.iter().for_each(|e| e.validation_end()); - if let Some(limit_complexity) = self.schema.complexity { + if let Some(limit_complexity) = self.schema.0.complexity { if complexity > limit_complexity { return Err(QueryError::TooComplex.into()); } } - if let Some(limit_depth) = self.schema.depth { + if let Some(limit_depth) = self.schema.0.depth { if depth > limit_depth { return Err(QueryError::TooDeep.into()); } @@ -83,14 +83,14 @@ impl<'a, Query, Mutation, Subscription> QueryBuilder<'a, Query, Mutation, Subscr Definition::Operation(operation_definition) => match operation_definition { OperationDefinition::SelectionSet(s) => { selection_set = Some(s); - root = Some(Root::Query(&self.schema.query)); + root = Some(Root::Query(&self.schema.0.query)); } OperationDefinition::Query(query) if query.name.is_none() || query.name.as_deref() == self.operation_name => { selection_set = Some(query.selection_set); variable_definitions = Some(query.variable_definitions); - root = Some(Root::Query(&self.schema.query)); + root = Some(Root::Query(&self.schema.0.query)); } OperationDefinition::Mutation(mutation) if mutation.name.is_none() @@ -98,7 +98,7 @@ impl<'a, Query, Mutation, Subscription> QueryBuilder<'a, Query, Mutation, Subscr { selection_set = Some(mutation.selection_set); variable_definitions = Some(mutation.variable_definitions); - root = Some(Root::Mutation(&self.schema.mutation)); + root = Some(Root::Mutation(&self.schema.0.mutation)); } OperationDefinition::Subscription(subscription) if subscription.name.is_none() @@ -116,7 +116,7 @@ impl<'a, Query, Mutation, Subscription> QueryBuilder<'a, Query, Mutation, Subscr Ok(PreparedQuery { extensions: self.extensions, - registry: &self.schema.registry, + registry: &self.schema.0.registry, variables: self.variables.unwrap_or_default(), data: self.data, fragments, diff --git a/src/schema.rs b/src/schema.rs index 67c8cb82..a2149935 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -3,13 +3,33 @@ use crate::extensions::{BoxExtension, Extension}; use crate::model::__DirectiveLocation; use crate::query::QueryBuilder; use crate::registry::{Directive, InputValue, Registry}; +use crate::subscription::{create_connection, SubscriptionStub, SubscriptionTransport}; use crate::types::QueryRoot; -use crate::{ObjectType, SubscribeBuilder, SubscriptionType, Type}; -use std::any::Any; +use crate::validation::check_rules; +use crate::{ + ContextSelectionSet, ObjectType, QueryError, QueryParseError, Result, SubscriptionStream, + SubscriptionType, Type, Variables, +}; +use bytes::Bytes; +use futures::channel::mpsc; +use futures::lock::Mutex; +use futures::SinkExt; +use graphql_parser::parse_query; +use graphql_parser::query::{ + Definition, Field, FragmentDefinition, OperationDefinition, Selection, +}; +use once_cell::sync::Lazy; +use slab::Slab; +use std::any::{Any, TypeId}; use std::collections::HashMap; +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; -/// GraphQL schema -pub struct Schema { +type MsgSender = mpsc::Sender>; + +pub(crate) static SUBSCRIPTION_SENDERS: Lazy>> = Lazy::new(Default::default); + +pub(crate) struct SchemaInner { pub(crate) query: QueryRoot, pub(crate) mutation: Mutation, pub(crate) subscription: Subscription, @@ -20,15 +40,80 @@ pub struct Schema { pub(crate) extensions: Vec BoxExtension + Send + Sync>>, } +/// Schema builder +pub struct SchemaBuilder(SchemaInner); + impl - Schema + SchemaBuilder { - /// Create a schema + /// Disable introspection query + pub fn disable_introspection(mut self) -> Self { + self.0.query.disable_introspection = true; + self + } + + /// Set limit complexity, Default no limit. + pub fn limit_complexity(mut self, complexity: usize) -> Self { + self.0.complexity = Some(complexity); + self + } + + /// Set limit complexity, Default no limit. + pub fn limit_depth(mut self, depth: usize) -> Self { + self.0.depth = Some(depth); + self + } + + /// Add an extension + pub fn extension E + Send + Sync + 'static, E: Extension>( + mut self, + extension_factory: F, + ) -> Self { + self.0 + .extensions + .push(Box::new(move || Box::new(extension_factory()))); + self + } + + /// Add a global data that can be accessed in the `Context`. + pub fn data(mut self, data: D) -> Self { + self.0.data.insert(data); + self + } + + /// Build schema. + pub fn finish(self) -> Schema { + Schema(Arc::new(self.0)) + } +} + +/// GraphQL schema +pub struct Schema( + pub(crate) Arc>, +); + +impl Clone for Schema { + fn clone(&self) -> Self { + Schema(self.0.clone()) + } +} + +impl Schema +where + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Subscription: SubscriptionType + Send + Sync + 'static, +{ + /// Create a schema builder /// /// The root object for the query and Mutation needs to be specified. /// If there is no mutation, you can use `EmptyMutation`. /// If there is no subscription, you can use `EmptySubscription`. - pub fn new(query: Query, mutation: Mutation, subscription: Subscription) -> Self { + pub fn build( + query: Query, + mutation: Mutation, + subscription: Subscription, + ) -> SchemaBuilder { let mut registry = Registry { types: Default::default(), directives: Default::default(), @@ -102,7 +187,7 @@ impl Subscription::create_type_info(&mut registry); } - Self { + SchemaBuilder(SchemaInner { query: QueryRoot { inner: query, disable_introspection: false, @@ -114,64 +199,161 @@ impl complexity: None, depth: None, extensions: Default::default(), - } + }) } - /// Disable introspection query - pub fn disable_introspection(mut self) -> Self { - self.query.disable_introspection = true; - self - } - - /// Set limit complexity, Default no limit. - pub fn limit_complexity(mut self, complexity: usize) -> Self { - self.complexity = Some(complexity); - self - } - - /// Set limit complexity, Default no limit. - pub fn limit_depth(mut self, depth: usize) -> Self { - self.depth = Some(depth); - self - } - - /// Add an extension - pub fn extension E + Send + Sync + 'static, E: Extension>( - mut self, - extension_factory: F, - ) -> Self { - self.extensions - .push(Box::new(move || Box::new(extension_factory()))); - self - } - - /// Add a global data that can be accessed in the `Context`. - pub fn data(mut self, data: D) -> Self { - self.data.insert(data); - self + /// Create a schema + pub fn new( + query: Query, + mutation: Mutation, + subscription: Subscription, + ) -> Schema { + Self::build(query, mutation, subscription).finish() } /// Start a query and return `QueryBuilder`. pub fn query<'a>(&'a self, source: &'a str) -> QueryBuilder<'a, Query, Mutation, Subscription> { QueryBuilder { - extensions: self.extensions.iter().map(|factory| factory()).collect(), + extensions: self.0.extensions.iter().map(|factory| factory()).collect(), schema: self, source, operation_name: None, variables: None, - data: &self.data, + data: &self.0.data, } } - /// Start a subscribe and return `SubscribeBuilder`. - pub fn subscribe<'a>(&'a self, source: &'a str) -> SubscribeBuilder<'a, Subscription> { - SubscribeBuilder { - extensions: Default::default(), - subscription: &self.subscription, - registry: &self.registry, - source, - operation_name: None, - variables: None, + /// Create subscription stub, typically called inside the `SubscriptionTransport::handle_request` method/ + pub fn create_subscription_stub( + &self, + source: &str, + operation_name: Option<&str>, + variables: Variables, + ) -> Result> + where + Self: Sized, + { + let document = parse_query(source).map_err(|err| QueryParseError(err.to_string()))?; + check_rules(&self.0.registry, &document)?; + + let mut fragments = HashMap::new(); + let mut subscription = None; + + for definition in document.definitions { + match definition { + Definition::Operation(OperationDefinition::Subscription(s)) => { + if s.name.as_deref() == operation_name { + subscription = Some(s); + break; + } + } + Definition::Fragment(fragment) => { + fragments.insert(fragment.name.clone(), fragment); + } + _ => {} + } } + + let subscription = subscription.ok_or(if let Some(name) = operation_name { + QueryError::UnknownOperationNamed { + name: name.to_string(), + } + } else { + QueryError::MissingOperation + })?; + + let mut types = HashMap::new(); + let resolve_id = AtomicUsize::default(); + let ctx = ContextSelectionSet { + path_node: None, + extensions: &[], + item: &subscription.selection_set, + resolve_id: &resolve_id, + variables: &variables, + variable_definitions: Some(&subscription.variable_definitions), + registry: &self.0.registry, + data: &Default::default(), + fragments: &fragments, + }; + create_subscription_types::(&ctx, &fragments, &mut types)?; + Ok(SubscriptionStub { + schema: self.clone(), + types, + variables, + variable_definitions: subscription.variable_definitions, + fragments, + }) + } + + /// Create subscription connection, returns `Sink` and `Stream`. + pub async fn subscription_connection( + &self, + transport: T, + ) -> ( + mpsc::Sender, + SubscriptionStream, + ) { + create_connection(self, transport).await + } +} + +fn create_subscription_types( + ctx: &ContextSelectionSet<'_>, + fragments: &HashMap, + types: &mut HashMap, +) -> Result<()> { + for selection in &ctx.items { + match selection { + Selection::Field(field) => { + if ctx.is_skip(&field.directives)? { + continue; + } + T::create_type(field, types)?; + } + Selection::FragmentSpread(fragment_spread) => { + if ctx.is_skip(&fragment_spread.directives)? { + continue; + } + + if let Some(fragment) = fragments.get(&fragment_spread.fragment_name) { + create_subscription_types::( + &ctx.with_selection_set(&fragment.selection_set), + fragments, + types, + )?; + } else { + return Err(QueryError::UnknownFragment { + name: fragment_spread.fragment_name.clone(), + } + .into()); + } + } + Selection::InlineFragment(inline_fragment) => { + if ctx.is_skip(&inline_fragment.directives)? { + continue; + } + create_subscription_types::( + &ctx.with_selection_set(&inline_fragment.selection_set), + fragments, + types, + )?; + } + } + } + Ok(()) +} + +/// Publish a message that will be pushed to all subscribed clients. +pub async fn publish(msg: T) { + let mut senders = SUBSCRIPTION_SENDERS.lock().await; + let msg = Arc::new(msg); + let mut remove = Vec::new(); + for (id, sender) in senders.iter_mut() { + if sender.send(msg.clone()).await.is_err() { + remove.push(id); + } + } + for id in remove { + senders.remove(id); } } diff --git a/src/subscription.rs b/src/subscription.rs deleted file mode 100644 index 4272b491..00000000 --- a/src/subscription.rs +++ /dev/null @@ -1,224 +0,0 @@ -use crate::extensions::BoxExtension; -use crate::registry::Registry; -use crate::validation::check_rules; -use crate::{ - ContextBase, ContextSelectionSet, QueryError, QueryParseError, Result, Schema, Type, Variables, -}; -use graphql_parser::parse_query; -use graphql_parser::query::{ - Definition, Field, FragmentDefinition, OperationDefinition, Selection, SelectionSet, - VariableDefinition, -}; -use std::any::{Any, TypeId}; -use std::collections::HashMap; -use std::sync::atomic::AtomicUsize; - -/// Subscribe stub -/// -/// When a new push message is generated, a JSON object that needs to be pushed can be obtained by -/// `Subscribe::resolve`, and if None is returned, the Subscribe is not subscribed to a message of this type. -pub struct Subscribe { - types: HashMap, - variables: Variables, - variable_definitions: Vec, - fragments: HashMap, -} - -impl Subscribe { - #[allow(missing_docs)] - pub async fn resolve( - &self, - schema: &Schema, - msg: &(dyn Any + Send + Sync), - ) -> Result> - where - Subscription: SubscriptionType + Sync + Send + 'static, - { - let resolve_id = AtomicUsize::default(); - let ctx = ContextBase::<()> { - path_node: None, - extensions: &[], - item: (), - resolve_id: &resolve_id, - variables: &self.variables, - variable_definitions: Some(&self.variable_definitions), - registry: &schema.registry, - data: &schema.data, - fragments: &self.fragments, - }; - schema.subscription.resolve(&ctx, &self.types, msg).await - } -} - -/// Represents a GraphQL subscription object -#[allow(missing_docs)] -#[async_trait::async_trait] -pub trait SubscriptionType: Type { - /// This function returns true of type `EmptySubscription` only - #[doc(hidden)] - fn is_empty() -> bool { - false - } - - fn create_type(field: &Field, types: &mut HashMap) -> Result<()>; - - fn create_subscribe( - &self, - extensions: &[BoxExtension], - registry: &Registry, - selection_set: SelectionSet, - variables: Variables, - variable_definitions: Vec, - fragments: HashMap, - ) -> Result - where - Self: Sized, - { - let mut types = HashMap::new(); - let resolve_id = AtomicUsize::default(); - let ctx = ContextSelectionSet { - path_node: None, - extensions, - item: &selection_set, - resolve_id: &resolve_id, - variables: &variables, - variable_definitions: Some(&variable_definitions), - registry, - data: &Default::default(), - fragments: &fragments, - }; - create_types::(&ctx, &fragments, &mut types)?; - Ok(Subscribe { - types, - variables, - variable_definitions, - fragments, - }) - } - - /// Resolve a subscription message, If no message of this type is subscribed, None is returned. - async fn resolve( - &self, - ctx: &ContextBase<'_, ()>, - types: &HashMap, - msg: &(dyn Any + Send + Sync), - ) -> Result>; -} - -fn create_types( - ctx: &ContextSelectionSet<'_>, - fragments: &HashMap, - types: &mut HashMap, -) -> Result<()> { - for selection in &ctx.items { - match selection { - Selection::Field(field) => { - if ctx.is_skip(&field.directives)? { - continue; - } - T::create_type(field, types)?; - } - Selection::FragmentSpread(fragment_spread) => { - if ctx.is_skip(&fragment_spread.directives)? { - continue; - } - - if let Some(fragment) = fragments.get(&fragment_spread.fragment_name) { - create_types::( - &ctx.with_selection_set(&fragment.selection_set), - fragments, - types, - )?; - } else { - return Err(QueryError::UnknownFragment { - name: fragment_spread.fragment_name.clone(), - } - .into()); - } - } - Selection::InlineFragment(inline_fragment) => { - if ctx.is_skip(&inline_fragment.directives)? { - continue; - } - create_types::( - &ctx.with_selection_set(&inline_fragment.selection_set), - fragments, - types, - )?; - } - } - } - Ok(()) -} - -/// Subscribe builder -pub struct SubscribeBuilder<'a, Subscription> { - pub(crate) subscription: &'a Subscription, - pub(crate) extensions: &'a [BoxExtension], - pub(crate) registry: &'a Registry, - pub(crate) source: &'a str, - pub(crate) operation_name: Option<&'a str>, - pub(crate) variables: Option, -} - -impl<'a, Subscription> SubscribeBuilder<'a, Subscription> -where - Subscription: SubscriptionType, -{ - /// Specify the operation name. - pub fn operator_name(self, name: &'a str) -> Self { - SubscribeBuilder { - operation_name: Some(name), - ..self - } - } - - /// Specify the variables. - pub fn variables(self, vars: Variables) -> Self { - SubscribeBuilder { - variables: Some(vars), - ..self - } - } - - /// Perform a subscription operation and return `Subscribe`. - pub fn execute(self) -> Result { - let document = parse_query(self.source).map_err(|err| QueryParseError(err.to_string()))?; - check_rules(self.registry, &document)?; - - let mut fragments = HashMap::new(); - let mut subscription = None; - - for definition in document.definitions { - match definition { - Definition::Operation(OperationDefinition::Subscription(s)) => { - if s.name.as_deref() == self.operation_name { - subscription = Some(s); - break; - } - } - Definition::Fragment(fragment) => { - fragments.insert(fragment.name.clone(), fragment); - } - _ => {} - } - } - - let subscription = subscription.ok_or(if let Some(name) = self.operation_name { - QueryError::UnknownOperationNamed { - name: name.to_string(), - } - } else { - QueryError::MissingOperation - })?; - - self.subscription.create_subscribe( - self.extensions, - self.registry, - subscription.selection_set, - self.variables.unwrap_or_default(), - subscription.variable_definitions, - fragments, - ) - } -} diff --git a/src/subscription/connection.rs b/src/subscription/connection.rs new file mode 100644 index 00000000..eeba103d --- /dev/null +++ b/src/subscription/connection.rs @@ -0,0 +1,180 @@ +use crate::schema::SUBSCRIPTION_SENDERS; +use crate::subscription::SubscriptionStub; +use crate::{ObjectType, Result, Schema, SubscriptionType}; +use bytes::Bytes; +use futures::channel::mpsc; +use futures::task::{Context, Poll}; +use futures::{Future, FutureExt, Stream}; +use slab::Slab; +use std::any::Any; +use std::collections::VecDeque; +use std::pin::Pin; +use std::sync::Arc; + +/// Subscription stubs, use to hold all subscription information for the `SubscriptionConnection` +pub struct SubscriptionStubs( + Slab>, +); + +impl Default for SubscriptionStubs { + fn default() -> Self { + Self(Slab::new()) + } +} + +#[allow(missing_docs)] +impl SubscriptionStubs { + pub fn add(&mut self, stub: SubscriptionStub) -> usize { + self.0.insert(stub) + } + + pub fn remove(&mut self, id: usize) { + self.0.remove(id); + } +} + +/// Subscription transport +/// +/// You can customize your transport by implementing this trait. +pub trait SubscriptionTransport: Send + Sync + Unpin + 'static { + /// Parse the request data here. + /// If you have a new request, create a `SubscriptionStub` with the `Schema::create_subscription_stub`, and then call `SubscriptionStubs::add`. + /// You can return a `Byte`, which will be sent to the client. If it returns an error, the connection will be broken. + fn handle_request( + &mut self, + schema: &Schema, + stubs: &mut SubscriptionStubs, + data: Bytes, + ) -> Result> + where + Query: ObjectType + Sync + Send + 'static, + Mutation: ObjectType + Sync + Send + 'static, + Subscription: SubscriptionType + Sync + Send + 'static; + + /// When a response message is generated, you can convert the message to the format you want here. + fn handle_response(&mut self, id: usize, result: Result) -> Option; +} + +pub async fn create_connection( + schema: &Schema, + transport: T, +) -> ( + mpsc::Sender, + SubscriptionStream, +) +where + Query: ObjectType + Sync + Send + 'static, + Mutation: ObjectType + Sync + Send + 'static, + Subscription: SubscriptionType + Sync + Send + 'static, +{ + let (tx_bytes, rx_bytes) = mpsc::channel(8); + let (tx_msg, rx_msg) = mpsc::channel(8); + let mut senders = SUBSCRIPTION_SENDERS.lock().await; + senders.insert(tx_msg); + ( + tx_bytes.clone(), + SubscriptionStream { + schema: schema.clone(), + transport, + stubs: Default::default(), + rx_bytes, + rx_msg, + send_queue: VecDeque::new(), + resolve_queue: VecDeque::default(), + resolve_fut: None, + }, + ) +} + +#[allow(missing_docs)] +pub struct SubscriptionStream { + schema: Schema, + transport: T, + stubs: SubscriptionStubs, + rx_bytes: mpsc::Receiver, + rx_msg: mpsc::Receiver>, + send_queue: VecDeque, + resolve_queue: VecDeque>, + resolve_fut: Option>>>, +} + +impl Stream + for SubscriptionStream +where + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Subscription: SubscriptionType + Send + Sync + 'static, + T: SubscriptionTransport, +{ + type Item = Bytes; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + loop { + // send bytes + if let Some(bytes) = self.send_queue.pop_front() { + println!("{}", String::from_utf8(bytes.to_vec()).unwrap()); + return Poll::Ready(Some(bytes)); + } + + // receive bytes + match Pin::new(&mut self.rx_bytes).poll_next(cx) { + Poll::Ready(Some(data)) => { + let this = &mut *self; + match this + .transport + .handle_request(&this.schema, &mut this.stubs, data) + { + Ok(Some(bytes)) => { + this.send_queue.push_back(bytes); + continue; + } + Ok(None) => {} + Err(_) => return Poll::Ready(None), + } + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => {} + } + + if let Some(resolve_fut) = &mut self.resolve_fut { + match resolve_fut.poll_unpin(cx) { + Poll::Ready(_) => { + self.resolve_fut = None; + } + Poll::Pending => return Poll::Pending, + } + } else if let Some(msg) = self.resolve_queue.pop_front() { + // FIXME: I think this code is safe, but I don't know how to implement it in safe code. + let this = &mut *self; + let stubs = &this.stubs as *const SubscriptionStubs; + let transport = &mut this.transport as *mut T; + let send_queue = &mut this.send_queue as *mut VecDeque; + let fut = async move { + unsafe { + for (id, stub) in (*stubs).0.iter() { + if let Some(res) = stub.resolve(msg.as_ref()).await.transpose() { + if let Some(bytes) = (*transport).handle_response(id, res) { + (*send_queue).push_back(bytes); + } + } + } + } + }; + self.resolve_fut = Some(Box::pin(fut)); + continue; + } + + // receive msg + match Pin::new(&mut self.rx_msg).poll_next(cx) { + Poll::Ready(Some(msg)) => { + self.resolve_queue.push_back(msg); + } + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => { + // all pending + return Poll::Pending; + } + } + } + } +} diff --git a/src/subscription/mod.rs b/src/subscription/mod.rs new file mode 100644 index 00000000..a2861bab --- /dev/null +++ b/src/subscription/mod.rs @@ -0,0 +1,11 @@ +mod connection; +mod subscribe_stub; +mod subscription_type; +mod ws_transport; + +pub use connection::{ + create_connection, SubscriptionStream, SubscriptionStubs, SubscriptionTransport, +}; +pub use subscribe_stub::SubscriptionStub; +pub use subscription_type::SubscriptionType; +pub use ws_transport::WebSocketTransport; diff --git a/src/subscription/subscribe_stub.rs b/src/subscription/subscribe_stub.rs new file mode 100644 index 00000000..df409fa8 --- /dev/null +++ b/src/subscription/subscribe_stub.rs @@ -0,0 +1,48 @@ +use crate::{ContextBase, ObjectType, Result, Schema, SubscriptionType, Variables}; +use graphql_parser::query::{Field, FragmentDefinition, VariableDefinition}; +use std::any::{Any, TypeId}; +use std::collections::HashMap; +use std::sync::atomic::AtomicUsize; + +/// Subscription stub +/// +/// When a new push message is generated, a JSON object that needs to be pushed can be obtained by +/// `Subscribe::resolve`, and if None is returned, the Subscribe is not subscribed to a message of this type. +pub struct SubscriptionStub { + pub(crate) schema: Schema, + pub(crate) types: HashMap, + pub(crate) variables: Variables, + pub(crate) variable_definitions: Vec, + pub(crate) fragments: HashMap, +} + +impl SubscriptionStub +where + Query: ObjectType + Send + Sync + 'static, + Mutation: ObjectType + Send + Sync + 'static, + Subscription: SubscriptionType + Send + Sync + 'static, +{ + #[doc(hidden)] + pub async fn resolve( + &self, + msg: &(dyn Any + Send + Sync), + ) -> Result> { + let resolve_id = AtomicUsize::default(); + let ctx = ContextBase::<()> { + path_node: None, + extensions: &[], + item: (), + resolve_id: &resolve_id, + variables: &self.variables, + variable_definitions: Some(&self.variable_definitions), + registry: &self.schema.0.registry, + data: &self.schema.0.data, + fragments: &self.fragments, + }; + self.schema + .0 + .subscription + .resolve(&ctx, &self.types, msg) + .await + } +} diff --git a/src/subscription/subscription_type.rs b/src/subscription/subscription_type.rs new file mode 100644 index 00000000..8168136c --- /dev/null +++ b/src/subscription/subscription_type.rs @@ -0,0 +1,25 @@ +use crate::{ContextBase, Result, Type}; +use graphql_parser::query::Field; +use std::any::{Any, TypeId}; +use std::collections::HashMap; + +/// Represents a GraphQL subscription object +#[async_trait::async_trait] +pub trait SubscriptionType: Type { + /// This function returns true of type `EmptySubscription` only + #[doc(hidden)] + fn is_empty() -> bool { + false + } + + #[doc(hidden)] + fn create_type(field: &Field, types: &mut HashMap) -> Result<()>; + + /// Resolve a subscription message, If no message of this type is subscribed, None is returned. + async fn resolve( + &self, + ctx: &ContextBase<'_, ()>, + types: &HashMap, + msg: &(dyn Any + Send + Sync), + ) -> Result>; +} diff --git a/src/subscription/ws_transport.rs b/src/subscription/ws_transport.rs new file mode 100644 index 00000000..ddd953a0 --- /dev/null +++ b/src/subscription/ws_transport.rs @@ -0,0 +1,126 @@ +use crate::http::{GQLError, GQLRequest, GQLResponse}; +use crate::{ + ObjectType, QueryResult, Result, Schema, SubscriptionStubs, SubscriptionTransport, + SubscriptionType, Variables, +}; +use bytes::Bytes; +use std::collections::HashMap; + +#[derive(Serialize, Deserialize)] +struct OperationMessage { + #[serde(rename = "type")] + ty: String, + id: Option, + payload: Option, +} + +/// WebSocket transport +#[derive(Default)] +pub struct WebSocketTransport { + id_to_sid: HashMap, + sid_to_id: HashMap, +} + +impl SubscriptionTransport for WebSocketTransport { + fn handle_request( + &mut self, + schema: &Schema, + stubs: &mut SubscriptionStubs, + data: Bytes, + ) -> Result> + where + Query: ObjectType + Sync + Send + 'static, + Mutation: ObjectType + Sync + Send + 'static, + Subscription: SubscriptionType + Sync + Send + 'static, + { + match serde_json::from_slice::(&data) { + Ok(msg) => match msg.ty.as_str() { + "connection_init" => Ok(Some( + serde_json::to_vec(&OperationMessage { + ty: "connection_ack".to_string(), + id: None, + payload: None, + }) + .unwrap() + .into(), + )), + "start" => { + if let (Some(id), Some(payload)) = (msg.id, msg.payload) { + if let Ok(request) = serde_json::from_value::(payload) { + let variables = if let Some(value) = request.variables { + match Variables::parse_from_json(value) { + Ok(variables) => variables, + Err(_) => Default::default(), + } + } else { + Default::default() + }; + + match schema.create_subscription_stub( + &request.query, + request.operation_name.as_deref(), + variables, + ) { + Ok(stub) => { + let stub_id = stubs.add(stub); + self.id_to_sid.insert(id.clone(), stub_id); + self.sid_to_id.insert(stub_id, id); + Ok(None) + } + Err(err) => Ok(Some( + serde_json::to_vec(&OperationMessage { + ty: "error".to_string(), + id: Some(id), + payload: Some( + serde_json::to_value(GQLError(&err)).unwrap(), + ), + }) + .unwrap() + .into(), + )), + } + } else { + Ok(None) + } + } else { + Ok(None) + } + } + "stop" => { + if let Some(id) = msg.id { + if let Some(id) = self.id_to_sid.remove(&id) { + self.sid_to_id.remove(&id); + stubs.remove(id); + } + } + Ok(None) + } + "connection_terminate" => Err(anyhow::anyhow!("connection_terminate")), + _ => Err(anyhow::anyhow!("unknown op")), + }, + Err(err) => Err(err.into()), + } + } + + fn handle_response(&mut self, id: usize, result: Result) -> Option { + if let Some(id) = self.sid_to_id.get(&id) { + Some( + serde_json::to_vec(&OperationMessage { + ty: "data".to_string(), + id: Some(id.clone()), + payload: Some( + serde_json::to_value(GQLResponse(result.map(|data| QueryResult { + data, + extensions: None, + }))) + .unwrap(), + ), + }) + .unwrap() + .into(), + ) + } else { + None + } + } +}