diff --git a/Cargo.toml b/Cargo.toml index 045aed8f..6420a5f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,9 @@ categories = ["network-programming", "asynchronous"] readme = "README.md" [features] -default = ["apollo_tracing", "uuid", "bson", "chrono", "chrono-tz", "log", "multipart", "tracing", "url", "unblock", "string_number"] +default = ["apollo_tracing", "apollo_persisted_queries", "uuid", "bson", "chrono", "chrono-tz", "log", "multipart", "tracing", "url", "unblock", "string_number"] apollo_tracing = ["chrono"] +apollo_persisted_queries = ["lru"] multipart = ["multer", "bytes", "tempfile"] unblock = ["blocking"] string_number = ["num-traits"] @@ -50,6 +51,7 @@ log = { version = "0.4.11", optional = true } tracing = { version = "0.1.19", optional = true } url = { version = "2.1.1", optional = true } num-traits = { version = "0.2.12", optional = true } +lru = { version = "0.6.0", optional = true } bytes = { version = "0.5.4", optional = true } multer = { version = "1.2.2", optional = true } diff --git a/src/error.rs b/src/error.rs index 3792b196..28dca464 100644 --- a/src/error.rs +++ b/src/error.rs @@ -430,6 +430,10 @@ pub enum Error { /// List of errors. errors: RuleErrors, }, + + /// Other errors + #[error("Other error: {0}")] + Other(String), } /// A collection of RuleError. diff --git a/src/extensions/apollo_persisted_queries.rs b/src/extensions/apollo_persisted_queries.rs new file mode 100644 index 00000000..9404bee9 --- /dev/null +++ b/src/extensions/apollo_persisted_queries.rs @@ -0,0 +1,176 @@ +//! Apollo persisted queries extension. + +use crate::extensions::{Error, Extension, ExtensionContext, ExtensionFactory}; +use crate::{Request, Result}; +use futures::lock::Mutex; +use serde::Deserialize; +use std::sync::Arc; + +#[derive(Deserialize)] +struct PersistedQuery { + version: i32, + #[serde(rename = "sha256Hash")] + sha256_hash: String, +} + +/// Cache storage for persisted queries. +#[async_trait::async_trait] +pub trait CacheStorage: Send + Sync + Clone + 'static { + /// Load the query by `key`. + async fn get(&self, key: String) -> Option; + + /// Save the query by `key`. + async fn set(&self, key: String, query: String); +} + +/// Memory-based LRU cache. +#[derive(Clone)] +pub struct LruCacheStorage(Arc>>); + +impl LruCacheStorage { + /// Creates a new LRU Cache that holds at most `cap` items. + pub fn new(cap: usize) -> Self { + Self(Arc::new(Mutex::new(lru::LruCache::new(cap)))) + } +} + +#[async_trait::async_trait] +impl CacheStorage for LruCacheStorage { + async fn get(&self, key: String) -> Option { + let mut cache = self.0.lock().await; + cache.get(&key).cloned() + } + + async fn set(&self, key: String, query: String) { + let mut cache = self.0.lock().await; + cache.put(key, query); + } +} + +/// Apollo persisted queries extension. +/// +/// [Reference](https://www.apollographql.com/docs/react/api/link/persisted-queries/) +#[cfg_attr(feature = "nightly", doc(cfg(feature = "apollo_persisted_queries")))] +pub struct ApolloPersistedQueries(T); + +impl ApolloPersistedQueries { + /// Creates an apollo persisted queries extension. + pub fn new(cache_storage: T) -> ApolloPersistedQueries { + Self(cache_storage) + } +} + +impl ExtensionFactory for ApolloPersistedQueries { + fn create(&self) -> Box { + Box::new(ApolloPersistedQueriesExtension { + storage: self.0.clone(), + }) + } +} + +struct ApolloPersistedQueriesExtension { + storage: T, +} + +#[async_trait::async_trait] +impl Extension for ApolloPersistedQueriesExtension { + async fn prepare_request( + &mut self, + _ctx: &ExtensionContext<'_>, + mut request: Request, + ) -> Result { + if let Some(value) = request.extensions.remove("persistedQuery") { + let persisted_query: PersistedQuery = serde_json::from_value(value).map_err(|_| { + Error::Other("Invalid \"PersistedQuery\" extension configuration.".to_string()) + })?; + if persisted_query.version != 1 { + return Err(Error::Other ( + format!("Only the \"PersistedQuery\" extension of version \"1\" is supported, and the current version is \"{}\".", persisted_query.version), + )); + } + + if request.query.is_empty() { + if let Some(query) = self.storage.get(persisted_query.sha256_hash).await { + Ok(Request { query, ..request }) + } else { + Err(Error::Other("PersistedQueryNotFound".to_string())) + } + } else { + self.storage + .set(persisted_query.sha256_hash, request.query.clone()) + .await; + Ok(request) + } + } else { + Ok(request) + } + } +} + +#[cfg(test)] +mod tests { + #[async_std::test] + async fn test() { + use super::*; + use crate::*; + + struct Query; + + #[Object(internal)] + impl Query { + async fn value(&self) -> i32 { + 100 + } + } + + let schema = Schema::build(Query, EmptyMutation, EmptySubscription) + .extension(ApolloPersistedQueries::new(LruCacheStorage::new(256))) + .finish(); + + let mut request = Request::new("{ value }"); + request.extensions.insert( + "persistedQuery".to_string(), + serde_json::json!({ + "version": 1, + "sha256Hash": "abc", + }), + ); + + assert_eq!( + schema.execute(request).await.into_result().unwrap().data, + serde_json::json!({ + "value": 100 + }) + ); + + let mut request = Request::new(""); + request.extensions.insert( + "persistedQuery".to_string(), + serde_json::json!({ + "version": 1, + "sha256Hash": "abc", + }), + ); + + assert_eq!( + schema.execute(request).await.into_result().unwrap().data, + serde_json::json!({ + "value": 100 + }) + ); + + let mut request = Request::new(""); + request.extensions.insert( + "persistedQuery".to_string(), + serde_json::json!({ + "version": 1, + "sha256Hash": "def", + }), + ); + + assert_eq!( + schema.execute(request).await.into_result().unwrap_err(), + Error::Other("PersistedQueryNotFound".to_string()) + ); + } +} diff --git a/src/extensions/logger.rs b/src/extensions/logger.rs index 1836d642..76cb5757 100644 --- a/src/extensions/logger.rs +++ b/src/extensions/logger.rs @@ -110,6 +110,12 @@ impl Extension for LoggerExtension { error!(target: "async-graphql", "[ValidationError] pos: [{}], query: \"{}\", variables: {}, {}", locations, self.query, self.variables, error.message) } } + Error::Other(err) => error!( + target: "async-graphql", "[OtherError] query: \"{}\", variables: {}, {}", + self.query, + self.variables, + err + ), } } } diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 54c8012e..e61568f7 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -1,5 +1,7 @@ //! Extensions for schema +#[cfg(feature = "apollo_persisted_queries")] +pub mod apollo_persisted_queries; #[cfg(feature = "apollo_tracing")] mod apollo_tracing; #[cfg(feature = "log")] @@ -8,7 +10,7 @@ mod logger; mod tracing; use crate::context::{QueryPathNode, ResolveId}; -use crate::{Data, FieldResult, Result, Variables}; +use crate::{Data, FieldResult, Request, Result, Variables}; #[cfg(feature = "apollo_tracing")] pub use self::apollo_tracing::ApolloTracing; @@ -84,6 +86,7 @@ pub struct ResolveInfo<'a> { } /// Represents a GraphQL extension +#[async_trait::async_trait] #[allow(unused_variables)] pub trait Extension: Sync + Send + 'static { /// If this extension needs to output data to query results, you need to specify a name. @@ -91,6 +94,15 @@ pub trait Extension: Sync + Send + 'static { None } + /// Called at the prepare request + async fn prepare_request( + &mut self, + ctx: &ExtensionContext<'_>, + request: Request, + ) -> Result { + Ok(request) + } + /// Called at the begin of the parse. fn parse_start( &mut self, @@ -143,7 +155,20 @@ impl ErrorLogger for Result { } } +#[async_trait::async_trait] impl Extension for Extensions { + async fn prepare_request( + &mut self, + ctx: &ExtensionContext<'_>, + request: Request, + ) -> Result { + let mut request = request; + for e in self.0.iter_mut() { + request = e.prepare_request(ctx, request).await?; + } + Ok(request) + } + fn parse_start( &mut self, ctx: &ExtensionContext<'_>, diff --git a/src/request.rs b/src/request.rs index 57646eae..4a6ca6ed 100644 --- a/src/request.rs +++ b/src/request.rs @@ -2,6 +2,7 @@ use crate::parser::types::UploadValue; use crate::{Data, ParseRequestError, Value, Variables}; use serde::{Deserialize, Deserializer}; use std::any::Any; +use std::collections::HashMap; use std::fmt::{self, Debug, Formatter}; use std::fs::File; @@ -29,6 +30,10 @@ pub struct Request { /// **This data is only valid for this request** #[serde(skip)] pub data: Data, + + /// The extensions config of the request. + #[serde(default)] + pub extensions: HashMap, } fn deserialize_variables<'de, D: Deserializer<'de>>( @@ -45,6 +50,7 @@ impl Request { operation_name: None, variables: Variables::default(), data: Data::default(), + extensions: Default::default(), } } diff --git a/src/schema.rs b/src/schema.rs index e5b011ab..67088d51 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -321,7 +321,7 @@ where // TODO: Remove the allow #[allow(clippy::type_complexity)] - fn prepare_request(&self, request: Request) -> Result<(QueryEnvInner, CacheControl)> { + async fn prepare_request(&self, request: Request) -> Result<(QueryEnvInner, CacheControl)> { // create extension instances let extensions = spin::Mutex::new(Extensions( self.0 @@ -330,6 +330,18 @@ where .map(|factory| factory.create()) .collect_vec(), )); + + let request = extensions + .lock() + .prepare_request( + &ExtensionContext { + schema_data: &self.env.data, + query_data: &Default::default(), + }, + request, + ) + .await?; + let ctx_extension = ExtensionContext { schema_data: &self.env.data, query_data: &request.data, @@ -451,7 +463,7 @@ where /// Execute an GraphQL query. pub async fn execute(&self, request: impl Into) -> Response { let request = request.into(); - match self.prepare_request(request) { + match self.prepare_request(request).await { Ok((env, cache_control)) => self .execute_once(QueryEnv::new(env)) .await @@ -482,7 +494,7 @@ where async_stream::stream! { let request = request.into(); - let (mut env, cache_control) = match schema.prepare_request(request) { + let (mut env, cache_control) = match schema.prepare_request(request).await { Ok(res) => res, Err(err) => { yield Response::from(err); diff --git a/src/serialize_resp.rs b/src/serialize_resp.rs index a9c267ff..30191b04 100644 --- a/src/serialize_resp.rs +++ b/src/serialize_resp.rs @@ -80,6 +80,13 @@ impl Serialize for Error { } seq.end() } + Error::Other(err) => { + let mut seq = serializer.serialize_seq(Some(1))?; + seq.serialize_element(&serde_json::json! ({ + "message": err, + }))?; + seq.end() + } } } }