Add cache support for DataLoader. #455

This commit is contained in:
Sunli 2021-04-01 15:46:46 +08:00
parent e8c0bfc743
commit 46927c3fc9
3 changed files with 440 additions and 66 deletions

View File

@ -18,7 +18,7 @@ apollo_tracing = ["chrono"]
apollo_persisted_queries = ["lru", "sha2"]
unblock = ["blocking"]
string_number = ["num-traits"]
dataloader = ["futures-timer", "futures-channel"]
dataloader = ["futures-timer", "futures-channel", "lru"]
[dependencies]
async-graphql-derive = { path = "derive", version = "=2.7.2" }

177
src/dataloader/cache.rs Normal file
View File

@ -0,0 +1,177 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::hash::Hash;
use std::marker::PhantomData;
/// Factory for creating cache storage.
pub trait CacheFactory: Send + Sync + 'static {
/// Create a cache storage.
///
/// TODO: When GAT is stable, this memory allocation can be optimized away.
fn create<K, V>(&self) -> Box<dyn CacheStorage<Key = K, Value = V>>
where
K: Send + Sync + Clone + Eq + Hash + 'static,
V: Send + Sync + Clone + 'static;
}
/// Cache storage for [DataLoader].
pub trait CacheStorage: Send + Sync + 'static {
/// The key type of the record.
type Key: Send + Sync + Clone + Eq + Hash + 'static;
/// The value type of the record.
type Value: Send + Sync + Clone + 'static;
/// Returns a reference to the value of the key in the cache or None if it is not present in the cache.
fn get(&mut self, key: &Self::Key) -> Option<&Self::Value>;
/// Puts a key-value pair into the cache. If the key already exists in the cache, then it updates the key's value.
fn insert(&mut self, key: Cow<'_, Self::Key>, val: Cow<'_, Self::Value>);
/// Removes the value corresponding to the key from the cache.
fn remove(&mut self, key: &Self::Key);
/// Clears the cache, removing all key-value pairs.
fn clear(&mut self);
}
/// No cache.
pub struct NoCache;
impl CacheFactory for NoCache {
fn create<K, V>(&self) -> Box<dyn CacheStorage<Key = K, Value = V>>
where
K: Send + Sync + Clone + Eq + Hash + 'static,
V: Send + Sync + Clone + 'static,
{
Box::new(NoCacheImpl {
_mark1: PhantomData,
_mark2: PhantomData,
})
}
}
struct NoCacheImpl<K, V> {
_mark1: PhantomData<K>,
_mark2: PhantomData<V>,
}
impl<K, V> CacheStorage for NoCacheImpl<K, V>
where
K: Send + Sync + Clone + Eq + Hash + 'static,
V: Send + Sync + Clone + 'static,
{
type Key = K;
type Value = V;
#[inline]
fn get(&mut self, _key: &K) -> Option<&V> {
None
}
#[inline]
fn insert(&mut self, _key: Cow<'_, Self::Key>, _val: Cow<'_, Self::Value>) {}
#[inline]
fn remove(&mut self, _key: &K) {}
#[inline]
fn clear(&mut self) {}
}
/// [std::collections::HashMap] cache.
pub struct HashMapCache;
impl CacheFactory for HashMapCache {
fn create<K, V>(&self) -> Box<dyn CacheStorage<Key = K, Value = V>>
where
K: Send + Sync + Clone + Eq + Hash + 'static,
V: Send + Sync + Clone + 'static,
{
Box::new(HashMapCacheImpl(Default::default()))
}
}
struct HashMapCacheImpl<K, V>(HashMap<K, V>);
impl<K, V> CacheStorage for HashMapCacheImpl<K, V>
where
K: Send + Sync + Clone + Eq + Hash + 'static,
V: Send + Sync + Clone + 'static,
{
type Key = K;
type Value = V;
#[inline]
fn get(&mut self, key: &Self::Key) -> Option<&Self::Value> {
self.0.get(key)
}
#[inline]
fn insert(&mut self, key: Cow<'_, Self::Key>, val: Cow<'_, Self::Value>) {
self.0.insert(key.into_owned(), val.into_owned());
}
#[inline]
fn remove(&mut self, key: &Self::Key) {
self.0.remove(key);
}
#[inline]
fn clear(&mut self) {
self.0.clear();
}
}
/// LRU cache.
pub struct LruCache {
cap: usize,
}
impl LruCache {
/// Creates a new LRU Cache that holds at most `cap` items.
pub fn new(cap: usize) -> Self {
Self { cap }
}
}
impl CacheFactory for LruCache {
fn create<K, V>(&self) -> Box<dyn CacheStorage<Key = K, Value = V>>
where
K: Send + Sync + Clone + Eq + Hash + 'static,
V: Send + Sync + Clone + 'static,
{
Box::new(LruCacheImpl(lru::LruCache::new(self.cap)))
}
}
struct LruCacheImpl<K, V>(lru::LruCache<K, V>);
impl<K, V> CacheStorage for LruCacheImpl<K, V>
where
K: Send + Sync + Clone + Eq + Hash + 'static,
V: Send + Sync + Clone + 'static,
{
type Key = K;
type Value = V;
#[inline]
fn get(&mut self, key: &Self::Key) -> Option<&Self::Value> {
self.0.get(key)
}
#[inline]
fn insert(&mut self, key: Cow<'_, Self::Key>, val: Cow<'_, Self::Value>) {
self.0.put(key.into_owned(), val.into_owned());
}
#[inline]
fn remove(&mut self, key: &Self::Key) {
self.0.pop(key);
}
#[inline]
fn clear(&mut self) {
self.0.clear();
}
}

