Fix deadlock in dataloader. #555
This commit is contained in:
parent
214f7b297a
commit
3aecc43365
|
@ -43,7 +43,7 @@
|
|||
//! v5: value(n: 5)
|
||||
//! }
|
||||
//! "#;
|
||||
//! let request = Request::new(query).data(DataLoader::new(MyLoader));
|
||||
//! let request = Request::new(query).data(DataLoader::new(MyLoader, tokio::spawn));
|
||||
//! let res = schema.execute(request).await.into_result().unwrap().data;
|
||||
//!
|
||||
//! assert_eq!(res, value!({
|
||||
|
@ -63,12 +63,13 @@ use std::borrow::Cow;
|
|||
use std::collections::{HashMap, HashSet};
|
||||
use std::hash::Hash;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Mutex;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
|
||||
use fnv::FnvHashMap;
|
||||
use futures_channel::oneshot;
|
||||
use futures_timer::Delay;
|
||||
use futures_util::future::BoxFuture;
|
||||
|
||||
pub use cache::{CacheFactory, CacheStorage, HashMapCache, LruCache, NoCache};
|
||||
|
||||
|
@ -105,37 +106,6 @@ impl<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> Requests<K, T>
|
|||
}
|
||||
}
|
||||
|
||||
async fn do_load<K, T, F>(
|
||||
loader: &T,
|
||||
keys: HashSet<K>,
|
||||
pending: Vec<(HashSet<K>, ResSender<K, T>)>,
|
||||
f: F,
|
||||
) where
|
||||
K: Send + Sync + Hash + Eq + Clone + 'static,
|
||||
T: Loader<K>,
|
||||
F: FnOnce(&HashMap<K, T::Value>),
|
||||
{
|
||||
let keys = keys.into_iter().collect::<Vec<_>>();
|
||||
match loader.load(&keys).await {
|
||||
Ok(values) => {
|
||||
f(&values);
|
||||
for (keys, sender) in pending {
|
||||
let mut res = HashMap::new();
|
||||
res.extend(sender.use_cache_values);
|
||||
for key in &keys {
|
||||
res.extend(values.get(key).map(|value| (key.clone(), value.clone())));
|
||||
}
|
||||
sender.tx.send(Ok(res)).ok();
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
for (_, sender) in pending {
|
||||
sender.tx.send(Err(err.clone())).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for batch loading.
|
||||
#[async_trait::async_trait]
|
||||
pub trait Loader<K: Send + Sync + Hash + Eq + Clone + 'static>: Send + Sync + 'static {
|
||||
|
@ -149,42 +119,109 @@ pub trait Loader<K: Send + Sync + Hash + Eq + Clone + 'static>: Send + Sync + 's
|
|||
async fn load(&self, keys: &[K]) -> Result<HashMap<K, Self::Value>, Self::Error>;
|
||||
}
|
||||
|
||||
struct DataLoaderInner<T> {
|
||||
requests: Mutex<FnvHashMap<TypeId, Box<dyn Any + Sync + Send>>>,
|
||||
loader: T,
|
||||
}
|
||||
|
||||
impl<T> DataLoaderInner<T> {
|
||||
async fn do_load<K>(&self, disable_cache: bool, (keys, senders): KeysAndSender<K, T>)
|
||||
where
|
||||
K: Send + Sync + Hash + Eq + Clone + 'static,
|
||||
T: Loader<K>,
|
||||
{
|
||||
let tid = TypeId::of::<K>();
|
||||
let keys = keys.into_iter().collect::<Vec<_>>();
|
||||
|
||||
match self.loader.load(&keys).await {
|
||||
Ok(values) => {
|
||||
// update cache
|
||||
let mut request = self.requests.lock().unwrap();
|
||||
let typed_requests = request
|
||||
.get_mut(&tid)
|
||||
.unwrap()
|
||||
.downcast_mut::<Requests<K, T>>()
|
||||
.unwrap();
|
||||
let disable_cache = typed_requests.disable_cache || disable_cache;
|
||||
if !disable_cache {
|
||||
for (key, value) in &values {
|
||||
typed_requests
|
||||
.cache_storage
|
||||
.insert(Cow::Borrowed(key), Cow::Borrowed(value));
|
||||
}
|
||||
}
|
||||
|
||||
// send response
|
||||
for (keys, sender) in senders {
|
||||
let mut res = HashMap::new();
|
||||
res.extend(sender.use_cache_values);
|
||||
for key in &keys {
|
||||
res.extend(values.get(key).map(|value| (key.clone(), value.clone())));
|
||||
}
|
||||
sender.tx.send(Ok(res)).ok();
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
for (_, sender) in senders {
|
||||
sender.tx.send(Err(err.clone())).ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Data loader.
|
||||
///
|
||||
/// Reference: <https://github.com/facebook/dataloader>
|
||||
pub struct DataLoader<T, C = NoCache> {
|
||||
requests: Mutex<FnvHashMap<TypeId, Box<dyn Any + Sync + Send>>>,
|
||||
inner: Arc<DataLoaderInner<T>>,
|
||||
cache_factory: C,
|
||||
delay: Duration,
|
||||
max_batch_size: usize,
|
||||
loader: T,
|
||||
disable_cache: AtomicBool,
|
||||
spawner: Box<dyn Fn(BoxFuture<'static, ()>) + Send + Sync>,
|
||||
}
|
||||
|
||||
impl<T> DataLoader<T, NoCache> {
|
||||
/// Use `Loader` to create a [DataLoader] that does not cache records.
|
||||
pub fn new(loader: T) -> Self {
|
||||
pub fn new<S, R>(loader: T, spawner: S) -> Self
|
||||
where
|
||||
S: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
inner: Arc::new(DataLoaderInner {
|
||||
requests: Mutex::new(Default::default()),
|
||||
loader,
|
||||
}),
|
||||
cache_factory: NoCache,
|
||||
requests: Default::default(),
|
||||
delay: Duration::from_millis(1),
|
||||
max_batch_size: 1000,
|
||||
loader,
|
||||
disable_cache: false.into(),
|
||||
spawner: Box::new(move |fut| {
|
||||
spawner(fut);
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, C: CacheFactory> DataLoader<T, C> {
|
||||
/// Use `Loader` to create a [DataLoader] with a cache factory.
|
||||
pub fn with_cache(loader: T, cache_factory: C) -> Self {
|
||||
pub fn with_cache<S, R>(loader: T, spawner: S, cache_factory: C) -> Self
|
||||
where
|
||||
S: Fn(BoxFuture<'static, ()>) -> R + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
inner: Arc::new(DataLoaderInner {
|
||||
requests: Mutex::new(Default::default()),
|
||||
loader,
|
||||
}),
|
||||
cache_factory,
|
||||
requests: Default::default(),
|
||||
delay: Duration::from_millis(1),
|
||||
max_batch_size: 1000,
|
||||
loader,
|
||||
disable_cache: false.into(),
|
||||
spawner: Box::new(move |fut| {
|
||||
spawner(fut);
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -206,7 +243,7 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
|
|||
/// Get the loader.
|
||||
#[inline]
|
||||
pub fn loader(&self) -> &T {
|
||||
&self.loader
|
||||
&self.inner.loader
|
||||
}
|
||||
|
||||
/// Enable/Disable cache of all loaders.
|
||||
|
@ -221,7 +258,7 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
|
|||
T: Loader<K>,
|
||||
{
|
||||
let tid = TypeId::of::<K>();
|
||||
let mut requests = self.requests.lock().unwrap();
|
||||
let mut requests = self.inner.requests.lock().unwrap();
|
||||
let typed_requests = requests
|
||||
.get_mut(&tid)
|
||||
.unwrap()
|
||||
|
@ -240,52 +277,6 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
|
|||
Ok(values.remove(&key))
|
||||
}
|
||||
|
||||
fn update_cache<K>(&self, values: &HashMap<K, T::Value>)
|
||||
where
|
||||
K: Send + Sync + Hash + Eq + Clone + 'static,
|
||||
T: Loader<K>,
|
||||
{
|
||||
let tid = TypeId::of::<K>();
|
||||
let mut requests = self.requests.lock().unwrap();
|
||||
let typed_requests = requests
|
||||
.get_mut(&tid)
|
||||
.unwrap()
|
||||
.downcast_mut::<Requests<K, T>>()
|
||||
.unwrap();
|
||||
for (key, value) in values {
|
||||
typed_requests
|
||||
.cache_storage
|
||||
.insert(Cow::Borrowed(key), Cow::Borrowed(value));
|
||||
}
|
||||
}
|
||||
|
||||
async fn immediate_load<K>(&self)
|
||||
where
|
||||
K: Send + Sync + Hash + Eq + Clone + 'static,
|
||||
T: Loader<K>,
|
||||
{
|
||||
let tid = TypeId::of::<K>();
|
||||
let (disable_cache, (keys, pending)) = {
|
||||
let mut requests = self.requests.lock().unwrap();
|
||||
let typed_requests = requests
|
||||
.get_mut(&tid)
|
||||
.unwrap()
|
||||
.downcast_mut::<Requests<K, T>>()
|
||||
.unwrap();
|
||||
(typed_requests.disable_cache, typed_requests.take())
|
||||
};
|
||||
if !keys.is_empty() {
|
||||
let disable_cache = disable_cache || self.disable_cache.load(Ordering::SeqCst);
|
||||
do_load(&self.loader, keys, pending, |values| {
|
||||
// Update cache
|
||||
if !disable_cache {
|
||||
self.update_cache(values);
|
||||
}
|
||||
})
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Use this `DataLoader` to load some data.
|
||||
pub async fn load_many<K, I>(&self, keys: I) -> Result<HashMap<K, T::Value>, T::Error>
|
||||
where
|
||||
|
@ -293,8 +284,8 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
|
|||
I: IntoIterator<Item = K>,
|
||||
T: Loader<K>,
|
||||
{
|
||||
enum Action {
|
||||
ImmediateLoad,
|
||||
enum Action<K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader<K>> {
|
||||
ImmediateLoad(KeysAndSender<K, T>),
|
||||
StartFetch,
|
||||
Delay,
|
||||
}
|
||||
|
@ -302,7 +293,7 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
|
|||
let tid = TypeId::of::<K>();
|
||||
|
||||
let (action, rx) = {
|
||||
let mut requests = self.requests.lock().unwrap();
|
||||
let mut requests = self.inner.requests.lock().unwrap();
|
||||
let typed_requests = requests
|
||||
.entry(tid)
|
||||
.or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
|
||||
|
@ -342,7 +333,7 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
|
|||
));
|
||||
|
||||
if typed_requests.keys.len() >= self.max_batch_size {
|
||||
(Action::ImmediateLoad, rx)
|
||||
(Action::ImmediateLoad(typed_requests.take()), rx)
|
||||
} else {
|
||||
(
|
||||
if !typed_requests.keys.is_empty() && prev_count == 0 {
|
||||
|
@ -356,10 +347,35 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
|
|||
};
|
||||
|
||||
match action {
|
||||
Action::ImmediateLoad => self.immediate_load::<K>().await,
|
||||
Action::ImmediateLoad(keys) => {
|
||||
let inner = self.inner.clone();
|
||||
let disable_cache = self.disable_cache.load(Ordering::SeqCst);
|
||||
(self.spawner)(Box::pin(
|
||||
async move { inner.do_load(disable_cache, keys).await },
|
||||
));
|
||||
}
|
||||
Action::StartFetch => {
|
||||
Delay::new(self.delay).await;
|
||||
self.immediate_load::<K>().await;
|
||||
let inner = self.inner.clone();
|
||||
let disable_cache = self.disable_cache.load(Ordering::SeqCst);
|
||||
let delay = self.delay;
|
||||
|
||||
(self.spawner)(Box::pin(async move {
|
||||
Delay::new(delay).await;
|
||||
|
||||
let keys = {
|
||||
let mut request = inner.requests.lock().unwrap();
|
||||
let typed_requests = request
|
||||
.get_mut(&tid)
|
||||
.unwrap()
|
||||
.downcast_mut::<Requests<K, T>>()
|
||||
.unwrap();
|
||||
typed_requests.take()
|
||||
};
|
||||
|
||||
if !keys.0.is_empty() {
|
||||
inner.do_load(disable_cache, keys).await
|
||||
}
|
||||
}))
|
||||
}
|
||||
Action::Delay => {}
|
||||
}
|
||||
|
@ -377,7 +393,7 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
|
|||
T: Loader<K>,
|
||||
{
|
||||
let tid = TypeId::of::<K>();
|
||||
let mut requests = self.requests.lock().unwrap();
|
||||
let mut requests = self.inner.requests.lock().unwrap();
|
||||
let typed_requests = requests
|
||||
.entry(tid)
|
||||
.or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
|
||||
|
@ -410,7 +426,7 @@ impl<T, C: CacheFactory> DataLoader<T, C> {
|
|||
T: Loader<K>,
|
||||
{
|
||||
let tid = TypeId::of::<K>();
|
||||
let mut requests = self.requests.lock().unwrap();
|
||||
let mut requests = self.inner.requests.lock().unwrap();
|
||||
let typed_requests = requests
|
||||
.entry(tid)
|
||||
.or_insert_with(|| Box::new(Requests::<K, T>::new(&self.cache_factory)))
|
||||
|
@ -452,7 +468,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_dataloader() {
|
||||
let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10));
|
||||
let loader = Arc::new(DataLoader::new(MyLoader, tokio::spawn).max_batch_size(10));
|
||||
assert_eq!(
|
||||
futures_util::future::try_join_all((0..100i32).map({
|
||||
let loader = loader.clone();
|
||||
|
@ -482,7 +498,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_duplicate_keys() {
|
||||
let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10));
|
||||
let loader = Arc::new(DataLoader::new(MyLoader, tokio::spawn).max_batch_size(10));
|
||||
assert_eq!(
|
||||
futures_util::future::try_join_all([1, 3, 5, 1, 7, 8, 3, 7].iter().copied().map({
|
||||
let loader = loader.clone();
|
||||
|
@ -503,13 +519,13 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_dataloader_load_empty() {
|
||||
let loader = DataLoader::new(MyLoader);
|
||||
let loader = DataLoader::new(MyLoader, tokio::spawn);
|
||||
assert!(loader.load_many::<i32, _>(vec![]).await.unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dataloader_with_cache() {
|
||||
let loader = DataLoader::with_cache(MyLoader, HashMapCache::default());
|
||||
let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default());
|
||||
loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
|
||||
|
||||
// All from the cache
|
||||
|
@ -540,7 +556,11 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_dataloader_with_cache_hashmap_fnv() {
|
||||
let loader = DataLoader::with_cache(MyLoader, HashMapCache::<FnvBuildHasher>::new());
|
||||
let loader = DataLoader::with_cache(
|
||||
MyLoader,
|
||||
tokio::spawn,
|
||||
HashMapCache::<FnvBuildHasher>::new(),
|
||||
);
|
||||
loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
|
||||
|
||||
// All from the cache
|
||||
|
@ -571,7 +591,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_dataloader_disable_all_cache() {
|
||||
let loader = DataLoader::with_cache(MyLoader, HashMapCache::default());
|
||||
let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default());
|
||||
loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
|
||||
|
||||
// All from the loader
|
||||
|
@ -591,7 +611,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn test_dataloader_disable_cache() {
|
||||
let loader = DataLoader::with_cache(MyLoader, HashMapCache::default());
|
||||
let loader = DataLoader::with_cache(MyLoader, tokio::spawn, HashMapCache::default());
|
||||
loader.feed_many(vec![(1, 10), (2, 20), (3, 30)]).await;
|
||||
|
||||
// All from the loader
|
||||
|
@ -608,4 +628,35 @@ mod tests {
|
|||
vec![(1, 10), (2, 20), (3, 30)].into_iter().collect()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dataloader_dead_lock() {
|
||||
struct MyDelayLoader;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Loader<i32> for MyDelayLoader {
|
||||
type Value = i32;
|
||||
type Error = ();
|
||||
|
||||
async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
|
||||
tokio::time::sleep(Duration::from_secs(1)).await;
|
||||
Ok(keys.iter().copied().map(|k| (k, k)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
let loader = Arc::new(
|
||||
DataLoader::with_cache(MyDelayLoader, tokio::spawn, NoCache)
|
||||
.delay(Duration::from_secs(1)),
|
||||
);
|
||||
let handle = tokio::spawn({
|
||||
let loader = loader.clone();
|
||||
async move {
|
||||
loader.load_many(vec![1, 2, 3]).await.unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
handle.abort();
|
||||
loader.load_many(vec![4, 5, 6]).await.unwrap();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -199,7 +199,7 @@ pub async fn test_find_entity_with_context() {
|
|||
}
|
||||
|
||||
let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
|
||||
.data(DataLoader::new(MyLoader))
|
||||
.data(DataLoader::new(MyLoader, tokio::spawn))
|
||||
.finish();
|
||||
let query = r#"{
|
||||
_entities(representations: [
|
||||
|
|
Loading…
Reference in New Issue
Block a user