Add DataLoader::enable_all_cache and DataLoader::enable_cache methods. #642

This commit is contained in:
Sunli 2021-09-24 09:46:19 +08:00
parent 5d03592d98
commit 7a464c15eb

View File

@ -63,6 +63,7 @@ use std::any::{Any, TypeId};
use std::borrow::Cow; use std::borrow::Cow;
use std::collections::{HashMap, HashSet}; use std::collections::{HashMap, HashSet};
use std::hash::Hash; use std::hash::Hash;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex; use std::sync::Mutex;
use std::time::Duration; use std::time::Duration;
@ -82,6 +83,7 @@ struct Requests<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
keys: HashSet<K>, keys: HashSet<K>,
pending: Vec<(HashSet<K>, ResSender<K, T>)>, pending: Vec<(HashSet<K>, ResSender<K, T>)>,
cache_storage: Box<dyn CacheStorage<Key = K, Value = T::Value>>, cache_storage: Box<dyn CacheStorage<Key = K, Value = T::Value>>,
disable_cache: bool,
} }
type KeysAndSender<K, T> = (HashSet<K>, Vec<(HashSet<K>, ResSender<K, T>)>); type KeysAndSender<K, T> = (HashSet<K>, Vec<(HashSet<K>, ResSender<K, T>)>);
@ -92,6 +94,7 @@ impl<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> Requests<K, T>
keys: Default::default(), keys: Default::default(),
pending: Vec::new(), pending: Vec::new(),
cache_storage: cache_factory.create::<K, T::Value>(), cache_storage: cache_factory.create::<K, T::Value>(),
disable_cache: false,
} }
} }
@ -156,6 +159,7 @@ pub struct DataLoader<T, C = NoCache> {
delay: Duration, delay: Duration,
max_batch_size: usize, max_batch_size: usize,
loader: T, loader: T,
disable_cache: AtomicBool,
} }
impl<T> DataLoader<T, NoCache> { impl<T> DataLoader<T, NoCache> {
@ -167,6 +171,7 @@ impl<T> DataLoader<T, NoCache> {
delay: Duration::from_millis(1), delay: Duration::from_millis(1),
max_batch_size: 1000, max_batch_size: 1000,
loader, loader,
disable_cache: false.into(),
} }
} }
} }
@ -180,6 +185,7 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
delay: Duration::from_millis(1), delay: Duration::from_millis(1),
max_batch_size: 1000, max_batch_size: 1000,
loader, loader,
disable_cache: false.into(),
} }
} }
@ -204,6 +210,27 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
&self.loader &self.loader
} }
/// Enable/Disable cache of all loaders.
pub fn enable_all_cache(&self, enable: bool) {
self.disable_cache.store(!enable, Ordering::SeqCst);
}
/// Enable/Disable cache of specified loader.
pub fn enable_cache<K>(&self, enable: bool)
where
K: Send + Sync + Hash + Eq + Clone + 'static,
T: Loader<K>,
{
let tid = TypeId::of::<K>();
let mut requests = self.requests.lock().unwrap();
let typed_requests = requests
.get_mut(&tid)
.unwrap()
.downcast_mut::<Requests<K, T>>()
.unwrap();
typed_requests.disable_cache = !enable;
}
/// Use this `DataLoader` load a data. /// Use this `DataLoader` load a data.
pub async fn load_one<K>(&self, key: K) -> Result<Option<T::Value>, T::Error> pub async fn load_one<K>(&self, key: K) -> Result<Option<T::Value>, T::Error>
where where
@ -239,19 +266,22 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
T: Loader<K>, T: Loader<K>,
{ {
let tid = TypeId::of::<K>(); let tid = TypeId::of::<K>();
let (keys, pending) = { let (disable_cache, (keys, pending)) = {
let mut requests = self.requests.lock().unwrap(); let mut requests = self.requests.lock().unwrap();
let typed_requests = requests let typed_requests = requests
.get_mut(&tid) .get_mut(&tid)
.unwrap() .unwrap()
.downcast_mut::<Requests<K, T>>() .downcast_mut::<Requests<K, T>>()
.unwrap(); .unwrap();
typed_requests.take() (typed_requests.disable_cache, typed_requests.take())
}; };
if !keys.is_empty() { if !keys.is_empty() {
let disable_cache = disable_cache || self.disable_cache.load(Ordering::SeqCst);
do_load(&self.loader, keys, pending, |values| { do_load(&self.loader, keys, pending, |values| {
// Update cache // Update cache
self.update_cache(values); if !disable_cache {
self.update_cache(values);
}
}) })
.await; .await;
} }
@ -283,12 +313,16 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
let mut keys_set = HashSet::new(); let mut keys_set = HashSet::new();
let mut use_cache_values = HashMap::new(); let mut use_cache_values = HashMap::new();
for key in keys { if typed_requests.disable_cache || self.disable_cache.load(Ordering::SeqCst) {
if let Some(value) = typed_requests.cache_storage.get(&key) { keys_set = keys.into_iter().collect();
// Already in cache } else {
use_cache_values.insert(key.clone(), value.clone()); for key in keys {
} else { if let Some(value) = typed_requests.cache_storage.get(&key) {
keys_set.insert(key); // Already in cache
use_cache_values.insert(key.clone(), value.clone());
} else {
keys_set.insert(key);
}
} }
} }
@ -535,4 +569,44 @@ mod tests {
vec![(1, 1), (2, 2), (3, 3)].into_iter().collect() vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
); );
} }
#[tokio::test]
async fn test_dataloader_disable_all_cache() {
let loader = DataLoader::with_cache(MyLoader, HashMapCache::default());
loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
// All from the loader
loader.enable_all_cache(false);
assert_eq!(
loader.load_many(vec![1, 2, 3]).await.unwrap(),
vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
);
// All from the cache
loader.enable_all_cache(true);
assert_eq!(
loader.load_many(vec![1, 2, 3]).await.unwrap(),
vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
);
}
#[tokio::test]
async fn test_dataloader_disable_cache() {
let loader = DataLoader::with_cache(MyLoader, HashMapCache::default());
loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
// All from the loader
loader.enable_cache::<i32>(false);
assert_eq!(
loader.load_many(vec![1, 2, 3]).await.unwrap(),
vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
);
// All from the cache
loader.enable_cache::<i32>(true);
assert_eq!(
loader.load_many(vec![1, 2, 3]).await.unwrap(),
vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
);
}
} }