Support subscription

This commit is contained in:
sunli 2020-03-17 17:26:59 +08:00
parent 8ce7365686
commit e6cfaf134e
38 changed files with 1490 additions and 339 deletions

View File

@ -1,6 +1,6 @@
[package]
name = "async-graphql"
version = "1.1.1"
version = "1.2.0"
authors = ["sunli <scott_s829@163.com>"]
edition = "2018"
description = "The GraphQL server library implemented by rust"
@ -17,7 +17,7 @@ readme = "README.md"
default = ["chrono", "uuid"]
[dependencies]
async-graphql-derive = { path = "async-graphql-derive", version = "1.0.0" }
async-graphql-derive = { path = "async-graphql-derive", version = "1.2.0" }
graphql-parser = "0.2.3"
anyhow = "1.0.26"
thiserror = "1.0.11"
@ -26,6 +26,7 @@ serde = "1.0.104"
serde_derive = "1.0.104"
serde_json = "1.0.48"
fnv = "1.0.6"
bytes = "0.5.4"
chrono = { version = "0.4.10", optional = true }
uuid = { version = "0.8.1", optional = true }

View File

@ -46,6 +46,7 @@ Open `http://localhost:8000` in browser
* Custom scalar.
* Minimal overhead.
* Easy integration (hyper, actix_web, tide ...).
* Upload files && Subscription (async-graphql-actix-web)
## Goals
@ -93,6 +94,9 @@ Open `http://localhost:8000` in browser
- [X] Schema
- [X] Multipart Request (https://github.com/jaydenseric/graphql-multipart-request-spec)
- [X] Actix-web
- [X] Subscription
- [X] Filter
- [X] WebSocket transport
- [X] Validation rules
- [X] ArgumentsOfCorrectType
- [X] DefaultValuesOfCorrectType

View File

@ -1,6 +1,6 @@
[package]
name = "async-graphql-actix-web"
version = "0.1.0"
version = "0.2.0"
authors = ["sunli <scott_s829@163.com>"]
edition = "2018"
description = "The GraphQL server library implemented by rust"
@ -13,12 +13,19 @@ keywords = ["futures", "async", "graphql"]
categories = ["network-programming", "asynchronous"]
[dependencies]
async-graphql = { path = "..", version = "1.1.0" }
async-graphql = { path = "..", version = "1.2.0" }
actix-web = "2.0.0"
actix-multipart = "0.2.0"
actix-web-actors = "2.0.0"
actix = "0.9.0"
futures = "0.3.0"
serde_json = "1.0.48"
mime = "0.3.16"
bytes = "0.5.4"
serde = "1.0.104"
serde_derive = "1.0.104"
serde_json = "1.0.48"
slab = "0.4.2"
actix_derive = "0.5.0"
[dev-dependencies]
actix-rt = "1.0.0"

View File

@ -0,0 +1,132 @@
use actix_web::{web, App, HttpServer};
use async_graphql::{Context, Result, Schema, ID};
use async_graphql_actix_web::publish_message;
use futures::lock::Mutex;
use slab::Slab;
use std::sync::Arc;
#[derive(Clone)]
struct Book {
id: ID,
name: String,
author: String,
}
#[async_graphql::Object]
impl Book {
#[field]
async fn id(&self) -> &str {
&self.id
}
#[field]
async fn name(&self) -> &str {
&self.name
}
#[field]
async fn author(&self) -> &str {
&self.author
}
}
type Storage = Arc<Mutex<Slab<Book>>>;
struct QueryRoot;
#[async_graphql::Object]
impl QueryRoot {
#[field]
async fn books(&self, ctx: &Context<'_>) -> Vec<Book> {
let books = ctx.data::<Storage>().lock().await;
books.iter().map(|(_, book)| book).cloned().collect()
}
}
struct MutationRoot;
#[async_graphql::Object]
impl MutationRoot {
#[field]
async fn create_book(&self, ctx: &Context<'_>, name: String, author: String) -> ID {
let mut books = ctx.data::<Storage>().lock().await;
let entry = books.vacant_entry();
let id: ID = entry.key().into();
entry.insert(Book {
id: id.clone(),
name,
author,
});
publish_message(BookChanged {
mutation_type: MutationType::Created,
id: id.clone(),
});
id
}
#[field]
async fn delete_book(&self, ctx: &Context<'_>, id: ID) -> Result<bool> {
let mut books = ctx.data::<Storage>().lock().await;
let id = id.parse::<usize>()?;
if books.contains(id) {
books.remove(id);
publish_message(BookChanged {
mutation_type: MutationType::Deleted,
id: id.into(),
});
Ok(true)
} else {
Ok(false)
}
}
}
#[async_graphql::Enum]
enum MutationType {
Created,
Deleted,
}
struct BookChanged {
mutation_type: MutationType,
id: ID,
}
#[async_graphql::Object]
impl BookChanged {
#[field]
async fn mutation_type(&self) -> &MutationType {
&self.mutation_type
}
#[field]
async fn id(&self) -> &ID {
&self.id
}
}
struct SubscriptionRoot;
#[async_graphql::Subscription]
impl SubscriptionRoot {
#[field]
fn books(&self, changed: &BookChanged, name: Option<String>) -> bool {
true
}
}
#[actix_rt::main]
async fn main() -> std::io::Result<()> {
HttpServer::new(move || {
let schema =
Schema::new(QueryRoot, MutationRoot, SubscriptionRoot).data(Storage::default());
let handler = async_graphql_actix_web::HandlerBuilder::new(schema)
.enable_ui("http://localhost:8000", Some("ws://localhost:8000"))
.enable_subscription()
.build();
App::new().service(web::resource("/").to(handler))
})
.bind("127.0.0.1:8000")?
.run()
.await
}

View File

@ -1,5 +1,5 @@
use actix_web::{web, App, HttpServer};
use async_graphql::{Schema, Upload};
use async_graphql::{GQLEmptySubscription, Schema, Upload};
struct QueryRoot;
@ -36,8 +36,10 @@ impl MutationRoot {
#[actix_rt::main]
async fn main() -> std::io::Result<()> {
HttpServer::new(move || {
let schema = Schema::new(QueryRoot, MutationRoot);
let handler = async_graphql_actix_web::HandlerBuilder::new(schema).build();
let schema = Schema::new(QueryRoot, MutationRoot, GQLEmptySubscription);
let handler = async_graphql_actix_web::HandlerBuilder::new(schema)
.enable_subscription()
.build();
App::new().service(web::resource("/").to(handler))
})
.bind("127.0.0.1:8000")?

View File

@ -1,9 +1,20 @@
#[macro_use]
extern crate serde_derive;
#[macro_use]
extern crate actix_derive;
mod pubsub;
mod session;
use crate::session::WsSession;
use actix_multipart::Multipart;
use actix_web::http::{header, HeaderMap};
use actix_web::web::Payload;
use actix_web::http::{header, HeaderMap, Method};
use actix_web::web::{BytesMut, Payload};
use actix_web::{web, FromRequest, HttpRequest, HttpResponse, Responder};
use actix_web_actors::ws;
use async_graphql::http::{GQLRequest, GQLResponse};
use async_graphql::{GQLObject, Schema};
use async_graphql::{GQLObject, GQLSubscription, Schema};
use bytes::Bytes;
use futures::StreamExt;
use mime::Mime;
use std::collections::HashMap;
@ -11,23 +22,32 @@ use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
pub struct HandlerBuilder<Query, Mutation> {
schema: Schema<Query, Mutation>,
pub use pubsub::publish_message;
pub struct HandlerBuilder<Query, Mutation, Subscription> {
schema: Schema<Query, Mutation, Subscription>,
max_file_size: Option<usize>,
enable_subscription: bool,
enable_ui: Option<(String, Option<String>)>,
}
impl<Query, Mutation> HandlerBuilder<Query, Mutation>
impl<Query, Mutation, Subscription> HandlerBuilder<Query, Mutation, Subscription>
where
Query: GQLObject + Send + Sync + 'static,
Mutation: GQLObject + Send + Sync + 'static,
Subscription: GQLSubscription + Send + Sync + 'static,
{
pub fn new(schema: Schema<Query, Mutation>) -> Self {
/// Create an HTTP handler builder
pub fn new(schema: Schema<Query, Mutation, Subscription>) -> Self {
Self {
schema,
max_file_size: Some(1024 * 1024 * 2),
enable_subscription: false,
enable_ui: None,
}
}
/// Set the maximum file size for upload, no limit by default.
pub fn max_file_size(self, size: usize) -> Self {
Self {
max_file_size: Some(size),
@ -35,6 +55,29 @@ where
}
}
/// Enable GraphQL playground
///
/// 'endpoint' is the endpoint of the GraphQL Request.
/// 'subscription_endpoint' is the endpoint of the GraphQL Subscription.
pub fn enable_ui(self, endpoint: &str, subscription_endpoint: Option<&str>) -> Self {
Self {
enable_ui: Some((
endpoint.to_string(),
subscription_endpoint.map(|s| s.to_string()),
)),
..self
}
}
/// Enable GraphQL Subscription.
pub fn enable_subscription(self) -> Self {
Self {
enable_subscription: true,
..self
}
}
/// Create an HTTP handler.
pub fn build(
self,
) -> impl Fn(
@ -45,120 +88,156 @@ where
+ Clone {
let schema = Arc::new(self.schema);
let max_file_size = self.max_file_size;
let enable_ui = self.enable_ui;
let enable_subscription = self.enable_subscription;
move |req: HttpRequest, mut payload: Payload| {
move |req: HttpRequest, payload: Payload| {
let schema = schema.clone();
let enable_ui = enable_ui.clone();
Box::pin(async move {
if req.method() != "POST" {
return Ok(HttpResponse::MethodNotAllowed().finish());
if req.method() == Method::GET {
if enable_subscription {
if let Some(s) = req.headers().get(&header::UPGRADE) {
if let Ok(s) = s.to_str() {
if s.to_ascii_lowercase().contains("websocket") {
return ws::start_with_protocols(
WsSession::new(schema.clone()),
&["graphql-ws"],
&req,
payload,
);
}
}
}
}
if let Some((endpoint, subscription_endpoint)) = &enable_ui {
return Ok(HttpResponse::Ok()
.content_type("text/html; charset=utf-8")
.body(async_graphql::http::playground_source(
endpoint,
subscription_endpoint.as_deref(),
)));
}
}
if let Ok(ct) = get_content_type(req.headers()) {
if ct.essence_str() == mime::MULTIPART_FORM_DATA {
let mut multipart = Multipart::from_request(&req, &mut payload.0).await?;
// read operators
let mut gql_request = {
let data = read_multipart(&mut multipart, "operations").await?;
serde_json::from_slice::<GQLRequest>(&data)
.map_err(|err| actix_web::error::ErrorBadRequest(err))?
};
// read map
let mut map = {
let data = read_multipart(&mut multipart, "map").await?;
serde_json::from_slice::<HashMap<String, Vec<String>>>(&data)
.map_err(|err| actix_web::error::ErrorBadRequest(err))?
};
let mut query = match gql_request.prepare(&schema) {
Ok(query) => query,
Err(err) => {
return Ok(web::Json(GQLResponse(Err(err)))
.respond_to(&req)
.await?)
}
};
if !query.is_upload() {
return Err(actix_web::error::ErrorBadRequest(
"It's not an upload operation",
));
}
// read files
while let Some(field) = multipart.next().await {
let mut field = field?;
if let Some(content_disposition) = field.content_disposition() {
if let (Some(name), Some(filename)) = (
content_disposition.get_name(),
content_disposition.get_filename(),
) {
if let Some(var_paths) = map.remove(name) {
let content_type = field.content_type().to_string();
let mut data = Vec::<u8>::new();
while let Some(part) = field.next().await {
let part = part.map_err(|err| {
actix_web::error::ErrorBadRequest(err)
})?;
data.extend(&part);
if let Some(max_file_size) = max_file_size {
if data.len() > max_file_size {
return Err(
actix_web::error::ErrorPayloadTooLarge(
"payload to large",
),
);
}
}
}
for var_path in var_paths {
query.set_upload(
&var_path,
filename,
Some(&content_type),
data.clone(),
);
}
} else {
return Err(actix_web::error::ErrorBadRequest(
"bad request",
));
}
} else {
return Err(actix_web::error::ErrorBadRequest("bad request"));
}
} else {
return Err(actix_web::error::ErrorBadRequest("bad request"));
}
}
if !map.is_empty() {
return Err(actix_web::error::ErrorBadRequest("missing files"));
}
Ok(web::Json(GQLResponse(query.execute().await))
.respond_to(&req)
.await?)
} else if ct.essence_str() == mime::APPLICATION_JSON {
let gql_req =
web::Json::<GQLRequest>::from_request(&req, &mut payload.0).await?;
Ok(web::Json(gql_req.into_inner().execute(&schema).await)
.respond_to(&req)
.await?)
} else {
Ok(HttpResponse::UnsupportedMediaType().finish())
}
if req.method() == Method::POST {
handle_request(&schema, max_file_size, req, payload).await
} else {
Ok(HttpResponse::UnsupportedMediaType().finish())
Ok(HttpResponse::MethodNotAllowed().finish())
}
})
}
}
}
async fn handle_request<Query, Mutation, Subscription>(
schema: &Schema<Query, Mutation, Subscription>,
max_file_size: Option<usize>,
req: HttpRequest,
mut payload: Payload,
) -> actix_web::Result<HttpResponse>
where
Query: GQLObject + Send + Sync,
Mutation: GQLObject + Send + Sync,
Subscription: GQLSubscription + Send + Sync,
{
if let Ok(ct) = get_content_type(req.headers()) {
if ct.essence_str() == mime::MULTIPART_FORM_DATA {
let mut multipart = Multipart::from_request(&req, &mut payload.0).await?;
// read operators
let mut gql_request = {
let data = read_multipart(&mut multipart, "operations").await?;
serde_json::from_slice::<GQLRequest>(&data)
.map_err(|err| actix_web::error::ErrorBadRequest(err))?
};
// read map
let mut map = {
let data = read_multipart(&mut multipart, "map").await?;
serde_json::from_slice::<HashMap<String, Vec<String>>>(&data)
.map_err(|err| actix_web::error::ErrorBadRequest(err))?
};
let mut query = match gql_request.prepare(schema) {
Ok(query) => query,
Err(err) => return Ok(web::Json(GQLResponse(Err(err))).respond_to(&req).await?),
};
if !query.is_upload() {
return Err(actix_web::error::ErrorBadRequest(
"It's not an upload operation",
));
}
// read files
while let Some(field) = multipart.next().await {
let mut field = field?;
if let Some(content_disposition) = field.content_disposition() {
if let (Some(name), Some(filename)) = (
content_disposition.get_name(),
content_disposition.get_filename(),
) {
if let Some(var_paths) = map.remove(name) {
let content_type = field.content_type().to_string();
let mut data = BytesMut::new();
while let Some(part) = field.next().await {
let part =
part.map_err(|err| actix_web::error::ErrorBadRequest(err))?;
data.extend(&part);
if let Some(max_file_size) = max_file_size {
if data.len() > max_file_size {
return Err(actix_web::error::ErrorPayloadTooLarge(
"payload to large",
));
}
}
}
let data = data.freeze();
for var_path in var_paths {
query.set_upload(
&var_path,
filename,
Some(&content_type),
data.clone(),
);
}
} else {
return Err(actix_web::error::ErrorBadRequest("bad request"));
}
} else {
return Err(actix_web::error::ErrorBadRequest("bad request"));
}
} else {
return Err(actix_web::error::ErrorBadRequest("bad request"));
}
}
if !map.is_empty() {
return Err(actix_web::error::ErrorBadRequest("missing files"));
}
Ok(web::Json(GQLResponse(query.execute().await))
.respond_to(&req)
.await?)
} else if ct.essence_str() == mime::APPLICATION_JSON {
let gql_req = web::Json::<GQLRequest>::from_request(&req, &mut payload.0).await?;
Ok(web::Json(gql_req.into_inner().execute(&schema).await)
.respond_to(&req)
.await?)
} else {
Ok(HttpResponse::UnsupportedMediaType().finish())
}
} else {
Ok(HttpResponse::UnsupportedMediaType().finish())
}
}
fn get_content_type(headers: &HeaderMap) -> actix_web::Result<Mime> {
if let Some(content_type) = headers.get(&header::CONTENT_TYPE) {
if let Ok(content_type) = content_type.to_str() {
@ -172,7 +251,7 @@ fn get_content_type(headers: &HeaderMap) -> actix_web::Result<Mime> {
))
}
async fn read_multipart(multipart: &mut Multipart, name: &str) -> actix_web::Result<Vec<u8>> {
async fn read_multipart(multipart: &mut Multipart, name: &str) -> actix_web::Result<Bytes> {
let data = match multipart.next().await {
Some(Ok(mut field)) => {
if let Some(content_disposition) = field.content_disposition() {
@ -184,7 +263,7 @@ async fn read_multipart(multipart: &mut Multipart, name: &str) -> actix_web::Res
)));
}
let mut data = Vec::<u8>::new();
let mut data = BytesMut::new();
while let Some(part) = field.next().await {
let part = part.map_err(|err| actix_web::error::ErrorBadRequest(err))?;
data.extend(&part);
@ -200,5 +279,5 @@ async fn read_multipart(multipart: &mut Multipart, name: &str) -> actix_web::Res
Some(Err(err)) => return Err(err.into()),
None => return Err(actix_web::error::ErrorBadRequest("bad request")),
};
Ok(data)
Ok(data.freeze())
}

View File

@ -0,0 +1,86 @@
use actix::{Actor, Context, Handler, Recipient, Supervised, SystemService};
use async_graphql::Result;
use slab::Slab;
use std::any::Any;
use std::sync::Arc;
#[derive(Message)]
#[rtype(result = "std::result::Result<(), ()>")]
pub struct PushMessage(pub Arc<dyn Any + Sync + Send>);
#[derive(Message)]
#[rtype(result = "usize")]
struct NewClient {
recipient: Recipient<PushMessage>,
}
#[derive(Message)]
#[rtype(result = "()")]
struct RemoveClient {
id: usize,
}
#[derive(Message)]
#[rtype(result = "()")]
struct PubMessage(Arc<dyn Any + Sync + Send>);
struct ClientInfo {
recipient: Recipient<PushMessage>,
}
#[derive(Default)]
struct PubSubService {
clients: Slab<ClientInfo>,
}
impl Actor for PubSubService {
type Context = Context<Self>;
}
impl Handler<NewClient> for PubSubService {
type Result = usize;
fn handle(&mut self, msg: NewClient, _ctx: &mut Context<Self>) -> Self::Result {
self.clients.insert(ClientInfo {
recipient: msg.recipient,
})
}
}
impl Handler<RemoveClient> for PubSubService {
type Result = ();
fn handle(&mut self, msg: RemoveClient, _ctx: &mut Context<Self>) -> Self::Result {
self.clients.remove(msg.id);
}
}
impl Handler<PubMessage> for PubSubService {
type Result = ();
fn handle(&mut self, msg: PubMessage, _ctx: &mut Context<Self>) -> Self::Result {
for (_, client) in &self.clients {
client.recipient.do_send(PushMessage(msg.0.clone())).ok();
}
}
}
impl Supervised for PubSubService {}
impl SystemService for PubSubService {}
pub async fn new_client(recipient: Recipient<PushMessage>) -> Result<usize> {
let id = PubSubService::from_registry()
.send(NewClient { recipient })
.await?;
Ok(id)
}
pub fn remove_client(id: usize) {
PubSubService::from_registry().do_send(RemoveClient { id });
}
/// Publish a message that will be pushed to all subscribed clients.
pub fn publish_message<T: Any + Send + Sync + Sized>(msg: T) {
PubSubService::from_registry().do_send(PubMessage(Arc::new(msg)));
}

View File

@ -0,0 +1,206 @@
use crate::pubsub::{new_client, remove_client, PushMessage};
use actix::{
Actor, ActorContext, ActorFuture, AsyncContext, ContextFutureSpawner, Handler,
ResponseActFuture, Running, StreamHandler, WrapFuture,
};
use actix_web_actors::ws::{Message, ProtocolError, WebsocketContext};
use async_graphql::http::{GQLError, GQLRequest, GQLResponse};
use async_graphql::{GQLObject, GQLSubscription, Schema, Subscribe, Variables};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
#[derive(Serialize, Deserialize)]
struct OperationMessage {
#[serde(rename = "type")]
ty: String,
id: Option<String>,
payload: Option<serde_json::Value>,
}
pub struct WsSession<Query, Mutation, Subscription> {
schema: Arc<Schema<Query, Mutation, Subscription>>,
hb: Instant,
client_id: usize,
subscribes: HashMap<String, Arc<Subscribe>>,
}
impl<Query, Mutation, Subscription> WsSession<Query, Mutation, Subscription>
where
Query: GQLObject + Send + Sync + 'static,
Mutation: GQLObject + Send + Sync + 'static,
Subscription: GQLSubscription + Send + Sync + 'static,
{
pub fn new(schema: Arc<Schema<Query, Mutation, Subscription>>) -> Self {
Self {
schema,
hb: Instant::now(),
client_id: 0,
subscribes: Default::default(),
}
}
}
impl<Query, Mutation, Subscription> Actor for WsSession<Query, Mutation, Subscription>
where
Query: GQLObject + Sync + Send + 'static,
Mutation: GQLObject + Sync + Send + 'static,
Subscription: GQLSubscription + Send + Sync + 'static,
{
type Context = WebsocketContext<Self>;
fn started(&mut self, ctx: &mut Self::Context) {
new_client(ctx.address().recipient())
.into_actor(self)
.then(|client_id, actor, _| {
actor.client_id = client_id.unwrap();
async {}.into_actor(actor)
})
.wait(ctx);
}
fn stopping(&mut self, _ctx: &mut Self::Context) -> Running {
remove_client(self.client_id);
Running::Stop
}
}
impl<Query, Mutation, Subscription> StreamHandler<Result<Message, ProtocolError>>
for WsSession<Query, Mutation, Subscription>
where
Query: GQLObject + Sync + Send + 'static,
Mutation: GQLObject + Sync + Send + 'static,
Subscription: GQLSubscription + Send + Sync + 'static,
{
fn handle(&mut self, msg: Result<Message, ProtocolError>, ctx: &mut Self::Context) {
let msg = match msg {
Err(_) => {
ctx.stop();
return;
}
Ok(msg) => msg,
};
match msg {
Message::Ping(msg) => {
self.hb = Instant::now();
ctx.pong(&msg);
}
Message::Pong(_) => {
self.hb = Instant::now();
}
Message::Text(s) => {
if let Ok(msg) = serde_json::from_str::<OperationMessage>(&s) {
match msg.ty.as_str() {
"connection_init" => {
ctx.text(
serde_json::to_string(&OperationMessage {
ty: "connection_ack".to_string(),
id: None,
payload: None,
})
.unwrap(),
);
}
"start" => {
if let (Some(id), Some(payload)) = (msg.id, msg.payload) {
if let Ok(request) = serde_json::from_value::<GQLRequest>(payload) {
let builder = self.schema.subscribe(&request.query);
let builder = if let Some(variables) = request.variables {
match Variables::parse_from_json(variables) {
Ok(variables) => builder.variables(variables),
Err(_) => builder,
}
} else {
builder
};
let builder =
if let Some(operation_name) = &request.operation_name {
builder.operator_name(&operation_name)
} else {
builder
};
let subscribe = match builder.execute() {
Ok(subscribe) => subscribe,
Err(err) => {
ctx.text(
serde_json::to_string(&OperationMessage {
ty: "error".to_string(),
id: Some(id),
payload: Some(
serde_json::to_value(GQLError(&err))
.unwrap(),
),
})
.unwrap(),
);
return;
}
};
self.subscribes.insert(id, Arc::new(subscribe));
}
}
}
"stop" => {
if let Some(id) = msg.id {
self.subscribes.remove(&id);
}
}
"connection_terminate" => {
ctx.stop();
}
_ => {}
}
}
}
Message::Binary(_) | Message::Close(_) | Message::Continuation(_) => {
ctx.stop();
}
Message::Nop => {}
}
}
}
impl<Query, Mutation, Subscription> Handler<PushMessage>
for WsSession<Query, Mutation, Subscription>
where
Query: GQLObject + Send + Sync + 'static,
Mutation: GQLObject + Send + Sync + 'static,
Subscription: GQLSubscription + Send + Sync + 'static,
{
type Result = ResponseActFuture<Self, std::result::Result<(), ()>>;
fn handle(&mut self, msg: PushMessage, _ctx: &mut Self::Context) -> Self::Result {
let subscribes = self.subscribes.clone();
let schema = self.schema.clone();
Box::new(
async move {
let mut push_msgs = Vec::new();
for (id, subscribe) in subscribes {
let res = match subscribe.resolve(&schema, msg.0.as_ref()).await {
Ok(Some(value)) => Some(Ok(value)),
Ok(None) => None,
Err(err) => Some(Err(err)),
};
if let Some(res) = res {
let push_msg = serde_json::to_string(&OperationMessage {
ty: "data".to_string(),
id: Some(id.clone()),
payload: Some(serde_json::to_value(GQLResponse(res)).unwrap()),
})
.unwrap();
push_msgs.push(push_msg);
}
}
push_msgs
}
.into_actor(self)
.map(|msgs, _, ctx| {
for msg in msgs {
ctx.text(msg);
}
Ok(())
}),
)
}
}

View File

@ -1,6 +1,6 @@
[package]
name = "async-graphql-derive"
version = "1.0.0"
version = "1.2.0"
authors = ["sunli <scott_s829@163.com>"]
edition = "2018"
description = "The GraphQL server library implemented by rust"

View File

@ -114,6 +114,33 @@ pub fn generate(enum_args: &args::Enum, input: &DeriveInput) -> Result<TokenStre
#crate_name::GQLEnum::resolve_enum(value)
}
}
impl #crate_name::GQLType for &#ident {
fn type_name() -> std::borrow::Cow<'static, str> {
std::borrow::Cow::Borrowed(#gql_typename)
}
fn create_type_info(registry: &mut #crate_name::registry::Registry) -> String {
registry.create_type::<Self, _>(|registry| {
#crate_name::registry::Type::Enum {
name: #gql_typename,
description: #desc,
enum_values: {
let mut enum_items = std::collections::HashMap::new();
#(#schema_enum_items)*
enum_items
},
}
})
}
}
#[#crate_name::async_trait::async_trait]
impl #crate_name::GQLOutputValue for &#ident {
async fn resolve(value: &Self, _: &#crate_name::ContextSelectionSet<'_>) -> #crate_name::Result<serde_json::Value> {
#crate_name::GQLEnum::resolve_enum(*value)
}
}
};
Ok(expanded.into())
}

View File

@ -6,6 +6,7 @@ mod input_object;
mod interface;
mod object;
mod output_type;
mod subscription;
mod union;
mod utils;
@ -82,3 +83,17 @@ pub fn Union(args: TokenStream, input: TokenStream) -> TokenStream {
Err(err) => err.to_compile_error().into(),
}
}
#[proc_macro_attribute]
#[allow(non_snake_case)]
pub fn Subscription(args: TokenStream, input: TokenStream) -> TokenStream {
let object_args = match args::Object::parse(parse_macro_input!(args as AttributeArgs)) {
Ok(object_args) => object_args,
Err(err) => return err.to_compile_error().into(),
};
let mut item_impl = parse_macro_input!(input as ItemImpl);
match subscription::generate(&object_args, &mut item_impl) {
Ok(expanded) => expanded,
Err(err) => err.to_compile_error().into(),
}
}

View File

@ -71,7 +71,6 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
}
} else if let FnArg::Typed(pat) = arg {
if idx == 0 {
// 第一个参数必须是self
return Err(Error::new_spanned(
pat,
"The self receiver must be the first parameter.",
@ -149,6 +148,7 @@ pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<
}
None => quote! { || #crate_name::Value::Null },
};
get_params.push(quote! {
let #ident: #ty = ctx.param_value(#name, #default)?;
});

View File

@ -0,0 +1,279 @@
use crate::args;
use crate::utils::{build_value_repr, get_crate_name};
use inflector::Inflector;
use proc_macro::TokenStream;
use quote::quote;
use syn::{Error, FnArg, ImplItem, ItemImpl, Pat, Result, ReturnType, Type};
pub fn generate(object_args: &args::Object, item_impl: &mut ItemImpl) -> Result<TokenStream> {
let crate_name = get_crate_name(object_args.internal);
let (self_ty, self_name) = match item_impl.self_ty.as_ref() {
Type::Path(path) => (
path,
path.path
.segments
.last()
.map(|s| s.ident.to_string())
.unwrap(),
),
_ => return Err(Error::new_spanned(&item_impl.self_ty, "Invalid type")),
};
let generics = &item_impl.generics;
let gql_typename = object_args
.name
.clone()
.unwrap_or_else(|| self_name.clone());
let desc = object_args
.desc
.as_ref()
.map(|s| quote! {Some(#s)})
.unwrap_or_else(|| quote! {None});
let mut create_types = Vec::new();
let mut filters = Vec::new();
let mut schema_fields = Vec::new();
for item in &mut item_impl.items {
if let ImplItem::Method(method) = item {
if let Some(field) = args::Field::parse(&method.attrs)? {
let ident = &method.sig.ident;
let field_name = field
.name
.clone()
.unwrap_or_else(|| method.sig.ident.to_string().to_camel_case());
let field_desc = field
.desc
.as_ref()
.map(|s| quote! {Some(#s)})
.unwrap_or_else(|| quote! {None});
let field_deprecation = field
.deprecation
.as_ref()
.map(|s| quote! {Some(#s)})
.unwrap_or_else(|| quote! {None});
if method.sig.inputs.len() < 2 {
return Err(Error::new_spanned(
&method.sig.inputs,
"The filter function needs at least two arguments",
));
}
if method.sig.asyncness.is_some() {
return Err(Error::new_spanned(
&method.sig.inputs,
"The filter function must be synchronous",
));
}
let mut res_typ_ok = false;
if let ReturnType::Type(_, res_ty) = &method.sig.output {
if let Type::Path(p) = res_ty.as_ref() {
if p.path.is_ident("bool") {
res_typ_ok = true;
}
}
}
if !res_typ_ok {
return Err(Error::new_spanned(
&method.sig.output,
"The filter function must return a boolean value",
));
}
match &method.sig.inputs[0] {
FnArg::Receiver(_) => {}
_ => {
return Err(Error::new_spanned(
&method.sig.inputs[0],
"The first argument must be self receiver",
));
}
}
let ty = if let FnArg::Typed(ty) = &method.sig.inputs[1] {
match ty.ty.as_ref() {
Type::Reference(r) => r.elem.as_ref().clone(),
_ => {
return Err(Error::new_spanned(ty, "Incorrect object type"));
}
}
} else {
return Err(Error::new_spanned(
&method.sig.inputs[1],
"Incorrect object type",
));
};
let mut args = Vec::new();
for arg in method.sig.inputs.iter_mut().skip(2) {
if let FnArg::Typed(pat) = arg {
match (&*pat.pat, &*pat.ty) {
(Pat::Ident(arg_ident), Type::Path(arg_ty)) => {
args.push((arg_ident, arg_ty, args::Argument::parse(&pat.attrs)?));
pat.attrs.clear();
}
_ => {
return Err(Error::new_spanned(arg, "Incorrect argument type"));
}
}
} else {
return Err(Error::new_spanned(arg, "Incorrect argument type"));
}
}
let mut schema_args = Vec::new();
let mut use_params = Vec::new();
let mut get_params = Vec::new();
for (
ident,
ty,
args::Argument {
name,
desc,
default,
},
) in args
{
let name = name
.clone()
.unwrap_or_else(|| ident.ident.to_string().to_camel_case());
let desc = desc
.as_ref()
.map(|s| quote! {Some(#s)})
.unwrap_or_else(|| quote! {None});
let schema_default = default
.as_ref()
.map(|v| {
let s = v.to_string();
quote! {Some(#s)}
})
.unwrap_or_else(|| quote! {None});
schema_args.push(quote! {
args.insert(#name, #crate_name::registry::InputValue {
name: #name,
description: #desc,
ty: <#ty as #crate_name::GQLType>::create_type_info(registry),
default_value: #schema_default,
});
});
use_params.push(quote! { #ident });
let default = match &default {
Some(default) => {
let repr = build_value_repr(&crate_name, &default);
quote! {|| #repr }
}
None => quote! { || #crate_name::Value::Null },
};
get_params.push(quote! {
let #ident: #ty = ctx_field.param_value(#name, #default)?;
});
}
schema_fields.push(quote! {
fields.insert(#field_name, #crate_name::registry::Field {
name: #field_name,
description: #field_desc,
args: {
let mut args = std::collections::HashMap::new();
#(#schema_args)*
args
},
ty: <#ty as #crate_name::GQLType>::create_type_info(registry),
deprecation: #field_deprecation,
});
});
create_types.push(quote! {
if field.name.as_str() == #field_name {
types.insert(std::any::TypeId::of::<#ty>(), field);
continue;
}
});
filters.push(quote! {
if let Some(msg) = msg.downcast_ref::<#ty>() {
#(#get_params)*
if self.#ident(msg, #(#use_params)*) {
let ctx_selection_set = ctx_field.with_item(&field.selection_set);
let value =
#crate_name::GQLOutputValue::resolve(msg, &ctx_selection_set).await?;
return Ok(Some(value));
}
}
});
method.attrs.clear();
}
}
}
let expanded = quote! {
#item_impl
impl #generics #crate_name::GQLType for #self_ty {
fn type_name() -> std::borrow::Cow<'static, str> {
std::borrow::Cow::Borrowed(#gql_typename)
}
fn create_type_info(registry: &mut #crate_name::registry::Registry) -> String {
registry.create_type::<Self, _>(|registry| #crate_name::registry::Type::Object {
name: #gql_typename,
description: #desc,
fields: {
let mut fields = std::collections::HashMap::new();
#(#schema_fields)*
fields
},
})
}
}
#[#crate_name::async_trait::async_trait]
impl #crate_name::GQLSubscription for SubscriptionRoot {
fn create_types(
selection_set: #crate_name::graphql_parser::query::SelectionSet,
) -> #crate_name::Result<std::collections::HashMap<std::any::TypeId, #crate_name::graphql_parser::query::Field>> {
use #crate_name::ErrorWithPosition;
let mut types = std::collections::HashMap::new();
for selection in selection_set.items {
match selection {
#crate_name::graphql_parser::query::Selection::Field(field) => {
#(#create_types)*
#crate_name::anyhow::bail!(#crate_name::QueryError::FieldNotFound {
field_name: field.name.clone(),
object: #gql_typename.to_string(),
}
.with_position(field.position));
}
_ => {}
}
}
Ok(types)
}
async fn resolve(
&self,
ctx: &#crate_name::ContextBase<'_, ()>,
types: &std::collections::HashMap<std::any::TypeId, #crate_name::graphql_parser::query::Field>,
msg: &(dyn std::any::Any + Send + Sync),
) -> #crate_name::Result<Option<serde_json::Value>> {
let tid = msg.type_id();
if let Some(field) = types.get(&tid) {
let ctx_field = ctx.with_item(field);
#(#filters)*
}
Ok(None)
}
}
};
Ok(expanded.into())
}

View File

@ -2,9 +2,9 @@ mod starwars;
use actix_web::{guard, web, App, HttpResponse, HttpServer};
use async_graphql::http::{graphiql_source, playground_source, GQLRequest, GQLResponse};
use async_graphql::{GQLEmptyMutation, Schema};
use async_graphql::{GQLEmptyMutation, GQLEmptySubscription, Schema};
type StarWarsSchema = Schema<starwars::QueryRoot, GQLEmptyMutation>;
type StarWarsSchema = Schema<starwars::QueryRoot, GQLEmptyMutation, GQLEmptySubscription>;
async fn index(s: web::Data<StarWarsSchema>, req: web::Json<GQLRequest>) -> web::Json<GQLResponse> {
web::Json(req.into_inner().execute(&s).await)
@ -13,7 +13,7 @@ async fn index(s: web::Data<StarWarsSchema>, req: web::Json<GQLRequest>) -> web:
async fn gql_playgound() -> HttpResponse {
HttpResponse::Ok()
.content_type("text/html; charset=utf-8")
.body(playground_source("/"))
.body(playground_source("/", None))
}
async fn gql_graphiql() -> HttpResponse {
@ -27,7 +27,8 @@ async fn main() -> std::io::Result<()> {
HttpServer::new(move || {
App::new()
.data(
Schema::new(starwars::QueryRoot, GQLEmptyMutation).data(starwars::StarWars::new()),
Schema::new(starwars::QueryRoot, GQLEmptyMutation, GQLEmptySubscription)
.data(starwars::StarWars::new()),
)
.service(web::resource("/").guard(guard::Post()).to(index))
.service(web::resource("/").guard(guard::Get()).to(gql_playgound))

View File

@ -1,11 +1,11 @@
mod starwars;
use async_graphql::http::{graphiql_source, playground_source, GQLRequest};
use async_graphql::{GQLEmptyMutation, Schema};
use async_graphql::{GQLEmptyMutation, GQLEmptySubscription, Schema};
use mime;
use tide::{self, Request, Response};
type StarWarsSchema = Schema<starwars::QueryRoot, GQLEmptyMutation>;
type StarWarsSchema = Schema<starwars::QueryRoot, GQLEmptyMutation, GQLEmptySubscription>;
async fn index(mut request: Request<StarWarsSchema>) -> Response {
let gql_request: GQLRequest = request.body_json().await.unwrap();
@ -16,7 +16,7 @@ async fn index(mut request: Request<StarWarsSchema>) -> Response {
async fn gql_playground(_request: Request<StarWarsSchema>) -> Response {
Response::new(200)
.body_string(playground_source("/"))
.body_string(playground_source("/", None))
.set_mime(mime::TEXT_HTML_UTF_8)
}
async fn gql_graphiql(_request: Request<StarWarsSchema>) -> Response {
@ -28,7 +28,8 @@ async fn gql_graphiql(_request: Request<StarWarsSchema>) -> Response {
#[async_std::main]
async fn main() -> std::io::Result<()> {
let mut app = tide::with_state(
Schema::new(starwars::QueryRoot, GQLEmptyMutation).data(starwars::StarWars::new()),
Schema::new(starwars::QueryRoot, GQLEmptyMutation, GQLEmptySubscription)
.data(starwars::StarWars::new()),
);
app.at("/").post(index);
app.at("/").get(gql_playground);

View File

@ -1,5 +1,6 @@
use crate::registry::Registry;
use crate::{ErrorWithPosition, GQLInputValue, GQLType, QueryError, Result};
use bytes::Bytes;
use fnv::FnvHasher;
use graphql_parser::query::{
Directive, Field, FragmentDefinition, SelectionSet, Value, VariableDefinition,
@ -10,7 +11,7 @@ use std::hash::BuildHasherDefault;
use std::ops::{Deref, DerefMut};
/// Variables of query
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Variables(Value);
impl Default for Variables {
@ -42,7 +43,7 @@ impl DerefMut for Variables {
}
impl Variables {
pub(crate) fn parse_from_json(value: serde_json::Value) -> Result<Self> {
pub fn parse_from_json(value: serde_json::Value) -> Result<Self> {
let gql_value = json_value_to_gql_value(value);
if let Value::Object(_) = gql_value {
Ok(Variables(gql_value))
@ -56,7 +57,7 @@ impl Variables {
var_path: &str,
filename: &str,
content_type: Option<&str>,
content: Vec<u8>,
content: Bytes,
) {
let mut it = var_path.split(".").peekable();

View File

@ -36,6 +36,9 @@ pub enum QueryError {
#[error("Schema is not configured for mutations.")]
NotConfiguredMutations,
#[error("Schema is not configured for subscriptions.")]
NotConfiguredSubscriptions,
#[error("Invalid value for enum \"{ty}\".")]
InvalidEnumValue { ty: String, value: String },

View File

@ -5,8 +5,8 @@ pub use graphiql_source::graphiql_source;
pub use playground_source::playground_source;
use crate::error::{RuleError, RuleErrors};
use crate::schema::PreparedQuery;
use crate::{GQLObject, PositionError, Result, Schema, Variables};
use crate::query::PreparedQuery;
use crate::{GQLObject, GQLSubscription, PositionError, Result, Schema, Variables};
use graphql_parser::Pos;
use serde::ser::{SerializeMap, SerializeSeq};
use serde::{Serialize, Serializer};
@ -21,10 +21,14 @@ pub struct GQLRequest {
}
impl GQLRequest {
pub async fn execute<Query, Mutation>(mut self, schema: &Schema<Query, Mutation>) -> GQLResponse
pub async fn execute<Query, Mutation, Subscription>(
mut self,
schema: &Schema<Query, Mutation, Subscription>,
) -> GQLResponse
where
Query: GQLObject + Send + Sync,
Mutation: GQLObject + Send + Sync,
Subscription: GQLSubscription + Send + Sync,
{
match self.prepare(schema) {
Ok(query) => GQLResponse(query.execute().await),
@ -32,13 +36,14 @@ impl GQLRequest {
}
}
pub fn prepare<'a, Query, Mutation>(
pub fn prepare<'a, Query, Mutation, Subscription>(
&'a mut self,
schema: &'a Schema<Query, Mutation>,
schema: &'a Schema<Query, Mutation, Subscription>,
) -> Result<PreparedQuery<'a, Query, Mutation>>
where
Query: GQLObject + Send + Sync,
Mutation: GQLObject + Send + Sync,
Subscription: GQLSubscription + Send + Sync,
{
let vars = match self.variables.take() {
Some(value) => match Variables::parse_from_json(value) {
@ -81,7 +86,7 @@ impl Serialize for GQLResponse {
}
}
struct GQLError<'a>(&'a anyhow::Error);
pub struct GQLError<'a>(pub &'a anyhow::Error);
impl<'a> Deref for GQLError<'a> {
type Target = anyhow::Error;

View File

@ -1,4 +1,7 @@
pub fn playground_source(graphql_endpoint_url: &str) -> String {
pub fn playground_source(
graphql_endpoint_url: &str,
subscription_endpoint: Option<&str>,
) -> String {
r##"
<!DOCTYPE html>
@ -533,10 +536,14 @@ pub fn playground_source(graphql_endpoint_url: &str) -> String {
const root = document.getElementById('root');
root.classList.add('playgroundIn');
GraphQLPlayground.init(root, { endpoint: 'GRAPHQL_URL' })
GraphQLPlayground.init(root, { endpoint: GRAPHQL_URL, subscriptionEndpoint: GRAPHQL_SUBSCRIPTION_URL })
})
</script>
</body>
</html>
"##.replace("GRAPHQL_URL", graphql_endpoint_url)
"##.replace("GRAPHQL_URL", &format!("'{}'", graphql_endpoint_url))
.replace("GRAPHQL_SUBSCRIPTION_URL", &match subscription_endpoint {
Some(url) => format!("'{}'", url),
None => "null".to_string()
})
}

View File

@ -57,9 +57,11 @@ mod base;
mod context;
mod error;
mod model;
mod query;
mod resolver;
mod scalars;
mod schema;
mod subscription;
mod types;
mod validation;
@ -78,9 +80,11 @@ pub use base::GQLScalar;
pub use context::{Context, Variables};
pub use error::{ErrorWithPosition, PositionError, QueryError, QueryParseError};
pub use graphql_parser::query::Value;
pub use query::{PreparedQuery, QueryBuilder};
pub use scalars::ID;
pub use schema::{QueryBuilder, Schema};
pub use types::{GQLEmptyMutation, Upload};
pub use schema::Schema;
pub use subscription::SubscribeBuilder;
pub use types::{GQLEmptyMutation, GQLEmptySubscription, Upload};
pub type Result<T> = anyhow::Result<T>;
pub type Error = anyhow::Error;
@ -97,6 +101,8 @@ pub use context::ContextBase;
#[doc(hidden)]
pub use resolver::do_resolve;
#[doc(hidden)]
pub use subscription::{GQLSubscription, Subscribe};
#[doc(hidden)]
pub use types::{GQLEnum, GQLEnumItem};
/// Define a GraphQL object
@ -424,3 +430,5 @@ pub use async_graphql_derive::Interface;
///
/// It's similar to Interface, but it doesn't have fields.
pub use async_graphql_derive::Union;
pub use async_graphql_derive::Subscription;

View File

@ -43,7 +43,11 @@ impl<'a> __Schema<'a> {
desc = "If this server support subscription, the type that subscription operations will be rooted at."
)]
async fn subscription_type(&self) -> Option<__Type<'a>> {
None
if let Some(ty) = &self.registry.subscription_type {
Some(__Type::new_simple(self.registry, &self.registry.types[ty]))
} else {
None
}
}
#[field(desc = "A list of all directives supported by this server.")]

179
src/query.rs Normal file
View File

@ -0,0 +1,179 @@
use crate::context::Data;
use crate::registry::Registry;
use crate::types::QueryRoot;
use crate::validation::check_rules;
use crate::{ContextBase, GQLOutputValue, Result};
use crate::{GQLObject, QueryError, QueryParseError, Variables};
use bytes::Bytes;
use graphql_parser::parse_query;
use graphql_parser::query::{
Definition, FragmentDefinition, OperationDefinition, SelectionSet, VariableDefinition,
};
use std::collections::HashMap;
enum Root<'a, Query, Mutation> {
Query(&'a QueryRoot<Query>),
Mutation(&'a Mutation),
}
/// Query builder
pub struct QueryBuilder<'a, Query, Mutation> {
pub(crate) query: &'a QueryRoot<Query>,
pub(crate) mutation: &'a Mutation,
pub(crate) registry: &'a Registry,
pub(crate) source: &'a str,
pub(crate) operation_name: Option<&'a str>,
pub(crate) variables: Option<Variables>,
pub(crate) data: &'a Data,
}
impl<'a, Query, Mutation> QueryBuilder<'a, Query, Mutation> {
/// Specify the operation name.
pub fn operator_name(self, name: &'a str) -> Self {
QueryBuilder {
operation_name: Some(name),
..self
}
}
/// Specify the variables.
pub fn variables(self, vars: Variables) -> Self {
QueryBuilder {
variables: Some(vars),
..self
}
}
/// Prepare query
pub fn prepare(self) -> Result<PreparedQuery<'a, Query, Mutation>> {
let document = parse_query(self.source).map_err(|err| QueryParseError(err.to_string()))?;
check_rules(self.registry, &document)?;
let mut fragments = HashMap::new();
let mut selection_set = None;
let mut variable_definitions = None;
let mut root = None;
for definition in document.definitions {
match definition {
Definition::Operation(operation_definition) => match operation_definition {
OperationDefinition::SelectionSet(s) => {
selection_set = Some(s);
root = Some(Root::Query(self.query));
}
OperationDefinition::Query(query)
if query.name.is_none() || query.name.as_deref() == self.operation_name =>
{
selection_set = Some(query.selection_set);
variable_definitions = Some(query.variable_definitions);
root = Some(Root::Query(self.query));
}
OperationDefinition::Mutation(mutation)
if mutation.name.is_none()
|| mutation.name.as_deref() == self.operation_name =>
{
selection_set = Some(mutation.selection_set);
variable_definitions = Some(mutation.variable_definitions);
root = Some(Root::Mutation(self.mutation));
}
OperationDefinition::Subscription(subscription)
if subscription.name.is_none()
|| subscription.name.as_deref() == self.operation_name =>
{
return Err(QueryError::NotSupported.into());
}
_ => {}
},
Definition::Fragment(fragment) => {
fragments.insert(fragment.name.clone(), fragment);
}
}
}
Ok(PreparedQuery {
registry: self.registry,
variables: self.variables.unwrap_or_default(),
data: self.data,
fragments,
selection_set: selection_set.ok_or({
if let Some(name) = self.operation_name {
QueryError::UnknownOperationNamed {
name: name.to_string(),
}
} else {
QueryError::MissingOperation
}
})?,
root: root.unwrap(),
variable_definitions,
})
}
/// Execute the query.
pub async fn execute(self) -> Result<serde_json::Value>
where
Query: GQLObject + Send + Sync,
Mutation: GQLObject + Send + Sync,
{
self.prepare()?.execute().await
}
}
pub struct PreparedQuery<'a, Query, Mutation> {
root: Root<'a, Query, Mutation>,
registry: &'a Registry,
variables: Variables,
data: &'a Data,
fragments: HashMap<String, FragmentDefinition>,
selection_set: SelectionSet,
variable_definitions: Option<Vec<VariableDefinition>>,
}
impl<'a, Query, Mutation> PreparedQuery<'a, Query, Mutation> {
/// Detects whether any parameter contains the Upload type
pub fn is_upload(&self) -> bool {
if let Some(variable_definitions) = &self.variable_definitions {
for d in variable_definitions {
if let Some(ty) = self.registry.basic_type_by_parsed_type(&d.var_type) {
if ty.name() == "Upload" {
return true;
}
}
}
}
false
}
/// Set upload files
pub fn set_upload(
&mut self,
var_path: &str,
filename: &str,
content_type: Option<&str>,
content: Bytes,
) {
self.variables
.set_upload(var_path, filename, content_type, content);
}
/// Execute the query.
pub async fn execute(self) -> Result<serde_json::Value>
where
Query: GQLObject + Send + Sync,
Mutation: GQLObject + Send + Sync,
{
let ctx = ContextBase {
item: &self.selection_set,
variables: &self.variables,
variable_definitions: self.variable_definitions.as_deref(),
registry: self.registry.clone(),
data: self.data,
fragments: &self.fragments,
};
match self.root {
Root::Query(query) => return GQLOutputValue::resolve(query, &ctx).await,
Root::Mutation(mutation) => return GQLOutputValue::resolve(mutation, &ctx).await,
}
}
}

View File

@ -179,6 +179,7 @@ pub struct Registry {
pub implements: HashMap<String, HashSet<String>>,
pub query_type: String,
pub mutation_type: Option<String>,
pub subscription_type: Option<String>,
}
impl Registry {

View File

@ -21,6 +21,24 @@ impl DerefMut for ID {
}
}
impl From<String> for ID {
fn from(value: String) -> Self {
ID(value)
}
}
impl<'a> From<&'a str> for ID {
fn from(value: &'a str) -> Self {
ID(value.to_string())
}
}
impl From<usize> for ID {
fn from(value: usize) -> Self {
ID(value.to_string())
}
}
impl GQLScalar for ID {
fn type_name() -> &'static str {
"ID"

View File

@ -1,32 +1,29 @@
use crate::context::Data;
use crate::model::__DirectiveLocation;
use crate::query::QueryBuilder;
use crate::registry::{Directive, InputValue, Registry};
use crate::types::QueryRoot;
use crate::validation::check_rules;
use crate::{
ContextBase, GQLObject, GQLOutputValue, GQLType, QueryError, QueryParseError, Result, Variables,
};
use graphql_parser::parse_query;
use graphql_parser::query::{
Definition, FragmentDefinition, OperationDefinition, SelectionSet, VariableDefinition,
};
use crate::{GQLObject, GQLSubscription, GQLType, SubscribeBuilder};
use std::any::Any;
use std::collections::HashMap;
/// GraphQL schema
pub struct Schema<Query, Mutation> {
pub struct Schema<Query, Mutation, Subscription> {
query: QueryRoot<Query>,
mutation: Mutation,
registry: Registry,
data: Data,
pub(crate) subscription: Subscription,
pub(crate) registry: Registry,
pub(crate) data: Data,
}
impl<Query: GQLObject, Mutation: GQLObject> Schema<Query, Mutation> {
impl<Query: GQLObject, Mutation: GQLObject, Subscription: GQLSubscription>
Schema<Query, Mutation, Subscription>
{
/// Create a schema.
///
/// The root object for the query and Mutation needs to be specified.
/// If there is no mutation, you can use `GQLEmptyMutation`.
pub fn new(query: Query, mutation: Mutation) -> Self {
pub fn new(query: Query, mutation: Mutation, subscription: Subscription) -> Self {
let mut registry = Registry {
types: Default::default(),
directives: Default::default(),
@ -37,6 +34,11 @@ impl<Query: GQLObject, Mutation: GQLObject> Schema<Query, Mutation> {
} else {
Some(Mutation::type_name().to_string())
},
subscription_type: if Subscription::is_empty() {
None
} else {
Some(Subscription::type_name().to_string())
},
};
registry.add_directive(Directive {
@ -89,10 +91,14 @@ impl<Query: GQLObject, Mutation: GQLObject> Schema<Query, Mutation> {
if !Mutation::is_empty() {
Mutation::create_type_info(&mut registry);
}
if !Subscription::is_empty() {
Subscription::create_type_info(&mut registry);
}
Self {
query: QueryRoot { inner: query },
mutation,
subscription,
registry,
data: Default::default(),
}
@ -105,184 +111,25 @@ impl<Query: GQLObject, Mutation: GQLObject> Schema<Query, Mutation> {
}
/// Start a query and return `QueryBuilder`.
pub fn query<'a>(&'a self, query_source: &'a str) -> QueryBuilder<'a, Query, Mutation> {
pub fn query<'a>(&'a self, source: &'a str) -> QueryBuilder<'a, Query, Mutation> {
QueryBuilder {
query: &self.query,
mutation: &self.mutation,
registry: &self.registry,
query_source,
source,
operation_name: None,
variables: None,
data: &self.data,
}
}
}
enum Root<'a, Query, Mutation> {
Query(&'a QueryRoot<Query>),
Mutation(&'a Mutation),
}
/// Query builder
pub struct QueryBuilder<'a, Query, Mutation> {
query: &'a QueryRoot<Query>,
mutation: &'a Mutation,
registry: &'a Registry,
query_source: &'a str,
operation_name: Option<&'a str>,
variables: Option<Variables>,
data: &'a Data,
}
impl<'a, Query, Mutation> QueryBuilder<'a, Query, Mutation> {
/// Specify the operation name.
pub fn operator_name(self, name: &'a str) -> Self {
QueryBuilder {
operation_name: Some(name),
..self
}
}
/// Specify the variables.
pub fn variables(self, vars: Variables) -> Self {
QueryBuilder {
variables: Some(vars),
..self
}
}
/// Prepare query
pub fn prepare(self) -> Result<PreparedQuery<'a, Query, Mutation>> {
let document =
parse_query(self.query_source).map_err(|err| QueryParseError(err.to_string()))?;
check_rules(self.registry, &document)?;
let mut fragments = HashMap::new();
let mut selection_set = None;
let mut variable_definitions = None;
let mut root = None;
for definition in document.definitions {
match definition {
Definition::Operation(operation_definition) => match operation_definition {
OperationDefinition::SelectionSet(s) => {
selection_set = Some(s);
root = Some(Root::Query(self.query));
}
OperationDefinition::Query(query)
if query.name.is_none() || query.name.as_deref() == self.operation_name =>
{
selection_set = Some(query.selection_set);
variable_definitions = Some(query.variable_definitions);
root = Some(Root::Query(self.query));
}
OperationDefinition::Mutation(mutation)
if mutation.name.is_none()
|| mutation.name.as_deref() == self.operation_name =>
{
selection_set = Some(mutation.selection_set);
variable_definitions = Some(mutation.variable_definitions);
root = Some(Root::Mutation(self.mutation));
}
OperationDefinition::Subscription(subscription)
if subscription.name.is_none()
|| subscription.name.as_deref() == self.operation_name =>
{
return Err(QueryError::NotSupported.into());
}
_ => {}
},
Definition::Fragment(fragment) => {
fragments.insert(fragment.name.clone(), fragment);
}
}
}
Ok(PreparedQuery {
registry: self.registry,
variables: self.variables.unwrap_or_default(),
data: self.data,
fragments,
selection_set: selection_set.ok_or({
if let Some(name) = self.operation_name {
QueryError::UnknownOperationNamed {
name: name.to_string(),
}
} else {
QueryError::MissingOperation
}
})?,
root: root.unwrap(),
variable_definitions,
})
}
/// Execute the query.
pub async fn execute(self) -> Result<serde_json::Value>
where
Query: GQLObject + Send + Sync,
Mutation: GQLObject + Send + Sync,
{
self.prepare()?.execute().await
}
}
pub struct PreparedQuery<'a, Query, Mutation> {
root: Root<'a, Query, Mutation>,
registry: &'a Registry,
variables: Variables,
data: &'a Data,
fragments: HashMap<String, FragmentDefinition>,
selection_set: SelectionSet,
variable_definitions: Option<Vec<VariableDefinition>>,
}
impl<'a, Query, Mutation> PreparedQuery<'a, Query, Mutation> {
/// Detects whether any parameter contains the Upload type
pub fn is_upload(&self) -> bool {
if let Some(variable_definitions) = &self.variable_definitions {
for d in variable_definitions {
if let Some(ty) = self.registry.basic_type_by_parsed_type(&d.var_type) {
if ty.name() == "Upload" {
return true;
}
}
}
}
false
}
/// Set upload files
pub fn set_upload(
&mut self,
var_path: &str,
filename: &str,
content_type: Option<&str>,
content: Vec<u8>,
) {
self.variables
.set_upload(var_path, filename, content_type, content);
}
/// Execute the query.
pub async fn execute(self) -> Result<serde_json::Value>
where
Query: GQLObject + Send + Sync,
Mutation: GQLObject + Send + Sync,
{
let ctx = ContextBase {
item: &self.selection_set,
variables: &self.variables,
variable_definitions: self.variable_definitions.as_deref(),
registry: self.registry.clone(),
data: self.data,
fragments: &self.fragments,
};
match self.root {
Root::Query(query) => return GQLOutputValue::resolve(query, &ctx).await,
Root::Mutation(mutation) => return GQLOutputValue::resolve(mutation, &ctx).await,
pub fn subscribe<'a>(&'a self, source: &'a str) -> SubscribeBuilder<'a, Subscription> {
SubscribeBuilder {
subscription: &self.subscription,
registry: &self.registry,
source,
operation_name: None,
variables: None,
}
}
}

139
src/subscription.rs Normal file
View File

@ -0,0 +1,139 @@
use crate::registry::Registry;
use crate::validation::check_rules;
use crate::{ContextBase, GQLType, QueryError, QueryParseError, Result, Schema, Variables};
use graphql_parser::parse_query;
use graphql_parser::query::{
Definition, Field, FragmentDefinition, OperationDefinition, SelectionSet, VariableDefinition,
};
use std::any::{Any, TypeId};
use std::collections::HashMap;
pub struct Subscribe {
pub types: HashMap<TypeId, Field>,
pub variables: Variables,
pub variable_definitions: Vec<VariableDefinition>,
pub fragments: HashMap<String, FragmentDefinition>,
}
impl Subscribe {
pub async fn resolve<Query, Mutation, Subscription>(
&self,
schema: &Schema<Query, Mutation, Subscription>,
msg: &(dyn Any + Send + Sync),
) -> Result<Option<serde_json::Value>>
where
Subscription: GQLSubscription + Sync + Send + 'static,
{
let ctx = ContextBase::<()> {
item: (),
variables: &self.variables,
variable_definitions: Some(&self.variable_definitions),
registry: &schema.registry,
data: &Default::default(),
fragments: &self.fragments,
};
schema.subscription.resolve(&ctx, &self.types, msg).await
}
}
/// Represents a GraphQL subscription object
#[async_trait::async_trait]
pub trait GQLSubscription: GQLType {
/// This function returns true of type `GQLEmptySubscription` only
#[doc(hidden)]
fn is_empty() -> bool {
return false;
}
fn create_types(selection_set: SelectionSet) -> Result<HashMap<TypeId, Field>>;
fn create_subscribe(
&self,
selection_set: SelectionSet,
variables: Variables,
variable_definitions: Vec<VariableDefinition>,
fragments: HashMap<String, FragmentDefinition>,
) -> Result<Subscribe> {
Ok(Subscribe {
types: Self::create_types(selection_set)?,
variables,
variable_definitions,
fragments,
})
}
/// Resolve a subscription message, If no message of this type is subscribed, None is returned.
async fn resolve(
&self,
ctx: &ContextBase<'_, ()>,
types: &HashMap<TypeId, Field>,
msg: &(dyn Any + Send + Sync),
) -> Result<Option<serde_json::Value>>;
}
pub struct SubscribeBuilder<'a, Subscription> {
pub(crate) subscription: &'a Subscription,
pub(crate) registry: &'a Registry,
pub(crate) source: &'a str,
pub(crate) operation_name: Option<&'a str>,
pub(crate) variables: Option<Variables>,
}
impl<'a, Subscription> SubscribeBuilder<'a, Subscription>
where
Subscription: GQLSubscription,
{
/// Specify the operation name.
pub fn operator_name(self, name: &'a str) -> Self {
SubscribeBuilder {
operation_name: Some(name),
..self
}
}
/// Specify the variables.
pub fn variables(self, vars: Variables) -> Self {
SubscribeBuilder {
variables: Some(vars),
..self
}
}
pub fn execute(self) -> Result<Subscribe> {
let document = parse_query(self.source).map_err(|err| QueryParseError(err.to_string()))?;
check_rules(self.registry, &document)?;
let mut fragments = HashMap::new();
let mut subscription = None;
for definition in document.definitions {
match definition {
Definition::Operation(OperationDefinition::Subscription(s)) => {
if s.name.as_deref() == self.operation_name {
subscription = Some(s);
break;
}
}
Definition::Fragment(fragment) => {
fragments.insert(fragment.name.clone(), fragment);
}
_ => {}
}
}
let subscription = subscription.ok_or(if let Some(name) = self.operation_name {
QueryError::UnknownOperationNamed {
name: name.to_string(),
}
} else {
QueryError::MissingOperation
})?;
self.subscription.create_subscribe(
subscription.selection_set,
self.variables.unwrap_or_default(),
subscription.variable_definitions,
fragments,
)
}
}

View File

@ -0,0 +1,39 @@
use crate::{registry, ContextBase, GQLSubscription, GQLType, QueryError, Result};
use graphql_parser::query::{Field, SelectionSet};
use serde_json::Value;
use std::any::{Any, TypeId};
use std::borrow::Cow;
use std::collections::hash_map::RandomState;
use std::collections::HashMap;
pub struct GQLEmptySubscription;
impl GQLType for GQLEmptySubscription {
fn type_name() -> Cow<'static, str> {
Cow::Borrowed("EmptyMutation")
}
fn create_type_info(registry: &mut registry::Registry) -> String {
registry.create_type::<Self, _>(|_| registry::Type::Object {
name: "EmptySubscription",
description: None,
fields: Default::default(),
})
}
}
#[async_trait::async_trait]
impl GQLSubscription for GQLEmptySubscription {
fn create_types(_selection_set: SelectionSet) -> Result<HashMap<TypeId, Field, RandomState>> {
return Err(QueryError::NotConfiguredSubscriptions.into());
}
async fn resolve(
&self,
_ctx: &ContextBase<'_, ()>,
_types: &HashMap<TypeId, Field, RandomState>,
_msg: &(dyn Any + Send + Sync),
) -> Result<Option<Value>> {
return Err(QueryError::NotConfiguredSubscriptions.into());
}
}

View File

@ -1,4 +1,5 @@
mod empty_mutation;
mod empty_subscription;
mod r#enum;
mod list;
mod optional;
@ -6,6 +7,7 @@ mod query_root;
mod upload;
pub use empty_mutation::GQLEmptyMutation;
pub use empty_subscription::GQLEmptySubscription;
pub use query_root::QueryRoot;
pub use r#enum::{GQLEnum, GQLEnumItem};
pub use upload::Upload;

View File

@ -32,6 +32,7 @@ pub fn check_rules(registry: &Registry, doc: &Document) -> Result<()> {
.with(rules::VariablesAreInputTypes)
.with(rules::VariableInAllowedPosition::default())
.with(rules::ScalarLeafs)
.with(rules::NoComposeLeafs)
.with(rules::PossibleFragmentSpreads::default())
.with(rules::ProvidedNonNullArguments)
.with(rules::KnownDirectives::default())

View File

@ -26,13 +26,27 @@ impl<'a> Visitor<'a> for LoneAnonymousOperation {
operation_definition: &'a OperationDefinition,
) {
if let Some(operation_count) = self.operation_count {
if let OperationDefinition::SelectionSet(s) = operation_definition {
if operation_count > 1 {
ctx.report_error(
vec![s.span.0, s.span.1],
"This anonymous operation must be the only defined operation",
);
let (err, pos) = match operation_definition {
OperationDefinition::SelectionSet(s) => (operation_count > 1, s.span.0),
OperationDefinition::Query(query) if query.name.is_none() => {
(operation_count > 1, query.position)
}
OperationDefinition::Mutation(mutation) if mutation.name.is_none() => {
(operation_count > 1, mutation.position)
}
OperationDefinition::Subscription(subscription) if subscription.name.is_none() => {
(operation_count > 1, subscription.position)
}
_ => {
return;
}
};
if err {
ctx.report_error(
vec![pos],
"This anonymous operation must be the only defined operation",
);
}
}
}

View File

@ -7,6 +7,7 @@ mod known_directives;
mod known_fragment_names;
mod known_type_names;
mod lone_anonymous_operation;
mod no_compose_leafs;
mod no_fragment_cycles;
mod no_undefined_variables;
mod no_unused_fragments;
@ -32,6 +33,7 @@ pub use known_directives::KnownDirectives;
pub use known_fragment_names::KnownFragmentNames;
pub use known_type_names::KnownTypeNames;
pub use lone_anonymous_operation::LoneAnonymousOperation;
pub use no_compose_leafs::NoComposeLeafs;
pub use no_fragment_cycles::NoFragmentCycles;
pub use no_undefined_variables::NoUndefinedVariables;
pub use no_unused_fragments::NoUnusedFragments;

View File

@ -0,0 +1,27 @@
use crate::validation::context::ValidatorContext;
use crate::validation::visitor::Visitor;
use graphql_parser::query::Field;
#[derive(Default)]
pub struct NoComposeLeafs;
impl<'a> Visitor<'a> for NoComposeLeafs {
fn enter_field(&mut self, ctx: &mut ValidatorContext<'a>, field: &'a Field) {
if let Some(ty) = ctx.parent_type() {
if let Some(schema_field) = ty.field_by_name(&field.name) {
if let Some(ty) = ctx.registry.basic_type_by_typename(&schema_field.ty) {
if ty.is_composite() && field.selection_set.items.is_empty() {
ctx.report_error(
vec![field.position],
format!(
"Field \"{}\" of type \"{}\" must have a selection of subfields",
field.name,
ty.name()
),
)
}
}
}
}
}
}

View File

@ -373,10 +373,18 @@ fn visit_operation_definition<'a, V: Visitor<'a>>(
}
}
OperationDefinition::Subscription(subscription) => {
ctx.report_error(vec![subscription.position], "Not supported.");
// visit_variable_definitions(v, ctx, &subscription.variable_definitions);
// visit_directives(v, ctx, &subscription.directives);
// visit_selection_set(v, ctx, &subscription.selection_set);
if let Some(subscription_type) = &ctx.registry.subscription_type {
ctx.with_type(&ctx.registry.types[subscription_type], |ctx| {
visit_variable_definitions(v, ctx, &subscription.variable_definitions);
visit_directives(v, ctx, &subscription.directives);
visit_selection_set(v, ctx, &subscription.selection_set);
});
} else {
ctx.report_error(
vec![subscription.position],
"Schema is not configured for subscriptions.",
);
}
}
}
v.exit_operation_definition(ctx, operation);

View File

@ -35,7 +35,11 @@ pub async fn test_enum_type() {
}
}
let schema = Schema::new(Root { value: MyEnum::A }, GQLEmptyMutation);
let schema = Schema::new(
Root { value: MyEnum::A },
GQLEmptyMutation,
GQLEmptySubscription,
);
let query = format!(
r#"{{
value

View File

@ -72,7 +72,7 @@ pub async fn test_input_object_default_value() {
}
}
let schema = Schema::new(Root, GQLEmptyMutation);
let schema = Schema::new(Root, GQLEmptyMutation, GQLEmptySubscription);
let query = format!(
r#"{{
a(input:{{e:777}}) {{

View File

@ -39,6 +39,7 @@ pub async fn test_list_type() {
value: vec![1, 2, 3, 4, 5],
},
GQLEmptyMutation,
GQLEmptySubscription,
);
let json_value: serde_json::Value = vec![1, 2, 3, 4, 5].into();
let query = format!(

View File

@ -51,6 +51,7 @@ pub async fn test_optional_type() {
value2: None,
},
GQLEmptyMutation,
GQLEmptySubscription,
);
let query = format!(
r#"{{

View File

@ -31,7 +31,7 @@ macro_rules! test_scalars {
}
}
let schema = Schema::new(Root { value: $value }, GQLEmptyMutation);
let schema = Schema::new(Root { value: $value }, GQLEmptyMutation, GQLEmptySubscription);
let json_value: serde_json::Value = $value.into();
let query = format!("{{ value testArg(input: {0}) testInput(input: {{value: {0}}}) }}", json_value);
assert_eq!(