View File

@ -57,53 +57,77 @@
//!
//! ```
mod cache;
use std::any::{Any, TypeId};
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::hash::Hash;
use std::sync::Mutex;
use std::time::Duration;
use futures_channel::oneshot;
use futures_timer::Delay;
use futures_util::lock::Mutex;
use fnv::FnvHashMap;
#[allow(clippy::type_complexity)]
struct ResSender<K: Send + Hash + Eq + Clone + 'static, T: Loader<K>>(
oneshot::Sender<Result<HashMap<K, T::Value>, T::Error>>,
);
pub use cache::{CacheFactory, CacheStorage, HashMapCache, LruCache, NoCache};
struct Requests<K: Send + Hash + Eq + Clone + 'static, T: Loader<K>> {
keys: HashSet<K>,
pending: Vec<(HashSet<K>, ResSender<K, T>)>,
#[allow(clippy::type_complexity)]
struct ResSender<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
use_cache_values: HashMap<K, T::Value>,
tx: oneshot::Sender<Result<HashMap<K, T::Value>, T::Error>>,
}
impl<K: Send + Hash + Eq + Clone + 'static, T: Loader<K>> Default for Requests<K, T> {
fn default() -> Self {
struct Requests<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
keys: HashSet<K>,
pending: Vec<(HashSet<K>, ResSender<K, T>)>,
cache_storage: Box<dyn CacheStorage<Key = K, Value = T::Value>>,
}
impl<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> Requests<K, T> {
fn new<C: CacheFactory>(cache_factory: &C) -> Self {
Self {
keys: Default::default(),
pending: Default::default(),
pending: Vec::new(),
cache_storage: cache_factory.create::<K, T::Value>(),
}
}
fn take(&mut self) -> (HashSet<K>, Vec<(HashSet<K>, ResSender<K, T>)>) {
(
std::mem::take(&mut self.keys),
std::mem::take(&mut self.pending),
)
}
}
impl<K: Send + Hash + Eq + Clone + 'static, T: Loader<K>> Requests<K, T> {
async fn load(self, loader: &T) {
let keys = self.keys.into_iter().collect::<Vec<_>>();
match loader.load(&keys).await {
Ok(values) => {
for (keys, tx) in self.pending {
let mut res = HashMap::new();
for key in &keys {
res.extend(values.get(key).map(|value| (key.clone(), value.clone())));
}
tx.0.send(Ok(res)).ok();
async fn do_load<K, T, F>(
loader: &T,
keys: HashSet<K>,
pending: Vec<(HashSet<K>, ResSender<K, T>)>,
f: F,
) where
K: Send + Sync + Hash + Eq + Clone + 'static,
T: Loader<K>,
F: FnOnce(&HashMap<K, T::Value>),
{
let keys = keys.into_iter().collect::<Vec<_>>();
match loader.load(&keys).await {
Ok(values) => {
f(&values);
for (keys, sender) in pending {
let mut res = HashMap::new();
res.extend(sender.use_cache_values);
for key in &keys {
res.extend(values.get(key).map(|value| (key.clone(), value.clone())));
}
sender.tx.send(Ok(res)).ok();
}
Err(err) => {
for (_, tx) in self.pending {
tx.0.send(Err(err.clone())).ok();
}
}
Err(err) => {
for (_, sender) in pending {
sender.tx.send(Err(err.clone())).ok();
}
}
}
@ -111,9 +135,9 @@ impl<K: Send + Hash + Eq + Clone + 'static, T: Loader<K>> Requests<K, T> {
/// Trait for batch loading.
#[async_trait::async_trait]
pub trait Loader<K: Send + Hash + Eq + Clone + 'static>: Send + Sync + 'static {
pub trait Loader<K: Send + Sync + Hash + Eq + Clone + 'static>: Send + Sync + 'static {
/// type of value.
type Value: Send + Clone + 'static;
type Value: Send + Sync + Clone + 'static;
/// Type of error.
type Error: Send + Clone + 'static;
@ -125,17 +149,32 @@ pub trait Loader<K: Send + Hash + Eq + Clone + 'static>: Send + Sync + 'static {
/// Data loader.
///
/// Reference: <https://github.com/facebook/dataloader>
pub struct DataLoader<T> {
pub struct DataLoader<T, C = NoCache> {
requests: Mutex<FnvHashMap<TypeId, Box<dyn Any + Sync + Send>>>,
cache_factory: C,
delay: Duration,
max_batch_size: usize,
loader: T,
}
impl<T> DataLoader<T> {
/// Create a DataLoader with the `Loader` trait.
impl<T> DataLoader<T, NoCache> {
/// Use `Loader` to create a [DataLoader] that does not cache records.
pub fn new(loader: T) -> Self {
Self {
cache_factory: NoCache,
requests: Default::default(),
delay: Duration::from_millis(1),
max_batch_size: 1000,
loader,
}
}
}
impl<T, C: CacheFactory> DataLoader<T, C> {
/// Use `Loader` to create a [DataLoader] with a cache factory.
pub fn with_cache(loader: T, cache_factory: C) -> Self {
Self {
cache_factory,
requests: Default::default(),
delay: Duration::from_millis(1),
max_batch_size: 1000,
@ -174,56 +213,177 @@ impl<T> DataLoader<T> {
Ok(values.remove(&key))
}
/// Use this `DataLoader` to load some data.
pub async fn load_many<K>(
&self,
keys: impl Iterator<Item = K>,
) -> Result<HashMap<K, T::Value>, T::Error>
fn update_cache<K>(&self, values: &HashMap<K, T::Value>)
where
K: Send + Sync + Hash + Eq + Clone + 'static,
T: Loader<K>,
{
let tid = TypeId::of::<K>();
let mut requests = self.requests.lock().unwrap();
let typed_requests = requests
.get_mut(&tid)
.unwrap()
.downcast_mut::<Requests<K, T>>()
.unwrap();
for (key, value) in values {
typed_requests
.cache_storage
.insert(Cow::Borrowed(key), Cow::Borrowed(value));
}
}
let (start_fetch, rx) = {
let mut requests = self.requests.lock().await;
let typed_requests = requests
.entry(tid)
.or_insert_with(|| Box::new(Requests::<K, T>::default()))
.downcast_mut::<Requests<K, T>>()
.unwrap();
let prev_count = typed_requests.keys.len();
let keys = keys.collect::<HashSet<_>>();
typed_requests.keys.extend(keys.clone());
let (tx, rx) = oneshot::channel();
typed_requests.pending.push((keys, ResSender(tx)));
if typed_requests.keys.len() >= self.max_batch_size {
let r = std::mem::take(&mut *typed_requests);
drop(requests);
r.load(&self.loader).await;
(false, rx)
} else {
(!typed_requests.keys.is_empty() && prev_count == 0, rx)
}
};
if start_fetch {
Delay::new(self.delay).await;
let mut requests = self.requests.lock().await;
async fn immediate_load<K>(&self)
where
K: Send + Sync + Hash + Eq + Clone + 'static,
T: Loader<K>,
{
let tid = TypeId::of::<K>();
let (keys, pending) = {
let mut requests = self.requests.lock().unwrap();
let typed_requests = requests
.get_mut(&tid)
.unwrap()
.downcast_mut::<Requests<K, T>>()
.unwrap();
let typed_requests = std::mem::take(typed_requests);
drop(requests);
if !typed_requests.keys.is_empty() {
typed_requests.load(&self.loader).await;
typed_requests.take()
};
if !keys.is_empty() {
do_load(&self.loader, keys, pending, |values| {
// Update cache
self.update_cache(values);
})
.await;
}
}
/// Use this `DataLoader` to load some data.
pub async fn load_many<K, I>(&self, keys: I) -> Result<HashMap<K, T::Value>, T::Error>
where
K: Send + Sync + Hash + Eq + Clone + 'static,
I: IntoIterator<Item = K>,
T: Loader<K>,
{
enum Action {
ImmediateLoad,
StartFetch,
Delay,
}
let tid = TypeId::of::<K>();
let (action, rx) = {
let mut requests = self.requests.lock().unwrap();
let typed_requests = requests
.entry(tid)
.or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
.downcast_mut::<Requests<K, T>>()
.unwrap();
let prev_count = typed_requests.keys.len();
let mut keys_set = HashSet::new();
let mut use_cache_values = HashMap::new();
for key in keys {
if let Some(value) = typed_requests.cache_storage.get(&key) {
// Already in cache
use_cache_values.insert(key.clone(), value.clone());
} else {
keys_set.insert(key);
}
}
if !use_cache_values.is_empty() && keys_set.is_empty() {
return Ok(use_cache_values);
} else if use_cache_values.is_empty() && keys_set.is_empty() {
return Ok(Default::default());
}
typed_requests.keys.extend(keys_set.clone());
let (tx, rx) = oneshot::channel();
typed_requests.pending.push((
keys_set,
ResSender {
use_cache_values,
tx,
},
));
if typed_requests.keys.len() >= self.max_batch_size {
(Action::ImmediateLoad, rx)
} else {
(
if !typed_requests.keys.is_empty() && prev_count == 0 {
Action::StartFetch
} else {
Action::Delay
},
rx,
)
}
};
match action {
Action::ImmediateLoad => self.immediate_load::<K>().await,
Action::StartFetch => {
Delay::new(self.delay).await;
self.immediate_load::<K>().await;
}
Action::Delay => {}
}
rx.await.unwrap()
}
/// Feed some data into the cache.
///
/// **NOTE: If the cache type is [NoCache], this function will not take effect. **
pub async fn feed_many<K, I>(&self, values: I)
where
K: Send + Sync + Hash + Eq + Clone + 'static,
I: IntoIterator<Item = (K, T::Value)>,
T: Loader<K>,
{
let tid = TypeId::of::<K>();
let mut requests = self.requests.lock().unwrap();
let typed_requests = requests
.entry(tid)
.or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
.downcast_mut::<Requests<K, T>>()
.unwrap();
for (key, value) in values {
typed_requests
.cache_storage
.insert(Cow::Owned(key), Cow::Owned(value));
}
}
/// Feed some data into the cache.
///
/// **NOTE: If the cache type is [NoCache], this function will not take effect. **
pub async fn feed_one<K>(&self, key: K, value: T::Value)
where
K: Send + Sync + Hash + Eq + Clone + 'static,
T: Loader<K>,
{
self.feed_many(std::iter::once((key, value))).await;
}
/// Clears the cache.
///
/// **NOTE: If the cache type is [NoCache], this function will not take effect. **
pub fn clear<K>(&self)
where
K: Send + Sync + Hash + Eq + Clone + 'static,
T: Loader<K>,
{
let tid = TypeId::of::<K>();
let mut requests = self.requests.lock().unwrap();
let typed_requests = requests
.entry(tid)
.or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
.downcast_mut::<Requests<K, T>>()
.unwrap();
typed_requests.cache_storage.clear();
}
}
#[cfg(test)]
@ -305,4 +465,41 @@ mod tests {
.collect::<Vec<_>>()
);
}
#[tokio::test]
async fn test_dataloader_load_empty() {
let loader = DataLoader::with_cache(MyLoader, HashMapCache);
assert!(loader.load_many::<i32, _>(vec![]).await.unwrap().is_empty());
}
#[tokio::test]
async fn test_dataloader_with_cache() {
let loader = DataLoader::with_cache(MyLoader, HashMapCache);
loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
// All from the cache
assert_eq!(
loader.load_many(vec![1, 2, 3]).await.unwrap(),
vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
);
// Part from the cache
assert_eq!(
loader.load_many(vec![1, 5, 6]).await.unwrap(),
vec![(1, 10), (5, 5), (6, 6)].into_iter().collect()
);
// All from the loader
assert_eq!(
loader.load_many(vec![8, 9, 10]).await.unwrap(),
vec![(8, 8), (9, 9), (10, 10)].into_iter().collect()
);
// Clear cache
loader.clear::<i32>();
assert_eq!(
loader.load_many(vec![1, 2, 3]).await.unwrap(),
vec![(1, 1), (2, 2), (3, 3)].into_iter().collect()
);
}
}