From 7a464c15eb28a11eefacef951f0b869018775b44 Mon Sep 17 00:00:00 2001 From: Sunli Date: Fri, 24 Sep 2021 09:46:19 +0800 Subject: [PATCH] Add `DataLoader::enable_all_cache` and `DataLoader::enable_cache` methods. #642 --- src/dataloader/mod.rs | 92 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 83 insertions(+), 9 deletions(-) diff --git a/src/dataloader/mod.rs b/src/dataloader/mod.rs index f44d8949..f4c95a89 100644 --- a/src/dataloader/mod.rs +++ b/src/dataloader/mod.rs @@ -63,6 +63,7 @@ use std::any::{Any, TypeId}; use std::borrow::Cow; use std::collections::{HashMap, HashSet}; use std::hash::Hash; +use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Mutex; use std::time::Duration; @@ -82,6 +83,7 @@ struct Requests> { keys: HashSet, pending: Vec<(HashSet, ResSender)>, cache_storage: Box>, + disable_cache: bool, } type KeysAndSender = (HashSet, Vec<(HashSet, ResSender)>); @@ -92,6 +94,7 @@ impl> Requests keys: Default::default(), pending: Vec::new(), cache_storage: cache_factory.create::(), + disable_cache: false, } } @@ -156,6 +159,7 @@ pub struct DataLoader { delay: Duration, max_batch_size: usize, loader: T, + disable_cache: AtomicBool, } impl DataLoader { @@ -167,6 +171,7 @@ impl DataLoader { delay: Duration::from_millis(1), max_batch_size: 1000, loader, + disable_cache: false.into(), } } } @@ -180,6 +185,7 @@ impl DataLoader { delay: Duration::from_millis(1), max_batch_size: 1000, loader, + disable_cache: false.into(), } } @@ -204,6 +210,27 @@ impl DataLoader { &self.loader } + /// Enable/Disable cache of all loaders. + pub fn enable_all_cache(&self, enable: bool) { + self.disable_cache.store(!enable, Ordering::SeqCst); + } + + /// Enable/Disable cache of specified loader. + pub fn enable_cache(&self, enable: bool) + where + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, + { + let tid = TypeId::of::(); + let mut requests = self.requests.lock().unwrap(); + let typed_requests = requests + .get_mut(&tid) + .unwrap() + .downcast_mut::>() + .unwrap(); + typed_requests.disable_cache = !enable; + } + /// Use this `DataLoader` load a data. pub async fn load_one(&self, key: K) -> Result, T::Error> where @@ -239,19 +266,22 @@ impl DataLoader { T: Loader, { let tid = TypeId::of::(); - let (keys, pending) = { + let (disable_cache, (keys, pending)) = { let mut requests = self.requests.lock().unwrap(); let typed_requests = requests .get_mut(&tid) .unwrap() .downcast_mut::>() .unwrap(); - typed_requests.take() + (typed_requests.disable_cache, typed_requests.take()) }; if !keys.is_empty() { + let disable_cache = disable_cache || self.disable_cache.load(Ordering::SeqCst); do_load(&self.loader, keys, pending, |values| { // Update cache - self.update_cache(values); + if !disable_cache { + self.update_cache(values); + } }) .await; } @@ -283,12 +313,16 @@ impl DataLoader { let mut keys_set = HashSet::new(); let mut use_cache_values = HashMap::new(); - for key in keys { - if let Some(value) = typed_requests.cache_storage.get(&key) { - // Already in cache - use_cache_values.insert(key.clone(), value.clone()); - } else { - keys_set.insert(key); + if typed_requests.disable_cache || self.disable_cache.load(Ordering::SeqCst) { + keys_set = keys.into_iter().collect(); + } else { + for key in keys { + if let Some(value) = typed_requests.cache_storage.get(&key) { + // Already in cache + use_cache_values.insert(key.clone(), value.clone()); + } else { + keys_set.insert(key); + } } } @@ -535,4 +569,44 @@ mod tests { vec![(1, 1), (2, 2), (3, 3)].into_iter().collect() ); } + + #[tokio::test] + async fn test_dataloader_disable_all_cache() { + let loader = DataLoader::with_cache(MyLoader, HashMapCache::default()); + loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; + + // All from the loader + loader.enable_all_cache(false); + assert_eq!( + loader.load_many(vec![1, 2, 3]).await.unwrap(), + vec![(1, 1), (2, 2), (3, 3)].into_iter().collect() + ); + + // All from the cache + loader.enable_all_cache(true); + assert_eq!( + loader.load_many(vec![1, 2, 3]).await.unwrap(), + vec![(1, 10), (2, 20), (3, 30)].into_iter().collect() + ); + } + + #[tokio::test] + async fn test_dataloader_disable_cache() { + let loader = DataLoader::with_cache(MyLoader, HashMapCache::default()); + loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; + + // All from the loader + loader.enable_cache::(false); + assert_eq!( + loader.load_many(vec![1, 2, 3]).await.unwrap(), + vec![(1, 1), (2, 2), (3, 3)].into_iter().collect() + ); + + // All from the cache + loader.enable_cache::(true); + assert_eq!( + loader.load_many(vec![1, 2, 3]).await.unwrap(), + vec![(1, 10), (2, 20), (3, 30)].into_iter().collect() + ); + } }