Add context data support
This commit is contained in:
sunli 2020-03-31 11:19:18 +08:00
parent 9f7a890c0c
commit a13204273e
16 changed files with 286 additions and 63 deletions

View File

@ -1,6 +1,6 @@
[package] [package]
name = "async-graphql" name = "async-graphql"
version = "1.6.7" version = "1.6.8"
authors = ["sunli <scott_s829@163.com>"] authors = ["sunli <scott_s829@163.com>"]
edition = "2018" edition = "2018"
description = "The GraphQL server library implemented by rust" description = "The GraphQL server library implemented by rust"
@ -18,7 +18,7 @@ default = ["bson", "uuid", "url", "chrono-tz", "validators"]
validators = ["regex"] validators = ["regex"]
[dependencies] [dependencies]
async-graphql-derive = { path = "async-graphql-derive", version = "1.6.7" } async-graphql-derive = { path = "async-graphql-derive", version = "1.6.8" }
graphql-parser = "=0.2.3" graphql-parser = "=0.2.3"
anyhow = "1.0.26" anyhow = "1.0.26"
thiserror = "1.0.11" thiserror = "1.0.11"
@ -26,7 +26,6 @@ async-trait = "0.1.24"
serde = "1.0.104" serde = "1.0.104"
serde_derive = "1.0.104" serde_derive = "1.0.104"
serde_json = "1.0.48" serde_json = "1.0.48"
fnv = "1.0.6"
bytes = "0.5.4" bytes = "0.5.4"
Inflector = "0.11.4" Inflector = "0.11.4"
base64 = "0.12.0" base64 = "0.12.0"

View File

@ -1,6 +1,6 @@
[package] [package]
name = "async-graphql-actix-web" name = "async-graphql-actix-web"
version = "0.6.10" version = "0.6.11"
authors = ["sunli <scott_s829@163.com>"] authors = ["sunli <scott_s829@163.com>"]
edition = "2018" edition = "2018"
description = "async-graphql for actix-web" description = "async-graphql for actix-web"
@ -13,7 +13,7 @@ keywords = ["futures", "async", "graphql"]
categories = ["network-programming", "asynchronous"] categories = ["network-programming", "asynchronous"]
[dependencies] [dependencies]
async-graphql = { path = "..", version = "1.6.7" } async-graphql = { path = "..", version = "1.6.8" }
actix-web = "2.0.0" actix-web = "2.0.0"
actix-multipart = "0.2.0" actix-multipart = "0.2.0"
actix-web-actors = "2.0.0" actix-web-actors = "2.0.0"

View File

@ -0,0 +1,36 @@
use actix_web::{web, App, HttpServer};
use async_graphql::{Context, EmptyMutation, EmptySubscription, Schema};
struct MyToken(Option<String>);
struct QueryRoot;
#[async_graphql::Object]
impl QueryRoot {
#[field]
async fn current_token<'a>(&self, ctx: &'a Context<'_>) -> Option<&'a str> {
ctx.data::<MyToken>().0.as_deref()
}
}
#[actix_rt::main]
async fn main() -> std::io::Result<()> {
HttpServer::new(move || {
let schema = Schema::new(QueryRoot, EmptyMutation, EmptySubscription);
let handler = async_graphql_actix_web::HandlerBuilder::new(schema)
.enable_subscription()
.enable_ui("http://localhost:8000", None)
.on_request(|req, builder| {
builder.data(MyToken(
req.headers()
.get("Token")
.and_then(|value| value.to_str().map(ToString::to_string).ok()),
))
})
.build();
App::new().service(web::resource("/").to(handler))
})
.bind("127.0.0.1:8000")?
.run()
.await
}

View File

