Added Apollo persisted queries extension. #280

This commit is contained in:
Sunli 2020-09-30 11:37:12 +08:00
parent 7ca82d9a9f
commit a95a532857
8 changed files with 243 additions and 5 deletions

View File

@ -14,8 +14,9 @@ categories = ["network-programming", "asynchronous"]
readme = "README.md"
[features]
default = ["apollo_tracing", "uuid", "bson", "chrono", "chrono-tz", "log", "multipart", "tracing", "url", "unblock", "string_number"]
default = ["apollo_tracing", "apollo_persisted_queries", "uuid", "bson", "chrono", "chrono-tz", "log", "multipart", "tracing", "url", "unblock", "string_number"]
apollo_tracing = ["chrono"]
apollo_persisted_queries = ["lru"]
multipart = ["multer", "bytes", "tempfile"]
unblock = ["blocking"]
string_number = ["num-traits"]
@ -50,6 +51,7 @@ log = { version = "0.4.11", optional = true }
tracing = { version = "0.1.19", optional = true }
url = { version = "2.1.1", optional = true }
num-traits = { version = "0.2.12", optional = true }
lru = { version = "0.6.0", optional = true }
bytes = { version = "0.5.4", optional = true }
multer = { version = "1.2.2", optional = true }

View File

@ -430,6 +430,10 @@ pub enum Error {
/// List of errors.
errors: RuleErrors,
},
/// Other errors
#[error("Other error: {0}")]
Other(String),
}
/// A collection of RuleError.

View File

@ -0,0 +1,176 @@
//! Apollo persisted queries extension.
use crate::extensions::{Error, Extension, ExtensionContext, ExtensionFactory};
use crate::{Request, Result};
use futures::lock::Mutex;
use serde::Deserialize;
use std::sync::Arc;
#[derive(Deserialize)]
struct PersistedQuery {
version: i32,
#[serde(rename = "sha256Hash")]
sha256_hash: String,
}
/// Cache storage for persisted queries.
#[async_trait::async_trait]
pub trait CacheStorage: Send + Sync + Clone + 'static {
/// Load the query by `key`.
async fn get(&self, key: String) -> Option<String>;
/// Save the query by `key`.
async fn set(&self, key: String, query: String);
}
/// Memory-based LRU cache.
#[derive(Clone)]
pub struct LruCacheStorage(Arc<Mutex<lru::LruCache<String, String>>>);
impl LruCacheStorage {
/// Creates a new LRU Cache that holds at most `cap` items.
pub fn new(cap: usize) -> Self {
Self(Arc::new(Mutex::new(lru::LruCache::new(cap))))
}
}
#[async_trait::async_trait]
impl CacheStorage for LruCacheStorage {
async fn get(&self, key: String) -> Option<String> {
let mut cache = self.0.lock().await;
cache.get(&key).cloned()
}
async fn set(&self, key: String, query: String) {
let mut cache = self.0.lock().await;
cache.put(key, query);
}
}
/// Apollo persisted queries extension.
///
/// [Reference](https://www.apollographql.com/docs/react/api/link/persisted-queries/)
#[cfg_attr(feature = "nightly", doc(cfg(feature = "apollo_persisted_queries")))]
pub struct ApolloPersistedQueries<T>(T);
impl<T: CacheStorage> ApolloPersistedQueries<T> {
/// Creates an apollo persisted queries extension.
pub fn new(cache_storage: T) -> ApolloPersistedQueries<T> {
Self(cache_storage)
}
}
impl<T: CacheStorage> ExtensionFactory for ApolloPersistedQueries<T> {
fn create(&self) -> Box<dyn Extension> {
Box::new(ApolloPersistedQueriesExtension {
storage: self.0.clone(),
})
}
}
struct ApolloPersistedQueriesExtension<T> {
storage: T,
}
#[async_trait::async_trait]
impl<T: CacheStorage> Extension for ApolloPersistedQueriesExtension<T> {
async fn prepare_request(
&mut self,
_ctx: &ExtensionContext<'_>,
mut request: Request,
) -> Result<Request> {
if let Some(value) = request.extensions.remove("persistedQuery") {
let persisted_query: PersistedQuery = serde_json::from_value(value).map_err(|_| {
Error::Other("Invalid \"PersistedQuery\" extension configuration.".to_string())
})?;
if persisted_query.version != 1 {
return Err(Error::Other (
format!("Only the \"PersistedQuery\" extension of version \"1\" is supported, and the current version is \"{}\".", persisted_query.version),
));
}
if request.query.is_empty() {
if let Some(query) = self.storage.get(persisted_query.sha256_hash).await {
Ok(Request { query, ..request })
} else {
Err(Error::Other("PersistedQueryNotFound".to_string()))
}
} else {
self.storage
.set(persisted_query.sha256_hash, request.query.clone())
.await;
Ok(request)
}
} else {
Ok(request)
}
}
}
#[cfg(test)]
mod tests {
#[async_std::test]
async fn test() {
use super::*;
use crate::*;
struct Query;
#[Object(internal)]
impl Query {
async fn value(&self) -> i32 {
100
}
}
let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
.extension(ApolloPersistedQueries::new(LruCacheStorage::new(256)))
.finish();
let mut request = Request::new("{ value }");
request.extensions.insert(
"persistedQuery".to_string(),
serde_json::json!({
"version": 1,
"sha256Hash": "abc",
}),
);
assert_eq!(
schema.execute(request).await.into_result().unwrap().data,
serde_json::json!({
"value": 100
})
);
let mut request = Request::new("");
request.extensions.insert(
"persistedQuery".to_string(),
serde_json::json!({
"version": 1,
"sha256Hash": "abc",
}),
);
assert_eq!(
schema.execute(request).await.into_result().unwrap().data,
serde_json::json!({
"value": 100
})
);
let mut request = Request::new("");
request.extensions.insert(
"persistedQuery".to_string(),
serde_json::json!({
"version": 1,
"sha256Hash": "def",
}),
);
assert_eq!(
schema.execute(request).await.into_result().unwrap_err(),
Error::Other("PersistedQueryNotFound".to_string())
);
}
}

