diff --git a/Cargo.toml b/Cargo.toml index 5f765f4b..6f9ae4c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -18,7 +18,7 @@ apollo_tracing = ["chrono"] apollo_persisted_queries = ["lru", "sha2"] unblock = ["blocking"] string_number = ["num-traits"] -dataloader = ["futures-timer", "futures-channel"] +dataloader = ["futures-timer", "futures-channel", "lru"] [dependencies] async-graphql-derive = { path = "derive", version = "=2.7.2" } diff --git a/src/dataloader/cache.rs b/src/dataloader/cache.rs new file mode 100644 index 00000000..c9a92ad9 --- /dev/null +++ b/src/dataloader/cache.rs @@ -0,0 +1,177 @@ +use std::borrow::Cow; +use std::collections::HashMap; +use std::hash::Hash; +use std::marker::PhantomData; + +/// Factory for creating cache storage. +pub trait CacheFactory: Send + Sync + 'static { + /// Create a cache storage. + /// + /// TODO: When GAT is stable, this memory allocation can be optimized away. + fn create(&self) -> Box> + where + K: Send + Sync + Clone + Eq + Hash + 'static, + V: Send + Sync + Clone + 'static; +} + +/// Cache storage for [DataLoader]. +pub trait CacheStorage: Send + Sync + 'static { + /// The key type of the record. + type Key: Send + Sync + Clone + Eq + Hash + 'static; + + /// The value type of the record. + type Value: Send + Sync + Clone + 'static; + + /// Returns a reference to the value of the key in the cache or None if it is not present in the cache. + fn get(&mut self, key: &Self::Key) -> Option<&Self::Value>; + + /// Puts a key-value pair into the cache. If the key already exists in the cache, then it updates the key's value. + fn insert(&mut self, key: Cow<'_, Self::Key>, val: Cow<'_, Self::Value>); + + /// Removes the value corresponding to the key from the cache. + fn remove(&mut self, key: &Self::Key); + + /// Clears the cache, removing all key-value pairs. + fn clear(&mut self); +} + +/// No cache. +pub struct NoCache; + +impl CacheFactory for NoCache { + fn create(&self) -> Box> + where + K: Send + Sync + Clone + Eq + Hash + 'static, + V: Send + Sync + Clone + 'static, + { + Box::new(NoCacheImpl { + _mark1: PhantomData, + _mark2: PhantomData, + }) + } +} + +struct NoCacheImpl { + _mark1: PhantomData, + _mark2: PhantomData, +} + +impl CacheStorage for NoCacheImpl +where + K: Send + Sync + Clone + Eq + Hash + 'static, + V: Send + Sync + Clone + 'static, +{ + type Key = K; + type Value = V; + + #[inline] + fn get(&mut self, _key: &K) -> Option<&V> { + None + } + + #[inline] + fn insert(&mut self, _key: Cow<'_, Self::Key>, _val: Cow<'_, Self::Value>) {} + + #[inline] + fn remove(&mut self, _key: &K) {} + + #[inline] + fn clear(&mut self) {} +} + +/// [std::collections::HashMap] cache. +pub struct HashMapCache; + +impl CacheFactory for HashMapCache { + fn create(&self) -> Box> + where + K: Send + Sync + Clone + Eq + Hash + 'static, + V: Send + Sync + Clone + 'static, + { + Box::new(HashMapCacheImpl(Default::default())) + } +} + +struct HashMapCacheImpl(HashMap); + +impl CacheStorage for HashMapCacheImpl +where + K: Send + Sync + Clone + Eq + Hash + 'static, + V: Send + Sync + Clone + 'static, +{ + type Key = K; + type Value = V; + + #[inline] + fn get(&mut self, key: &Self::Key) -> Option<&Self::Value> { + self.0.get(key) + } + + #[inline] + fn insert(&mut self, key: Cow<'_, Self::Key>, val: Cow<'_, Self::Value>) { + self.0.insert(key.into_owned(), val.into_owned()); + } + + #[inline] + fn remove(&mut self, key: &Self::Key) { + self.0.remove(key); + } + + #[inline] + fn clear(&mut self) { + self.0.clear(); + } +} + +/// LRU cache. +pub struct LruCache { + cap: usize, +} + +impl LruCache { + /// Creates a new LRU Cache that holds at most `cap` items. + pub fn new(cap: usize) -> Self { + Self { cap } + } +} + +impl CacheFactory for LruCache { + fn create(&self) -> Box> + where + K: Send + Sync + Clone + Eq + Hash + 'static, + V: Send + Sync + Clone + 'static, + { + Box::new(LruCacheImpl(lru::LruCache::new(self.cap))) + } +} + +struct LruCacheImpl(lru::LruCache); + +impl CacheStorage for LruCacheImpl +where + K: Send + Sync + Clone + Eq + Hash + 'static, + V: Send + Sync + Clone + 'static, +{ + type Key = K; + type Value = V; + + #[inline] + fn get(&mut self, key: &Self::Key) -> Option<&Self::Value> { + self.0.get(key) + } + + #[inline] + fn insert(&mut self, key: Cow<'_, Self::Key>, val: Cow<'_, Self::Value>) { + self.0.put(key.into_owned(), val.into_owned()); + } + + #[inline] + fn remove(&mut self, key: &Self::Key) { + self.0.pop(key); + } + + #[inline] + fn clear(&mut self) { + self.0.clear(); + } +} diff --git a/src/dataloader/mod.rs b/src/dataloader/mod.rs index 4ef350a7..349d7bf6 100644 --- a/src/dataloader/mod.rs +++ b/src/dataloader/mod.rs @@ -57,53 +57,77 @@ //! //! ``` +mod cache; + use std::any::{Any, TypeId}; +use std::borrow::Cow; use std::collections::{HashMap, HashSet}; use std::hash::Hash; +use std::sync::Mutex; use std::time::Duration; use futures_channel::oneshot; use futures_timer::Delay; -use futures_util::lock::Mutex; use fnv::FnvHashMap; -#[allow(clippy::type_complexity)] -struct ResSender>( - oneshot::Sender, T::Error>>, -); +pub use cache::{CacheFactory, CacheStorage, HashMapCache, LruCache, NoCache}; -struct Requests> { - keys: HashSet, - pending: Vec<(HashSet, ResSender)>, +#[allow(clippy::type_complexity)] +struct ResSender> { + use_cache_values: HashMap, + tx: oneshot::Sender, T::Error>>, } -impl> Default for Requests { - fn default() -> Self { +struct Requests> { + keys: HashSet, + pending: Vec<(HashSet, ResSender)>, + cache_storage: Box>, +} + +impl> Requests { + fn new(cache_factory: &C) -> Self { Self { keys: Default::default(), - pending: Default::default(), + pending: Vec::new(), + cache_storage: cache_factory.create::(), } } + + fn take(&mut self) -> (HashSet, Vec<(HashSet, ResSender)>) { + ( + std::mem::take(&mut self.keys), + std::mem::take(&mut self.pending), + ) + } } -impl> Requests { - async fn load(self, loader: &T) { - let keys = self.keys.into_iter().collect::>(); - match loader.load(&keys).await { - Ok(values) => { - for (keys, tx) in self.pending { - let mut res = HashMap::new(); - for key in &keys { - res.extend(values.get(key).map(|value| (key.clone(), value.clone()))); - } - tx.0.send(Ok(res)).ok(); +async fn do_load( + loader: &T, + keys: HashSet, + pending: Vec<(HashSet, ResSender)>, + f: F, +) where + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, + F: FnOnce(&HashMap), +{ + let keys = keys.into_iter().collect::>(); + match loader.load(&keys).await { + Ok(values) => { + f(&values); + for (keys, sender) in pending { + let mut res = HashMap::new(); + res.extend(sender.use_cache_values); + for key in &keys { + res.extend(values.get(key).map(|value| (key.clone(), value.clone()))); } + sender.tx.send(Ok(res)).ok(); } - Err(err) => { - for (_, tx) in self.pending { - tx.0.send(Err(err.clone())).ok(); - } + } + Err(err) => { + for (_, sender) in pending { + sender.tx.send(Err(err.clone())).ok(); } } } @@ -111,9 +135,9 @@ impl> Requests { /// Trait for batch loading. #[async_trait::async_trait] -pub trait Loader: Send + Sync + 'static { +pub trait Loader: Send + Sync + 'static { /// type of value. - type Value: Send + Clone + 'static; + type Value: Send + Sync + Clone + 'static; /// Type of error. type Error: Send + Clone + 'static; @@ -125,17 +149,32 @@ pub trait Loader: Send + Sync + 'static { /// Data loader. /// /// Reference: -pub struct DataLoader { +pub struct DataLoader { requests: Mutex>>, + cache_factory: C, delay: Duration, max_batch_size: usize, loader: T, } -impl DataLoader { - /// Create a DataLoader with the `Loader` trait. +impl DataLoader { + /// Use `Loader` to create a [DataLoader] that does not cache records. pub fn new(loader: T) -> Self { Self { + cache_factory: NoCache, + requests: Default::default(), + delay: Duration::from_millis(1), + max_batch_size: 1000, + loader, + } + } +} + +impl DataLoader { + /// Use `Loader` to create a [DataLoader] with a cache factory. + pub fn with_cache(loader: T, cache_factory: C) -> Self { + Self { + cache_factory, requests: Default::default(), delay: Duration::from_millis(1), max_batch_size: 1000, @@ -174,56 +213,177 @@ impl DataLoader { Ok(values.remove(&key)) } - /// Use this `DataLoader` to load some data. - pub async fn load_many( - &self, - keys: impl Iterator, - ) -> Result, T::Error> + fn update_cache(&self, values: &HashMap) where K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader, { let tid = TypeId::of::(); + let mut requests = self.requests.lock().unwrap(); + let typed_requests = requests + .get_mut(&tid) + .unwrap() + .downcast_mut::>() + .unwrap(); + for (key, value) in values { + typed_requests + .cache_storage + .insert(Cow::Borrowed(key), Cow::Borrowed(value)); + } + } - let (start_fetch, rx) = { - let mut requests = self.requests.lock().await; - let typed_requests = requests - .entry(tid) - .or_insert_with(|| Box::new(Requests::::default())) - .downcast_mut::>() - .unwrap(); - let prev_count = typed_requests.keys.len(); - let keys = keys.collect::>(); - typed_requests.keys.extend(keys.clone()); - let (tx, rx) = oneshot::channel(); - typed_requests.pending.push((keys, ResSender(tx))); - if typed_requests.keys.len() >= self.max_batch_size { - let r = std::mem::take(&mut *typed_requests); - drop(requests); - r.load(&self.loader).await; - (false, rx) - } else { - (!typed_requests.keys.is_empty() && prev_count == 0, rx) - } - }; - - if start_fetch { - Delay::new(self.delay).await; - let mut requests = self.requests.lock().await; + async fn immediate_load(&self) + where + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, + { + let tid = TypeId::of::(); + let (keys, pending) = { + let mut requests = self.requests.lock().unwrap(); let typed_requests = requests .get_mut(&tid) .unwrap() .downcast_mut::>() .unwrap(); - let typed_requests = std::mem::take(typed_requests); - drop(requests); - if !typed_requests.keys.is_empty() { - typed_requests.load(&self.loader).await; + typed_requests.take() + }; + if !keys.is_empty() { + do_load(&self.loader, keys, pending, |values| { + // Update cache + self.update_cache(values); + }) + .await; + } + } + + /// Use this `DataLoader` to load some data. + pub async fn load_many(&self, keys: I) -> Result, T::Error> + where + K: Send + Sync + Hash + Eq + Clone + 'static, + I: IntoIterator, + T: Loader, + { + enum Action { + ImmediateLoad, + StartFetch, + Delay, + } + + let tid = TypeId::of::(); + + let (action, rx) = { + let mut requests = self.requests.lock().unwrap(); + let typed_requests = requests + .entry(tid) + .or_insert_with(|| Box::new(Requests::::new(&self.cache_factory))) + .downcast_mut::>() + .unwrap(); + let prev_count = typed_requests.keys.len(); + let mut keys_set = HashSet::new(); + let mut use_cache_values = HashMap::new(); + + for key in keys { + if let Some(value) = typed_requests.cache_storage.get(&key) { + // Already in cache + use_cache_values.insert(key.clone(), value.clone()); + } else { + keys_set.insert(key); + } } + + if !use_cache_values.is_empty() && keys_set.is_empty() { + return Ok(use_cache_values); + } else if use_cache_values.is_empty() && keys_set.is_empty() { + return Ok(Default::default()); + } + + typed_requests.keys.extend(keys_set.clone()); + let (tx, rx) = oneshot::channel(); + typed_requests.pending.push(( + keys_set, + ResSender { + use_cache_values, + tx, + }, + )); + + if typed_requests.keys.len() >= self.max_batch_size { + (Action::ImmediateLoad, rx) + } else { + ( + if !typed_requests.keys.is_empty() && prev_count == 0 { + Action::StartFetch + } else { + Action::Delay + }, + rx, + ) + } + }; + + match action { + Action::ImmediateLoad => self.immediate_load::().await, + Action::StartFetch => { + Delay::new(self.delay).await; + self.immediate_load::().await; + } + Action::Delay => {} } rx.await.unwrap() } + + /// Feed some data into the cache. + /// + /// **NOTE: If the cache type is [NoCache], this function will not take effect. ** + pub async fn feed_many(&self, values: I) + where + K: Send + Sync + Hash + Eq + Clone + 'static, + I: IntoIterator, + T: Loader, + { + let tid = TypeId::of::(); + let mut requests = self.requests.lock().unwrap(); + let typed_requests = requests + .entry(tid) + .or_insert_with(|| Box::new(Requests::::new(&self.cache_factory))) + .downcast_mut::>() + .unwrap(); + for (key, value) in values { + typed_requests + .cache_storage + .insert(Cow::Owned(key), Cow::Owned(value)); + } + } + + /// Feed some data into the cache. + /// + /// **NOTE: If the cache type is [NoCache], this function will not take effect. ** + pub async fn feed_one(&self, key: K, value: T::Value) + where + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, + { + self.feed_many(std::iter::once((key, value))).await; + } + + /// Clears the cache. + /// + /// **NOTE: If the cache type is [NoCache], this function will not take effect. ** + pub fn clear(&self) + where + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, + { + let tid = TypeId::of::(); + let mut requests = self.requests.lock().unwrap(); + let typed_requests = requests + .entry(tid) + .or_insert_with(|| Box::new(Requests::::new(&self.cache_factory))) + .downcast_mut::>() + .unwrap(); + typed_requests.cache_storage.clear(); + } } #[cfg(test)] @@ -305,4 +465,41 @@ mod tests { .collect::>() ); } + + #[tokio::test] + async fn test_dataloader_load_empty() { + let loader = DataLoader::with_cache(MyLoader, HashMapCache); + assert!(loader.load_many::(vec![]).await.unwrap().is_empty()); + } + + #[tokio::test] + async fn test_dataloader_with_cache() { + let loader = DataLoader::with_cache(MyLoader, HashMapCache); + loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; + + // All from the cache + assert_eq!( + loader.load_many(vec![1, 2, 3]).await.unwrap(), + vec![(1, 10), (2, 20), (3, 30)].into_iter().collect() + ); + + // Part from the cache + assert_eq!( + loader.load_many(vec![1, 5, 6]).await.unwrap(), + vec![(1, 10), (5, 5), (6, 6)].into_iter().collect() + ); + + // All from the loader + assert_eq!( + loader.load_many(vec![8, 9, 10]).await.unwrap(), + vec![(8, 8), (9, 9), (10, 10)].into_iter().collect() + ); + + // Clear cache + loader.clear::(); + assert_eq!( + loader.load_many(vec![1, 2, 3]).await.unwrap(), + vec![(1, 1), (2, 2), (3, 3)].into_iter().collect() + ); + } }