@ -11,15 +11,31 @@ use actix_web::web::{BytesMut, Payload};
use actix_web::{web, FromRequest, HttpRequest, HttpResponse, Responder}; use actix_web::{web, FromRequest, HttpRequest, HttpResponse, Responder};
use actix_web_actors::ws; use actix_web_actors::ws;
use async_graphql::http::{GQLRequest, GQLResponse}; use async_graphql::http::{GQLRequest, GQLResponse};
use async_graphql::{ObjectType, Schema, SubscriptionType}; use async_graphql::{
ObjectType, QueryBuilder, Schema, SubscriptionConnectionBuilder, SubscriptionType,
WebSocketTransport,
};
use bytes::Bytes; use bytes::Bytes;
use futures::StreamExt; use futures::StreamExt;
use mime::Mime; use mime::Mime;
use std::collections::HashMap; use std::collections::HashMap;
use std::future::Future; use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc;
// pub use pubsub::publish_message; type BoxOnRequestFn<Query, Mutation, Subscription> = Arc<
dyn for<'a> Fn(
&HttpRequest,
QueryBuilder<'a, Query, Mutation, Subscription>,
) -> QueryBuilder<'a, Query, Mutation, Subscription>,
>;
type BoxOnConnectFn<Query, Mutation, Subscription> = Arc<
dyn Fn(
&HttpRequest,
SubscriptionConnectionBuilder<Query, Mutation, Subscription, WebSocketTransport>,
) -> SubscriptionConnectionBuilder<Query, Mutation, Subscription, WebSocketTransport>,
>;
/// Actix-web handler builder /// Actix-web handler builder
pub struct HandlerBuilder<Query, Mutation, Subscription> { pub struct HandlerBuilder<Query, Mutation, Subscription> {
@ -28,6 +44,8 @@ pub struct HandlerBuilder<Query, Mutation, Subscription> {
max_file_count: usize, max_file_count: usize,
enable_subscription: bool, enable_subscription: bool,
enable_ui: Option<(String, Option<String>)>, enable_ui: Option<(String, Option<String>)>,
on_request: Option<BoxOnRequestFn<Query, Mutation, Subscription>>,
on_connect: Option<BoxOnConnectFn<Query, Mutation, Subscription>>,
} }
impl<Query, Mutation, Subscription> HandlerBuilder<Query, Mutation, Subscription> impl<Query, Mutation, Subscription> HandlerBuilder<Query, Mutation, Subscription>
@ -44,6 +62,8 @@ where
max_file_count: 9, max_file_count: 9,
enable_subscription: false, enable_subscription: false,
enable_ui: None, enable_ui: None,
on_request: None,
on_connect: None,
} }
} }
@ -85,6 +105,41 @@ where
} }
} }
/// When a new request arrives, you can use this closure to append your own data to the `QueryBuilder`.
pub fn on_request<
F: for<'a> Fn(
&HttpRequest,
QueryBuilder<'a, Query, Mutation, Subscription>,
) -> QueryBuilder<'a, Query, Mutation, Subscription>
+ 'static,
>(
self,
f: F,
) -> Self {
Self {
on_request: Some(Arc::new(f)),
..self
}
}
/// When there is a new subscription connection, you can use this closure to append your own data to the `SubscriptionConnectionBuilder`.
pub fn on_connect<
F: Fn(
&HttpRequest,
SubscriptionConnectionBuilder<Query, Mutation, Subscription, WebSocketTransport>,
)
-> SubscriptionConnectionBuilder<Query, Mutation, Subscription, WebSocketTransport>
+ 'static,
>(
self,
f: F,
) -> Self {
Self {
on_connect: Some(Arc::new(f)),
..self
}
}
/// Create an HTTP handler. /// Create an HTTP handler.
pub fn build( pub fn build(
self, self,
@ -99,10 +154,14 @@ where
let max_file_count = self.max_file_count; let max_file_count = self.max_file_count;
let enable_ui = self.enable_ui; let enable_ui = self.enable_ui;
let enable_subscription = self.enable_subscription; let enable_subscription = self.enable_subscription;
let on_request = self.on_request;
let on_connect = self.on_connect;
move |req: HttpRequest, payload: Payload| { move |req: HttpRequest, payload: Payload| {
let schema = schema.clone(); let schema = schema.clone();
let enable_ui = enable_ui.clone(); let enable_ui = enable_ui.clone();
let on_request = on_request.clone();
let on_connect = on_connect.clone();
Box::pin(async move { Box::pin(async move {
if req.method() == Method::GET { if req.method() == Method::GET {
@ -111,7 +170,11 @@ where
if let Ok(s) = s.to_str() { if let Ok(s) = s.to_str() {
if s.to_ascii_lowercase().contains("websocket") { if s.to_ascii_lowercase().contains("websocket") {
return ws::start_with_protocols( return ws::start_with_protocols(
WsSession::new(schema.clone()), WsSession::new(
schema.clone(),
req.clone(),
on_connect.clone(),
),
&["graphql-ws"], &["graphql-ws"],
&req, &req,
payload, payload,
@ -132,7 +195,15 @@ where
} }
if req.method() == Method::POST { if req.method() == Method::POST {
handle_request(&schema, max_file_size, max_file_count, req, payload).await handle_request(
&schema,
max_file_size,
max_file_count,
req,
payload,
on_request.as_ref(),
)
.await
} else { } else {
Ok(HttpResponse::MethodNotAllowed().finish()) Ok(HttpResponse::MethodNotAllowed().finish())
} }
@ -147,6 +218,7 @@ async fn handle_request<Query, Mutation, Subscription>(
max_file_count: usize, max_file_count: usize,
req: HttpRequest, req: HttpRequest,
mut payload: Payload, mut payload: Payload,
on_request: Option<&BoxOnRequestFn<Query, Mutation, Subscription>>,
) -> actix_web::Result<HttpResponse> ) -> actix_web::Result<HttpResponse>
where where
Query: ObjectType + Send + Sync + 'static, Query: ObjectType + Send + Sync + 'static,
@ -171,7 +243,15 @@ where
.map_err(actix_web::error::ErrorBadRequest)? .map_err(actix_web::error::ErrorBadRequest)?
}; };
let mut prepared = match gql_request.prepare(schema) { let mut builder = gql_request
.builder(schema)
.map_err(actix_web::error::ErrorBadRequest)?;
if let Some(on_request) = on_request {
builder = on_request(&req, builder);
}
let mut prepared = match builder.prepare() {
Ok(prepared) => prepared, Ok(prepared) => prepared,
Err(err) => return Ok(web::Json(GQLResponse(Err(err))).respond_to(&req).await?), Err(err) => return Ok(web::Json(GQLResponse(Err(err))).respond_to(&req).await?),
}; };
@ -243,7 +323,13 @@ where
let mut gql_req = web::Json::<GQLRequest>::from_request(&req, &mut payload.0) let mut gql_req = web::Json::<GQLRequest>::from_request(&req, &mut payload.0)
.await? .await?
.into_inner(); .into_inner();
let prepared = match gql_req.prepare(schema) { let mut builder = gql_req
.builder(schema)
.map_err(actix_web::error::ErrorBadRequest)?;
if let Some(on_request) = on_request {
builder = on_request(&req, builder);
}
let prepared = match builder.prepare() {
Ok(prepared) => prepared, Ok(prepared) => prepared,
Err(err) => return Ok(web::Json(GQLResponse(Err(err))).respond_to(&req).await?), Err(err) => return Ok(web::Json(GQLResponse(Err(err))).respond_to(&req).await?),
}; };

View File

@ -1,6 +1,8 @@
use crate::BoxOnConnectFn;
use actix::{ use actix::{
Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture, Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, StreamHandler, WrapFuture,
}; };
use actix_web::HttpRequest;
use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext}; use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext};
use async_graphql::{ObjectType, Schema, SubscriptionType, WebSocketTransport}; use async_graphql::{ObjectType, Schema, SubscriptionType, WebSocketTransport};
use bytes::Bytes; use bytes::Bytes;
@ -9,9 +11,11 @@ use futures::SinkExt;
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
pub struct WsSession<Query, Mutation, Subscription> { pub struct WsSession<Query, Mutation, Subscription> {
req: HttpRequest,
schema: Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
hb: Instant, hb: Instant,
sink: Option<mpsc::Sender<Bytes>>, sink: Option<mpsc::Sender<Bytes>>,
on_connect: Option<BoxOnConnectFn<Query, Mutation, Subscription>>,
} }
impl<Query, Mutation, Subscription> WsSession<Query, Mutation, Subscription> impl<Query, Mutation, Subscription> WsSession<Query, Mutation, Subscription>
@ -20,11 +24,17 @@ where
Mutation: ObjectType + Send + Sync + 'static, Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static, Subscription: SubscriptionType + Send + Sync + 'static,
{ {
pub fn new(schema: Schema<Query, Mutation, Subscription>) -> Self { pub fn new(
schema: Schema<Query, Mutation, Subscription>,
req: HttpRequest,
on_connect: Option<BoxOnConnectFn<Query, Mutation, Subscription>>,
) -> Self {
Self { Self {
req,
schema, schema,
hb: Instant::now(), hb: Instant::now(),
sink: None, sink: None,
on_connect,
} }
} }
@ -48,11 +58,16 @@ where
fn started(&mut self, ctx: &mut Self::Context) { fn started(&mut self, ctx: &mut Self::Context) {
self.hb(ctx); self.hb(ctx);
let schema = self.schema.clone(); let schema = self.schema.clone();
let on_connect = self.on_connect.clone();
let req = self.req.clone();
async move { async move {
schema let mut builder = schema
.clone() .clone()
.subscription_connection(WebSocketTransport::default()) .subscription_connection(WebSocketTransport::default());
.await if let Some(on_connect) = on_connect {
builder = on_connect(&req, builder);
}
builder.build().await
} }
.into_actor(self) .into_actor(self)
.then(|(sink, stream), actor, ctx| { .then(|(sink, stream), actor, ctx| {

View File

@ -1,6 +1,6 @@
[package] [package]
name = "async-graphql-derive" name = "async-graphql-derive"
version = "1.6.7" version = "1.6.8"
authors = ["sunli <scott_s829@163.com>"] authors = ["sunli <scott_s829@163.com>"]
edition = "2018" edition = "2018"
description = "Macros for async-graphql" description = "Macros for async-graphql"

View File

@ -2,13 +2,11 @@ use crate::extensions::BoxExtension;
use crate::registry::Registry; use crate::registry::Registry;
use crate::{ErrorWithPosition, InputValueType, QueryError, Result, Type}; use crate::{ErrorWithPosition, InputValueType, QueryError, Result, Type};
use bytes::Bytes; use bytes::Bytes;
use fnv::FnvHasher;
use graphql_parser::query::{ use graphql_parser::query::{
Directive, Field, FragmentDefinition, SelectionSet, Value, VariableDefinition, Directive, Field, FragmentDefinition, SelectionSet, Value, VariableDefinition,
}; };
use std::any::{Any, TypeId}; use std::any::{Any, TypeId};
use std::collections::{BTreeMap, HashMap}; use std::collections::{BTreeMap, HashMap};
use std::hash::BuildHasherDefault;
use std::ops::{Deref, DerefMut}; use std::ops::{Deref, DerefMut};
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
@ -131,7 +129,7 @@ fn json_value_to_gql_value(value: serde_json::Value) -> Value {
} }
#[derive(Default)] #[derive(Default)]
pub struct Data(HashMap<TypeId, Box<dyn Any + Sync + Send>, BuildHasherDefault<FnvHasher>>); pub struct Data(BTreeMap<TypeId, Box<dyn Any + Sync + Send>>);
impl Data { impl Data {
pub fn insert<D: Any + Send + Sync>(&mut self, data: D) { pub fn insert<D: Any + Send + Sync>(&mut self, data: D) {
@ -220,6 +218,7 @@ pub struct ContextBase<'a, T> {
pub(crate) variable_definitions: Option<&'a [VariableDefinition]>, pub(crate) variable_definitions: Option<&'a [VariableDefinition]>,
pub(crate) registry: &'a Registry, pub(crate) registry: &'a Registry,
pub(crate) data: &'a Data, pub(crate) data: &'a Data,
pub(crate) ctx_data: Option<&'a Data>,
pub(crate) fragments: &'a HashMap<String, FragmentDefinition>, pub(crate) fragments: &'a HashMap<String, FragmentDefinition>,
} }
@ -257,6 +256,7 @@ impl<'a, T> ContextBase<'a, T> {
variable_definitions: self.variable_definitions, variable_definitions: self.variable_definitions,
registry: self.registry, registry: self.registry,
data: self.data, data: self.data,
ctx_data: self.ctx_data,
fragments: self.fragments, fragments: self.fragments,
} }
} }
@ -275,15 +275,16 @@ impl<'a, T> ContextBase<'a, T> {
variable_definitions: self.variable_definitions, variable_definitions: self.variable_definitions,
registry: self.registry, registry: self.registry,
data: self.data, data: self.data,
ctx_data: self.ctx_data,
fragments: self.fragments, fragments: self.fragments,
} }
} }
/// Gets the global data defined in the `Schema`. /// Gets the global data defined in the `Context` or `Schema`.
pub fn data<D: Any + Send + Sync>(&self) -> &D { pub fn data<D: Any + Send + Sync>(&self) -> &D {
self.data self.ctx_data
.0 .and_then(|ctx_data| ctx_data.0.get(&TypeId::of::<D>()))
.get(&TypeId::of::<D>()) .or_else(|| self.data.0.get(&TypeId::of::<D>()))
.and_then(|d| d.downcast_ref::<D>()) .and_then(|d| d.downcast_ref::<D>())
.expect("The specified data type does not exist.") .expect("The specified data type does not exist.")
} }
@ -413,6 +414,7 @@ impl<'a> ContextBase<'a, &'a SelectionSet> {
variable_definitions: self.variable_definitions, variable_definitions: self.variable_definitions,
registry: self.registry, registry: self.registry,
data: self.data, data: self.data,
ctx_data: self.ctx_data,
fragments: self.fragments, fragments: self.fragments,
} }
} }

View File

@ -1,12 +1,10 @@
use crate::extensions::{Extension, ResolveInfo}; use crate::extensions::{Extension, ResolveInfo};
use crate::QueryPathSegment; use crate::QueryPathSegment;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use fnv::FnvHasher;
use parking_lot::Mutex; use parking_lot::Mutex;
use serde::ser::SerializeMap; use serde::ser::SerializeMap;
use serde::{Serialize, Serializer}; use serde::{Serialize, Serializer};
use std::collections::HashMap; use std::collections::BTreeMap;
use std::hash::BuildHasherDefault;
use std::ops::Deref; use std::ops::Deref;
struct PendingResolve { struct PendingResolve {
@ -50,7 +48,7 @@ impl Serialize for ResolveStat {
struct Inner { struct Inner {
start_time: DateTime<Utc>, start_time: DateTime<Utc>,
end_time: DateTime<Utc>, end_time: DateTime<Utc>,
pending_resolves: HashMap<usize, PendingResolve, BuildHasherDefault<FnvHasher>>, pending_resolves: BTreeMap<usize, PendingResolve>,
resolves: Vec<ResolveStat>, resolves: Vec<ResolveStat>,
} }

View File

@ -8,7 +8,10 @@ pub use playground_source::playground_source;
use crate::error::{ExtendedError, RuleError, RuleErrors}; use crate::error::{ExtendedError, RuleError, RuleErrors};
use crate::query::PreparedQuery; use crate::query::PreparedQuery;
use crate::{ObjectType, PositionError, QueryResult, Result, Schema, SubscriptionType, Variables}; use crate::{
ObjectType, PositionError, QueryBuilder, QueryResult, Result, Schema, SubscriptionType,
Variables,
};
use graphql_parser::Pos; use graphql_parser::Pos;
use serde::ser::{SerializeMap, SerializeSeq}; use serde::ser::{SerializeMap, SerializeSeq};
use serde::{Serialize, Serializer}; use serde::{Serialize, Serializer};
@ -45,11 +48,11 @@ impl GQLRequest {
} }
} }
/// Prepare a query and return a `PreparedQuery` object that gets some information about the query. /// Create query builder
pub fn prepare<'a, Query, Mutation, Subscription>( pub fn builder<'a, Query, Mutation, Subscription>(
&'a mut self, &'a mut self,
schema: &'a Schema<Query, Mutation, Subscription>, schema: &'a Schema<Query, Mutation, Subscription>,
) -> Result<PreparedQuery<'a, Query, Mutation>> ) -> Result<QueryBuilder<'a, Query, Mutation, Subscription>>
where where
Query: ObjectType + Send + Sync + 'static, Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static, Mutation: ObjectType + Send + Sync + 'static,
@ -71,7 +74,20 @@ impl GQLRequest {
Some(name) => query.operator_name(name), Some(name) => query.operator_name(name),
None => query, None => query,
}; };
query.prepare() Ok(query)
}
/// Prepare a query and return a `PreparedQuery` object that gets some information about the query.
pub fn prepare<'a, Query, Mutation, Subscription>(
&'a mut self,
schema: &'a Schema<Query, Mutation, Subscription>,
) -> Result<PreparedQuery<'a, Query, Mutation>>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
self.builder(schema)?.prepare()
} }
} }

View File

@ -39,6 +39,7 @@
//! * Custom extension //! * Custom extension
//! * Apollo Tracing extension //! * Apollo Tracing extension
//! * Limit query complexity/depth //! * Limit query complexity/depth
//! * Error Extensions
//! //!
//! ## Integrations //! ## Integrations
//! //!
@ -108,8 +109,8 @@ pub use registry::CacheControl;
pub use scalars::ID; pub use scalars::ID;
pub use schema::{publish, Schema}; pub use schema::{publish, Schema};
pub use subscription::{ pub use subscription::{
SubscriptionStream, SubscriptionStub, SubscriptionStubs, SubscriptionTransport, SubscriptionConnectionBuilder, SubscriptionStream, SubscriptionStub, SubscriptionStubs,
WebSocketTransport, SubscriptionTransport, WebSocketTransport,
}; };
pub use types::{ pub use types::{
Connection, DataSource, EmptyEdgeFields, EmptyMutation, EmptySubscription, QueryOperation, Connection, DataSource, EmptyEdgeFields, EmptyMutation, EmptySubscription, QueryOperation,

View File

@ -10,6 +10,7 @@ use graphql_parser::parse_query;
use graphql_parser::query::{ use graphql_parser::query::{
Definition, FragmentDefinition, OperationDefinition, SelectionSet, VariableDefinition, Definition, FragmentDefinition, OperationDefinition, SelectionSet, VariableDefinition,
}; };
use std::any::Any;
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
@ -26,6 +27,7 @@ pub struct QueryBuilder<'a, Query, Mutation, Subscription> {
pub(crate) operation_name: Option<&'a str>, pub(crate) operation_name: Option<&'a str>,
pub(crate) variables: Option<Variables>, pub(crate) variables: Option<Variables>,
pub(crate) data: &'a Data, pub(crate) data: &'a Data,
pub(crate) ctx_data: Option<Data>,
} }
impl<'a, Query, Mutation, Subscription> QueryBuilder<'a, Query, Mutation, Subscription> { impl<'a, Query, Mutation, Subscription> QueryBuilder<'a, Query, Mutation, Subscription> {
@ -45,6 +47,18 @@ impl<'a, Query, Mutation, Subscription> QueryBuilder<'a, Query, Mutation, Subscr
} }
} }
/// Add a context data that can be accessed in the `Context`, you access it with `Context::data`.
pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
if let Some(ctx_data) = &mut self.ctx_data {
ctx_data.insert(data);
} else {
let mut ctx_data = Data::default();
ctx_data.insert(data);
self.ctx_data = Some(ctx_data);
}
self
}
/// Prepare query /// Prepare query
pub fn prepare(self) -> Result<PreparedQuery<'a, Query, Mutation>> { pub fn prepare(self) -> Result<PreparedQuery<'a, Query, Mutation>> {
self.extensions self.extensions
@ -119,6 +133,7 @@ impl<'a, Query, Mutation, Subscription> QueryBuilder<'a, Query, Mutation, Subscr
registry: &self.schema.0.registry, registry: &self.schema.0.registry,
variables: self.variables.unwrap_or_default(), variables: self.variables.unwrap_or_default(),
data: self.data, data: self.data,
ctx_data: self.ctx_data,
fragments, fragments,
selection_set: selection_set.ok_or({ selection_set: selection_set.ok_or({
if let Some(name) = self.operation_name { if let Some(name) = self.operation_name {
@ -161,6 +176,7 @@ pub struct PreparedQuery<'a, Query, Mutation> {
registry: &'a Registry, registry: &'a Registry,
variables: Variables, variables: Variables,
data: &'a Data, data: &'a Data,
ctx_data: Option<Data>,
fragments: HashMap<String, FragmentDefinition>, fragments: HashMap<String, FragmentDefinition>,
selection_set: SelectionSet, selection_set: SelectionSet,
variable_definitions: Option<Vec<VariableDefinition>>, variable_definitions: Option<Vec<VariableDefinition>>,
@ -210,6 +226,7 @@ impl<'a, Query, Mutation> PreparedQuery<'a, Query, Mutation> {
variable_definitions: self.variable_definitions.as_deref(), variable_definitions: self.variable_definitions.as_deref(),
registry: self.registry, registry: self.registry,
data: self.data, data: self.data,
ctx_data: self.ctx_data.as_ref(),
fragments: &self.fragments, fragments: &self.fragments,
}; };

View File

@ -3,14 +3,13 @@ use crate::extensions::{BoxExtension, Extension};
use crate::model::__DirectiveLocation; use crate::model::__DirectiveLocation;
use crate::query::QueryBuilder; use crate::query::QueryBuilder;
use crate::registry::{Directive, InputValue, Registry}; use crate::registry::{Directive, InputValue, Registry};
use crate::subscription::{create_connection, SubscriptionStub, SubscriptionTransport}; use crate::subscription::{SubscriptionConnectionBuilder, SubscriptionStub, SubscriptionTransport};
use crate::types::QueryRoot; use crate::types::QueryRoot;
use crate::validation::check_rules; use crate::validation::check_rules;
use crate::{ use crate::{
ContextSelectionSet, ObjectType, QueryError, QueryParseError, Result, SubscriptionStream, ContextSelectionSet, ObjectType, QueryError, QueryParseError, Result, SubscriptionType, Type,
SubscriptionType, Type, Variables, Variables,
}; };
use bytes::Bytes;
use futures::channel::mpsc; use futures::channel::mpsc;
use futures::lock::Mutex; use futures::lock::Mutex;
use futures::SinkExt; use futures::SinkExt;
@ -75,7 +74,7 @@ impl<Query: ObjectType, Mutation: ObjectType, Subscription: SubscriptionType>
self self
} }
/// Add a global data that can be accessed in the `Context`. /// Add a global data that can be accessed in the `Schema`, you access it with `Context::data`.
pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self { pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
self.0.data.insert(data); self.0.data.insert(data);
self self
@ -220,6 +219,7 @@ where
operation_name: None, operation_name: None,
variables: None, variables: None,
data: &self.0.data, data: &self.0.data,
ctx_data: None,
} }
} }
@ -273,6 +273,7 @@ where
variable_definitions: Some(&subscription.variable_definitions), variable_definitions: Some(&subscription.variable_definitions),
registry: &self.0.registry, registry: &self.0.registry,
data: &Default::default(), data: &Default::default(),
ctx_data: None,
fragments: &fragments, fragments: &fragments,
}; };
create_subscription_types::<Subscription>(&ctx, &fragments, &mut types)?; create_subscription_types::<Subscription>(&ctx, &fragments, &mut types)?;
@ -282,18 +283,20 @@ where
variables, variables,
variable_definitions: subscription.variable_definitions, variable_definitions: subscription.variable_definitions,
fragments, fragments,
ctx_data: None,
}) })
} }
/// Create subscription connection, returns `Sink` and `Stream`. /// Create subscription connection, returns `SubscriptionConnectionBuilder`.
pub async fn subscription_connection<T: SubscriptionTransport>( pub fn subscription_connection<T: SubscriptionTransport>(
&self, &self,
transport: T, transport: T,
) -> ( ) -> SubscriptionConnectionBuilder<Query, Mutation, Subscription, T> {
mpsc::Sender<Bytes>, SubscriptionConnectionBuilder {
SubscriptionStream<Query, Mutation, Subscription, T>, schema: self.clone(),
) { transport,
create_connection(self, transport).await ctx_data: None,
}
} }
} }