View File

@ -110,6 +110,12 @@ impl Extension for LoggerExtension {
error!(target: "async-graphql", "[ValidationError] pos: [{}], query: \"{}\", variables: {}, {}", locations, self.query, self.variables, error.message)
}
}
Error::Other(err) => error!(
target: "async-graphql", "[OtherError] query: \"{}\", variables: {}, {}",
self.query,
self.variables,
err
),
}
}
}

View File

@ -1,5 +1,7 @@
//! Extensions for schema
#[cfg(feature = "apollo_persisted_queries")]
pub mod apollo_persisted_queries;
#[cfg(feature = "apollo_tracing")]
mod apollo_tracing;
#[cfg(feature = "log")]
@ -8,7 +10,7 @@ mod logger;
mod tracing;
use crate::context::{QueryPathNode, ResolveId};
use crate::{Data, FieldResult, Result, Variables};
use crate::{Data, FieldResult, Request, Result, Variables};
#[cfg(feature = "apollo_tracing")]
pub use self::apollo_tracing::ApolloTracing;
@ -84,6 +86,7 @@ pub struct ResolveInfo<'a> {
}
/// Represents a GraphQL extension
#[async_trait::async_trait]
#[allow(unused_variables)]
pub trait Extension: Sync + Send + 'static {
/// If this extension needs to output data to query results, you need to specify a name.
@ -91,6 +94,15 @@ pub trait Extension: Sync + Send + 'static {
None
}
/// Called at the prepare request
async fn prepare_request(
&mut self,
ctx: &ExtensionContext<'_>,
request: Request,
) -> Result<Request> {
Ok(request)
}
/// Called at the begin of the parse.
fn parse_start(
&mut self,
@ -143,7 +155,20 @@ impl<T> ErrorLogger for Result<T> {
}
}
#[async_trait::async_trait]
impl Extension for Extensions {
async fn prepare_request(
&mut self,
ctx: &ExtensionContext<'_>,
request: Request,
) -> Result<Request> {
let mut request = request;
for e in self.0.iter_mut() {
request = e.prepare_request(ctx, request).await?;
}
Ok(request)
}
fn parse_start(
&mut self,
ctx: &ExtensionContext<'_>,

View File

@ -2,6 +2,7 @@ use crate::parser::types::UploadValue;
use crate::{Data, ParseRequestError, Value, Variables};
use serde::{Deserialize, Deserializer};
use std::any::Any;
use std::collections::HashMap;
use std::fmt::{self, Debug, Formatter};
use std::fs::File;
@ -29,6 +30,10 @@ pub struct Request {
/// **This data is only valid for this request**
#[serde(skip)]
pub data: Data,
/// The extensions config of the request.
#[serde(default)]
pub extensions: HashMap<String, serde_json::Value>,
}
fn deserialize_variables<'de, D: Deserializer<'de>>(
@ -45,6 +50,7 @@ impl Request {
operation_name: None,
variables: Variables::default(),
data: Data::default(),
extensions: Default::default(),
}
}

View File

@ -321,7 +321,7 @@ where
// TODO: Remove the allow
#[allow(clippy::type_complexity)]
fn prepare_request(&self, request: Request) -> Result<(QueryEnvInner, CacheControl)> {
async fn prepare_request(&self, request: Request) -> Result<(QueryEnvInner, CacheControl)> {
// create extension instances
let extensions = spin::Mutex::new(Extensions(
self.0
@ -330,6 +330,18 @@ where
.map(|factory| factory.create())
.collect_vec(),
));
let request = extensions
.lock()
.prepare_request(
&ExtensionContext {
schema_data: &self.env.data,
query_data: &Default::default(),
},
request,
)
.await?;
let ctx_extension = ExtensionContext {
schema_data: &self.env.data,
query_data: &request.data,
@ -451,7 +463,7 @@ where
/// Execute an GraphQL query.
pub async fn execute(&self, request: impl Into<Request>) -> Response {
let request = request.into();
match self.prepare_request(request) {
match self.prepare_request(request).await {
Ok((env, cache_control)) => self
.execute_once(QueryEnv::new(env))
.await
@ -482,7 +494,7 @@ where
async_stream::stream! {
let request = request.into();
let (mut env, cache_control) = match schema.prepare_request(request) {
let (mut env, cache_control) = match schema.prepare_request(request).await {
Ok(res) => res,
Err(err) => {
yield Response::from(err);

View File

@ -80,6 +80,13 @@ impl Serialize for Error {
}
seq.end()
}
Error::Other(err) => {
let mut seq = serializer.serialize_seq(Some(1))?;
seq.serialize_element(&serde_json::json! ({
"message": err,
}))?;
seq.end()
}
}
}
}