diff --git a/src/dataloader/mod.rs b/src/dataloader/mod.rs index 414808fb..b9be879f 100644 --- a/src/dataloader/mod.rs +++ b/src/dataloader/mod.rs @@ -43,7 +43,7 @@ //! v5: value(n: 5) //! } //! "#; -//! let request = Request::new(query).data(DataLoader::new(MyLoader)); +//! let request = Request::new(query).data(DataLoader::new(MyLoader, tokio::spawn)); //! let res = schema.execute(request).await.into_result().unwrap().data; //! //! assert_eq!(res, value!({ @@ -63,12 +63,13 @@ use std::borrow::Cow; use std::collections::{HashMap, HashSet}; use std::hash::Hash; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; use std::time::Duration; use fnv::FnvHashMap; use futures_channel::oneshot; use futures_timer::Delay; +use futures_util::future::BoxFuture; pub use cache::{CacheFactory, CacheStorage, HashMapCache, LruCache, NoCache}; @@ -105,37 +106,6 @@ impl> Requests } } -async fn do_load( - loader: &T, - keys: HashSet, - pending: Vec<(HashSet, ResSender)>, - f: F, -) where - K: Send + Sync + Hash + Eq + Clone + 'static, - T: Loader, - F: FnOnce(&HashMap), -{ - let keys = keys.into_iter().collect::>(); - match loader.load(&keys).await { - Ok(values) => { - f(&values); - for (keys, sender) in pending { - let mut res = HashMap::new(); - res.extend(sender.use_cache_values); - for key in &keys { - res.extend(values.get(key).map(|value| (key.clone(), value.clone()))); - } - sender.tx.send(Ok(res)).ok(); - } - } - Err(err) => { - for (_, sender) in pending { - sender.tx.send(Err(err.clone())).ok(); - } - } - } -} - /// Trait for batch loading. #[async_trait::async_trait] pub trait Loader: Send + Sync + 'static { @@ -149,42 +119,109 @@ pub trait Loader: Send + Sync + 's async fn load(&self, keys: &[K]) -> Result, Self::Error>; } +struct DataLoaderInner { + requests: Mutex>>, + loader: T, +} + +impl DataLoaderInner { + async fn do_load(&self, disable_cache: bool, (keys, senders): KeysAndSender) + where + K: Send + Sync + Hash + Eq + Clone + 'static, + T: Loader, + { + let tid = TypeId::of::(); + let keys = keys.into_iter().collect::>(); + + match self.loader.load(&keys).await { + Ok(values) => { + // update cache + let mut request = self.requests.lock().unwrap(); + let typed_requests = request + .get_mut(&tid) + .unwrap() + .downcast_mut::>() + .unwrap(); + let disable_cache = typed_requests.disable_cache || disable_cache; + if !disable_cache { + for (key, value) in &values { + typed_requests + .cache_storage + .insert(Cow::Borrowed(key), Cow::Borrowed(value)); + } + } + + // send response + for (keys, sender) in senders { + let mut res = HashMap::new(); + res.extend(sender.use_cache_values); + for key in &keys { + res.extend(values.get(key).map(|value| (key.clone(), value.clone()))); + } + sender.tx.send(Ok(res)).ok(); + } + } + Err(err) => { + for (_, sender) in senders { + sender.tx.send(Err(err.clone())).ok(); + } + } + } + } +} + /// Data loader. /// /// Reference: pub struct DataLoader { - requests: Mutex>>, + inner: Arc>, cache_factory: C, delay: Duration, max_batch_size: usize, - loader: T, disable_cache: AtomicBool, + spawner: Box) + Send + Sync>, } impl DataLoader { /// Use `Loader` to create a [DataLoader] that does not cache records. - pub fn new(loader: T) -> Self { + pub fn new(loader: T, spawner: S) -> Self + where + S: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static, + { Self { + inner: Arc::new(DataLoaderInner { + requests: Mutex::new(Default::default()), + loader, + }), cache_factory: NoCache, - requests: Default::default(), delay: Duration::from_millis(1), max_batch_size: 1000, - loader, disable_cache: false.into(), + spawner: Box::new(move |fut| { + spawner(fut); + }), } } } impl DataLoader { /// Use `Loader` to create a [DataLoader] with a cache factory. - pub fn with_cache(loader: T, cache_factory: C) -> Self { + pub fn with_cache(loader: T, spawner: S, cache_factory: C) -> Self + where + S: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static, + { Self { + inner: Arc::new(DataLoaderInner { + requests: Mutex::new(Default::default()), + loader, + }), cache_factory, - requests: Default::default(), delay: Duration::from_millis(1), max_batch_size: 1000, - loader, disable_cache: false.into(), + spawner: Box::new(move |fut| { + spawner(fut); + }), } } @@ -206,7 +243,7 @@ impl DataLoader { /// Get the loader. #[inline] pub fn loader(&self) -> &T { - &self.loader + &self.inner.loader } /// Enable/Disable cache of all loaders. @@ -221,7 +258,7 @@ impl DataLoader { T: Loader, { let tid = TypeId::of::(); - let mut requests = self.requests.lock().unwrap(); + let mut requests = self.inner.requests.lock().unwrap(); let typed_requests = requests .get_mut(&tid) .unwrap() @@ -240,52 +277,6 @@ impl DataLoader { Ok(values.remove(&key)) } - fn update_cache(&self, values: &HashMap) - where - K: Send + Sync + Hash + Eq + Clone + 'static, - T: Loader, - { - let tid = TypeId::of::(); - let mut requests = self.requests.lock().unwrap(); - let typed_requests = requests - .get_mut(&tid) - .unwrap() - .downcast_mut::>() - .unwrap(); - for (key, value) in values { - typed_requests - .cache_storage - .insert(Cow::Borrowed(key), Cow::Borrowed(value)); - } - } - - async fn immediate_load(&self) - where - K: Send + Sync + Hash + Eq + Clone + 'static, - T: Loader, - { - let tid = TypeId::of::(); - let (disable_cache, (keys, pending)) = { - let mut requests = self.requests.lock().unwrap(); - let typed_requests = requests - .get_mut(&tid) - .unwrap() - .downcast_mut::>() - .unwrap(); - (typed_requests.disable_cache, typed_requests.take()) - }; - if !keys.is_empty() { - let disable_cache = disable_cache || self.disable_cache.load(Ordering::SeqCst); - do_load(&self.loader, keys, pending, |values| { - // Update cache - if !disable_cache { - self.update_cache(values); - } - }) - .await; - } - } - /// Use this `DataLoader` to load some data. pub async fn load_many(&self, keys: I) -> Result, T::Error> where @@ -293,8 +284,8 @@ impl DataLoader { I: IntoIterator, T: Loader, { - enum Action { - ImmediateLoad, + enum Action> { + ImmediateLoad(KeysAndSender), StartFetch, Delay, } @@ -302,7 +293,7 @@ impl DataLoader { let tid = TypeId::of::(); let (action, rx) = { - let mut requests = self.requests.lock().unwrap(); + let mut requests = self.inner.requests.lock().unwrap(); let typed_requests = requests .entry(tid) .or_insert_with(|| Box::new(Requests::::new(&self.cache_factory))) @@ -342,7 +333,7 @@ impl DataLoader { )); if typed_requests.keys.len() >= self.max_batch_size { - (Action::ImmediateLoad, rx) + (Action::ImmediateLoad(typed_requests.take()), rx) } else { ( if !typed_requests.keys.is_empty() && prev_count == 0 { @@ -356,10 +347,35 @@ impl DataLoader { }; match action { - Action::ImmediateLoad => self.immediate_load::().await, + Action::ImmediateLoad(keys) => { + let inner = self.inner.clone(); + let disable_cache = self.disable_cache.load(Ordering::SeqCst); + (self.spawner)(Box::pin( + async move { inner.do_load(disable_cache, keys).await }, + )); + } Action::StartFetch => { - Delay::new(self.delay).await; - self.immediate_load::().await; + let inner = self.inner.clone(); + let disable_cache = self.disable_cache.load(Ordering::SeqCst); + let delay = self.delay; + + (self.spawner)(Box::pin(async move { + Delay::new(delay).await; + + let keys = { + let mut request = inner.requests.lock().unwrap(); + let typed_requests = request + .get_mut(&tid) + .unwrap() + .downcast_mut::>() + .unwrap(); + typed_requests.take() + }; + + if !keys.0.is_empty() { + inner.do_load(disable_cache, keys).await + } + })) } Action::Delay => {} } @@ -377,7 +393,7 @@ impl DataLoader { T: Loader, { let tid = TypeId::of::(); - let mut requests = self.requests.lock().unwrap(); + let mut requests = self.inner.requests.lock().unwrap(); let typed_requests = requests .entry(tid) .or_insert_with(|| Box::new(Requests::::new(&self.cache_factory))) @@ -410,7 +426,7 @@ impl DataLoader { T: Loader, { let tid = TypeId::of::(); - let mut requests = self.requests.lock().unwrap(); + let mut requests = self.inner.requests.lock().unwrap(); let typed_requests = requests .entry(tid) .or_insert_with(|| Box::new(Requests::::new(&self.cache_factory))) @@ -452,7 +468,7 @@ mod tests { #[tokio::test] async fn test_dataloader() { - let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10)); + let loader = Arc::new(DataLoader::new(MyLoader, tokio::spawn).max_batch_size(10)); assert_eq!( futures_util::future::try_join_all((0..100i32).map({ let loader = loader.clone(); @@ -482,7 +498,7 @@ mod tests { #[tokio::test] async fn test_duplicate_keys() { - let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10)); + let loader = Arc::new(DataLoader::new(MyLoader, tokio::spawn).max_batch_size(10)); assert_eq!( futures_util::future::try_join_all([1, 3, 5, 1, 7, 8, 3, 7].iter().copied().map({ let loader = loader.clone(); @@ -503,13 +519,13 @@ mod tests { #[tokio::test] async fn test_dataloader_load_empty() { - let loader = DataLoader::new(MyLoader); + let loader = DataLoader::new(MyLoader, tokio::spawn); assert!(loader.load_many::(vec![]).await.unwrap().is_empty()); } #[tokio::test] async fn test_dataloader_with_cache() { - let loader = DataLoader::with_cache(MyLoader, HashMapCache::default()); + let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default()); loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; // All from the cache @@ -540,7 +556,11 @@ mod tests { #[tokio::test] async fn test_dataloader_with_cache_hashmap_fnv() { - let loader = DataLoader::with_cache(MyLoader, HashMapCache::::new()); + let loader = DataLoader::with_cache( + MyLoader, + tokio::spawn, + HashMapCache::::new(), + ); loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; // All from the cache @@ -571,7 +591,7 @@ mod tests { #[tokio::test] async fn test_dataloader_disable_all_cache() { - let loader = DataLoader::with_cache(MyLoader, HashMapCache::default()); + let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default()); loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; // All from the loader @@ -591,7 +611,7 @@ mod tests { #[tokio::test] async fn test_dataloader_disable_cache() { - let loader = DataLoader::with_cache(MyLoader, HashMapCache::default()); + let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default()); loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await; // All from the loader @@ -608,4 +628,35 @@ mod tests { vec![(1, 10), (2, 20), (3, 30)].into_iter().collect() ); } + + #[tokio::test] + async fn test_dataloader_dead_lock() { + struct MyDelayLoader; + + #[async_trait::async_trait] + impl Loader for MyDelayLoader { + type Value = i32; + type Error = (); + + async fn load(&self, keys: &[i32]) -> Result, Self::Error> { + tokio::time::sleep(Duration::from_secs(1)).await; + Ok(keys.iter().copied().map(|k| (k, k)).collect()) + } + } + + let loader = Arc::new( + DataLoader::with_cache(MyDelayLoader, tokio::spawn, NoCache) + .delay(Duration::from_secs(1)), + ); + let handle = tokio::spawn({ + let loader = loader.clone(); + async move { + loader.load_many(vec![1, 2, 3]).await.unwrap(); + } + }); + + tokio::time::sleep(Duration::from_millis(500)).await; + handle.abort(); + loader.load_many(vec![4, 5, 6]).await.unwrap(); + } } diff --git a/tests/federation.rs b/tests/federation.rs index 82de3365..ef2dc53d 100644 --- a/tests/federation.rs +++ b/tests/federation.rs @@ -199,7 +199,7 @@ pub async fn test_find_entity_with_context() { } let schema = Schema::build(Query, EmptyMutation, EmptySubscription) - .data(DataLoader::new(MyLoader)) + .data(DataLoader::new(MyLoader, tokio::spawn)) .finish(); let query = r#"{ _entities(representations: [