//! Apollo persisted queries extension. use std::sync::Arc; use async_graphql_parser::types::ExecutableDocument; use futures_util::lock::Mutex; use serde::Deserialize; use sha2::{Digest, Sha256}; use crate::{ extensions::{Extension, ExtensionContext, ExtensionFactory, NextPrepareRequest}, from_value, Request, ServerError, ServerResult, }; #[derive(Deserialize)] struct PersistedQuery { version: i32, #[serde(rename = "sha256Hash")] sha256_hash: String, } /// Cache storage for persisted queries. #[async_trait::async_trait] pub trait CacheStorage: Send + Sync + Clone + 'static { /// Load the query by `key`. async fn get(&self, key: String) -> Option; /// Save the query by `key`. async fn set(&self, key: String, query: ExecutableDocument); } /// Memory-based LRU cache. #[derive(Clone)] pub struct LruCacheStorage(Arc>>); impl LruCacheStorage { /// Creates a new LRU Cache that holds at most `cap` items. pub fn new(cap: usize) -> Self { Self(Arc::new(Mutex::new(lru::LruCache::new(cap)))) } } #[async_trait::async_trait] impl CacheStorage for LruCacheStorage { async fn get(&self, key: String) -> Option { let mut cache = self.0.lock().await; cache.get(&key).cloned() } async fn set(&self, key: String, query: ExecutableDocument) { let mut cache = self.0.lock().await; cache.put(key, query); } } /// Apollo persisted queries extension. /// /// [Reference](https://www.apollographql.com/docs/react/api/link/persisted-queries/) #[cfg_attr(docsrs, doc(cfg(feature = "apollo_persisted_queries")))] pub struct ApolloPersistedQueries(T); impl ApolloPersistedQueries { /// Creates an apollo persisted queries extension. pub fn new(cache_storage: T) -> ApolloPersistedQueries { Self(cache_storage) } } impl ExtensionFactory for ApolloPersistedQueries { fn create(&self) -> Arc { Arc::new(ApolloPersistedQueriesExtension { storage: self.0.clone(), }) } } struct ApolloPersistedQueriesExtension { storage: T, } #[async_trait::async_trait] impl Extension for ApolloPersistedQueriesExtension { async fn prepare_request( &self, ctx: &ExtensionContext<'_>, mut request: Request, next: NextPrepareRequest<'_>, ) -> ServerResult { let res = if let Some(value) = request.extensions.remove("persistedQuery") { let persisted_query: PersistedQuery = from_value(value).map_err(|_| { ServerError::new("Invalid \"PersistedQuery\" extension configuration.", None) })?; if persisted_query.version != 1 { return Err(ServerError::new( format!("Only the \"PersistedQuery\" extension of version \"1\" is supported, and the current version is \"{}\".", persisted_query.version), None )); } if request.query.is_empty() { if let Some(doc) = self.storage.get(persisted_query.sha256_hash).await { Ok(Request { parsed_query: Some(doc), ..request }) } else { Err(ServerError::new("PersistedQueryNotFound", None)) } } else { let sha256_hash = format!("{:x}", Sha256::digest(request.query.as_bytes())); if persisted_query.sha256_hash != sha256_hash { Err(ServerError::new("provided sha does not match query", None)) } else { let doc = async_graphql_parser::parse_query(&request.query)?; self.storage.set(sha256_hash, doc.clone()).await; Ok(Request { query: String::new(), parsed_query: Some(doc), ..request }) } } } else { Ok(request) }; next.run(ctx, res?).await } } #[cfg(test)] mod tests { #[tokio::test] async fn test() { use super::*; use crate::*; struct Query; #[Object(internal)] impl Query { async fn value(&self) -> i32 { 100 } } let schema = Schema::build(Query, EmptyMutation, EmptySubscription) .extension(ApolloPersistedQueries::new(LruCacheStorage::new(256))) .finish(); let mut request = Request::new("{ value }"); request.extensions.insert( "persistedQuery".to_string(), value!({ "version": 1, "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b", }), ); assert_eq!( schema.execute(request).await.into_result().unwrap().data, value!({ "value": 100 }) ); let mut request = Request::new(""); request.extensions.insert( "persistedQuery".to_string(), value!({ "version": 1, "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b", }), ); assert_eq!( schema.execute(request).await.into_result().unwrap().data, value!({ "value": 100 }) ); let mut request = Request::new(""); request.extensions.insert( "persistedQuery".to_string(), value!({ "version": 1, "sha256Hash": "def", }), ); assert_eq!( schema.execute(request).await.into_result().unwrap_err(), vec![ServerError::new("PersistedQueryNotFound", None)] ); } }