View File

@ -1,3 +1,4 @@
use crate::context::Data;
use crate::schema::SUBSCRIPTION_SENDERS; use crate::schema::SUBSCRIPTION_SENDERS;
use crate::subscription::SubscriptionStub; use crate::subscription::SubscriptionStub;
use crate::{ObjectType, Result, Schema, SubscriptionType}; use crate::{ObjectType, Result, Schema, SubscriptionType};
@ -12,24 +13,20 @@ use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
/// Subscription stubs, use to hold all subscription information for the `SubscriptionConnection` /// Subscription stubs, use to hold all subscription information for the `SubscriptionConnection`
pub struct SubscriptionStubs<Query, Mutation, Subscription>( pub struct SubscriptionStubs<Query, Mutation, Subscription> {
Slab<SubscriptionStub<Query, Mutation, Subscription>>, stubs: Slab<SubscriptionStub<Query, Mutation, Subscription>>,
); ctx_data: Option<Arc<Data>>,
impl<Query, Mutation, Subscription> Default for SubscriptionStubs<Query, Mutation, Subscription> {
fn default() -> Self {
Self(Slab::new())
}
} }
#[allow(missing_docs)] #[allow(missing_docs)]
impl<Query, Mutation, Subscription> SubscriptionStubs<Query, Mutation, Subscription> { impl<Query, Mutation, Subscription> SubscriptionStubs<Query, Mutation, Subscription> {
pub fn add(&mut self, stub: SubscriptionStub<Query, Mutation, Subscription>) -> usize { pub fn add(&mut self, mut stub: SubscriptionStub<Query, Mutation, Subscription>) -> usize {
self.0.insert(stub) stub.ctx_data = self.ctx_data.clone();
self.stubs.insert(stub)
} }
pub fn remove(&mut self, id: usize) { pub fn remove(&mut self, id: usize) {
self.0.remove(id); self.stubs.remove(id);
} }
} }
@ -56,8 +53,9 @@ pub trait SubscriptionTransport: Send + Sync + Unpin + 'static {
} }
pub async fn create_connection<Query, Mutation, Subscription, T: SubscriptionTransport>( pub async fn create_connection<Query, Mutation, Subscription, T: SubscriptionTransport>(
schema: &Schema<Query, Mutation, Subscription>, schema: Schema<Query, Mutation, Subscription>,
transport: T, transport: T,
ctx_data: Option<Data>,
) -> ( ) -> (
mpsc::Sender<Bytes>, mpsc::Sender<Bytes>,
SubscriptionStream<Query, Mutation, Subscription, T>, SubscriptionStream<Query, Mutation, Subscription, T>,
@ -74,9 +72,12 @@ where
( (
tx_bytes.clone(), tx_bytes.clone(),
SubscriptionStream { SubscriptionStream {
schema: schema.clone(), schema,
transport, transport,
stubs: Default::default(), stubs: SubscriptionStubs {
stubs: Default::default(),
ctx_data: ctx_data.map(Arc::new),
},
rx_bytes, rx_bytes,
rx_msg, rx_msg,
send_queue: VecDeque::new(), send_queue: VecDeque::new(),
@ -151,7 +152,7 @@ where
let send_queue = &mut this.send_queue as *mut VecDeque<Bytes>; let send_queue = &mut this.send_queue as *mut VecDeque<Bytes>;
let fut = async move { let fut = async move {
unsafe { unsafe {
for (id, stub) in (*stubs).0.iter() { for (id, stub) in (*stubs).stubs.iter() {
if let Some(res) = stub.resolve(msg.as_ref()).await.transpose() { if let Some(res) = stub.resolve(msg.as_ref()).await.transpose() {
if let Some(bytes) = (*transport).handle_response(id, res) { if let Some(bytes) = (*transport).handle_response(id, res) {
(*send_queue).push_back(bytes); (*send_queue).push_back(bytes);

View File

@ -0,0 +1,43 @@
use crate::context::Data;
use crate::subscription::create_connection;
use crate::{ObjectType, Schema, SubscriptionStream, SubscriptionTransport, SubscriptionType};
use bytes::Bytes;
use futures::channel::mpsc;
use std::any::Any;
/// SubscriptionConnection builder
pub struct SubscriptionConnectionBuilder<Query, Mutation, Subscription, T: SubscriptionTransport> {
pub(crate) schema: Schema<Query, Mutation, Subscription>,
pub(crate) transport: T,
pub(crate) ctx_data: Option<Data>,
}
impl<Query, Mutation, Subscription, T: SubscriptionTransport>
SubscriptionConnectionBuilder<Query, Mutation, Subscription, T>
where
Query: ObjectType + Send + Sync + 'static,
Mutation: ObjectType + Send + Sync + 'static,
Subscription: SubscriptionType + Send + Sync + 'static,
{
/// Add a context data that can be accessed in the `Context`, you access it with `Context::data`.
pub fn data<D: Any + Send + Sync>(mut self, data: D) -> Self {
if let Some(ctx_data) = &mut self.ctx_data {
ctx_data.insert(data);
} else {
let mut ctx_data = Data::default();
ctx_data.insert(data);
self.ctx_data = Some(ctx_data);
}
self
}
/// Create subscription connection, returns `Sink` and `Stream`.
pub async fn build(
self,
) -> (
mpsc::Sender<Bytes>,
SubscriptionStream<Query, Mutation, Subscription, T>,
) {
create_connection(self.schema, self.transport, self.ctx_data).await
}
}

View File

@ -1,11 +1,13 @@
mod connection; mod connection;
mod subscribe_stub; mod connection_builder;
mod subscription_stub;
mod subscription_type; mod subscription_type;
mod ws_transport; mod ws_transport;
pub use connection::{ pub use connection::{
create_connection, SubscriptionStream, SubscriptionStubs, SubscriptionTransport, create_connection, SubscriptionStream, SubscriptionStubs, SubscriptionTransport,
}; };
pub use subscribe_stub::SubscriptionStub; pub use connection_builder::SubscriptionConnectionBuilder;
pub use subscription_stub::SubscriptionStub;
pub use subscription_type::SubscriptionType; pub use subscription_type::SubscriptionType;
pub use ws_transport::WebSocketTransport; pub use ws_transport::WebSocketTransport;

View File

@ -1,8 +1,10 @@
use crate::context::Data;
use crate::{ContextBase, ObjectType, Result, Schema, SubscriptionType, Variables}; use crate::{ContextBase, ObjectType, Result, Schema, SubscriptionType, Variables};
use graphql_parser::query::{Field, FragmentDefinition, VariableDefinition}; use graphql_parser::query::{Field, FragmentDefinition, VariableDefinition};
use std::any::{Any, TypeId}; use std::any::{Any, TypeId};
use std::collections::HashMap; use std::collections::HashMap;
use std::sync::atomic::AtomicUsize; use std::sync::atomic::AtomicUsize;
use std::sync::Arc;
/// Subscription stub /// Subscription stub
/// ///
@ -14,6 +16,7 @@ pub struct SubscriptionStub<Query, Mutation, Subscription> {
pub(crate) variables: Variables, pub(crate) variables: Variables,
pub(crate) variable_definitions: Vec<VariableDefinition>, pub(crate) variable_definitions: Vec<VariableDefinition>,
pub(crate) fragments: HashMap<String, FragmentDefinition>, pub(crate) fragments: HashMap<String, FragmentDefinition>,
pub(crate) ctx_data: Option<Arc<Data>>,
} }
impl<Query, Mutation, Subscription> SubscriptionStub<Query, Mutation, Subscription> impl<Query, Mutation, Subscription> SubscriptionStub<Query, Mutation, Subscription>
@ -37,6 +40,7 @@ where
variable_definitions: Some(&self.variable_definitions), variable_definitions: Some(&self.variable_definitions),
registry: &self.schema.0.registry, registry: &self.schema.0.registry,
data: &self.schema.0.data, data: &self.schema.0.data,
ctx_data: self.ctx_data.as_deref(),
fragments: &self.fragments, fragments: &self.fragments,
}; };
self.schema self.schema