Add DataLoader::enable_all_cache and DataLoader::enable_cache methods. #642

This commit is contained in:
Sunli 2021-09-24 09:46:19 +08:00
parent 5d03592d98
commit 7a464c15eb

View File

@ -63,6 +63,7 @@ use std::any::{Any, TypeId};
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::hash::Hash;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Mutex;
use std::time::Duration;
@ -82,6 +83,7 @@ struct Requests<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
keys: HashSet<K>,
pending: Vec<(HashSet<K>, ResSender<K, T>)>,
cache_storage: Box<dyn CacheStorage<Key = K, Value = T::Value>>,
disable_cache: bool,
}
type KeysAndSender<K, T> = (HashSet<K>, Vec<(HashSet<K>, ResSender<K, T>)>);
@ -92,6 +94,7 @@ impl<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> Requests<K, T>
keys: Default::default(),
pending: Vec::new(),
cache_storage: cache_factory.create::<K, T::Value>(),
disable_cache: false,
}
}
@ -156,6 +159,7 @@ pub struct DataLoader<T, C = NoCache> {
delay: Duration,
max_batch_size: usize,
loader: T,
disable_cache: AtomicBool,
}
impl<T> DataLoader<T, NoCache> {
@ -167,6 +171,7 @@ impl<T> DataLoader<T, NoCache> {
delay: Duration::from_millis(1),
max_batch_size: 1000,
loader,
disable_cache: false.into(),
}
}
}
@ -180,6 +185,7 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
delay: Duration::from_millis(1),
max_batch_size: 1000,
loader,
disable_cache: false.into(),
}
}
@ -204,6 +210,27 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
&self.loader
}
/// Enable/Disable cache of all loaders.
pub fn enable_all_cache(&self, enable: bool) {
self.disable_cache.store(!enable, Ordering::SeqCst);
}
/// Enable/Disable cache of specified loader.
pub fn enable_cache<K>(&self, enable: bool)
where
K: Send + Sync + Hash + Eq + Clone + 'static,
T: Loader<K>,
{
let tid = TypeId::of::<K>();
let mut requests = self.requests.lock().unwrap();
let typed_requests = requests
.get_mut(&tid)
.unwrap()
.downcast_mut::<Requests<K, T>>()
.unwrap();
typed_requests.disable_cache = !enable;
}
/// Use this `DataLoader` load a data.
pub async fn load_one<K>(&self, key: K) -> Result<Option<T::Value>, T::Error>
where
@ -239,19 +266,22 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
T: Loader<K>,
{
let tid = TypeId::of::<K>();
let (keys, pending) = {
let (disable_cache, (keys, pending)) = {
let mut requests = self.requests.lock().unwrap();
let typed_requests = requests
.get_mut(&tid)
.unwrap()
.downcast_mut::<Requests<K, T>>()
.unwrap();
typed_requests.take()
(typed_requests.disable_cache, typed_requests.take())
};
if !keys.is_empty() {
let disable_cache = disable_cache || self.disable_cache.load(Ordering::SeqCst);
do_load(&self.loader, keys, pending, |values| {
// Update cache
self.update_cache(values);
if !disable_cache {
self.update_cache(values);
}
})
.await;
}
@ -283,12 +313,16 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
let mut keys_set = HashSet::new();
let mut use_cache_values = HashMap::new();
for key in keys {
if let Some(value) = typed_requests.cache_storage.get(&key) {
// Already in cache
use_cache_values.insert(key.clone(), value.clone());
} else {
keys_set.insert(key);
if typed_requests.disable_cache || self.disable_cache.load(Ordering::SeqCst) {
keys_set = keys.into_iter().collect();
} else {
for key in keys {
if let Some(value) = typed_requests.cache_storage.get(&key) {
// Already in cache
use_cache_values.insert(key.clone(), value.clone());
} else {
keys_set.insert(key);
}
}
}
@ -535,4 +569,44 @@ mod tests {
vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
);
}
#[tokio::test]
async fn test_dataloader_disable_all_cache() {
let loader = DataLoader::with_cache(MyLoader, HashMapCache::default());
loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
// All from the loader
loader.enable_all_cache(false);
assert_eq!(
loader.load_many(vec![1, 2, 3]).await.unwrap(),
vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
);
// All from the cache
loader.enable_all_cache(true);
assert_eq!(
loader.load_many(vec![1, 2, 3]).await.unwrap(),
vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
);
}
#[tokio::test]
async fn test_dataloader_disable_cache() {
let loader = DataLoader::with_cache(MyLoader, HashMapCache::default());
loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
// All from the loader
loader.enable_cache::<i32>(false);
assert_eq!(
loader.load_many(vec![1, 2, 3]).await.unwrap(),
vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
);
// All from the cache
loader.enable_cache::<i32>(true);
assert_eq!(
loader.load_many(vec![1, 2, 3]).await.unwrap(),
vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
);
}
}