Improve dataloader
This commit is contained in:
parent
1f93298164
commit
3e6339e630
|
@ -75,12 +75,11 @@ struct UserNameLoader {
|
|||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Loader for UserNameLoader {
|
||||
type Key = u64;
|
||||
impl Loader<u64> for UserNameLoader {
|
||||
type Value = String;
|
||||
type Error = sqlx::Error;
|
||||
|
||||
async fn load(&self, keys: HashSet<Self::Key>) -> Result<HashMap<Self::Key, Self::Value>, Self::Error> {
|
||||
|
||||
async fn load(&self, keys: &[u64]) -> Result<HashMap<u64, Self::Value>, Self::Error> {
|
||||
let pool = ctx.data_unchecked::<Pool<Postgres>>();
|
||||
let query = format!("SELECT name FROM user WHERE id IN ({})", keys.iter().join(","));
|
||||
Ok(sqlx::query_as(query)
|
||||
|
@ -110,3 +109,33 @@ In the end, only two SQLs are needed to query the results we want!
|
|||
SELECT id, todo, user_id FROM todo
|
||||
SELECT name FROM user WHERE id IN (1, 2, 3, 4)
|
||||
```
|
||||
|
||||
## Implement multiple data types
|
||||
|
||||
You can implement multiple data types for the same `Loader`, like this:
|
||||
|
||||
```rust
|
||||
struct PostgresLoader {
|
||||
pool: sqlx::Pool<Postgres>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Loader<UserId> for PostgresLoader {
|
||||
type Value = User;
|
||||
type Error = sqlx::Error;
|
||||
|
||||
async fn load(&self, keys: &[UserId]) -> Result<HashMap<UserId, Self::Value>, Self::Error> {
|
||||
// Load users from database
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Loader<TodoId> for PostgresLoader {
|
||||
type Value = Todo;
|
||||
type Error = sqlx::Error;
|
||||
|
||||
async fn load(&self, keys: &[TodoId]) -> Result<HashMap<TodoId, Self::Value>, Self::Error> {
|
||||
// Load todos from database
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
|
@ -73,12 +73,11 @@ struct UserNameLoader {
|
|||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Loader for UserNameLoader {
|
||||
type Key = u64;
|
||||
impl Loader<u64> for UserNameLoader {
|
||||
type Value = String;
|
||||
type Error = sqlx::Error;
|
||||
|
||||
async fn load(&self, keys: HashSet<Self::Key>) -> Result<HashMap<Self::Key, Self::Value>, Self::Error> {
|
||||
async fn load(&self, keys: &[u64]) -> Result<HashMap<u64, Self::Value>, Self::Error> {
|
||||
let pool = ctx.data_unchecked::<Pool<Postgres>>();
|
||||
let query = format!("SELECT name FROM user WHERE id IN ({})", keys.iter().join(","));
|
||||
Ok(sqlx::query_as(query)
|
||||
|
@ -108,3 +107,33 @@ impl User {
|
|||
SELECT id, todo, user_id FROM todo
|
||||
SELECT name FROM user WHERE id IN (1, 2, 3, 4)
|
||||
```
|
||||
|
||||
## 同一个Loader支持多种数据类型
|
||||
|
||||
你可以为同一个`Loader`实现多种数据类型,就像下面这样:
|
||||
|
||||
```rust
|
||||
struct PostgresLoader {
|
||||
pool: sqlx::Pool<Postgres>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Loader<UserId> for PostgresLoader {
|
||||
type Value = User;
|
||||
type Error = sqlx::Error;
|
||||
|
||||
async fn load(&self, keys: &[UserId]) -> Result<HashMap<UserId, Self::Value>, Self::Error> {
|
||||
// 从数据库中加载User
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Loader<TodoId> for PostgresLoader {
|
||||
type Value = Todo;
|
||||
type Error = sqlx::Error;
|
||||
|
||||
async fn load(&self, keys: &[TodoId]) -> Result<HashMap<TodoId, Self::Value>, Self::Error> {
|
||||
// 从数据库中加载Todo
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
|
@ -1,92 +0,0 @@
|
|||
use std::collections::{HashMap, HashSet};
|
||||
use std::hash::Hash;
|
||||
use std::sync::Mutex;
|
||||
|
||||
use super::Loader;
|
||||
|
||||
/// Cache storage for a loader.
|
||||
pub trait CacheStorage: Send + Sync + 'static {
|
||||
/// Type of key.
|
||||
type Key: Send + Hash + Eq + Clone + 'static;
|
||||
|
||||
/// Type of value.
|
||||
type Value: Send + Clone + 'static;
|
||||
|
||||
/// Load `value` from cache by `key`.
|
||||
fn get(&self, key: &Self::Key) -> Option<Self::Value>;
|
||||
|
||||
/// Put `value` to cache by `key`.
|
||||
fn set(&self, key: Self::Key, value: Self::Value);
|
||||
}
|
||||
|
||||
/// Loader for the [cached](trait.LoaderExt.html#method.cached) method.
|
||||
pub struct CachedLoader<C, T> {
|
||||
loader: T,
|
||||
cache: C,
|
||||
}
|
||||
|
||||
impl<C, T> CachedLoader<C, T>
|
||||
where
|
||||
C: CacheStorage<Key = T::Key, Value = T::Value>,
|
||||
T: Loader,
|
||||
{
|
||||
/// Create a loader that can cache data.
|
||||
pub fn new(loader: T, cache: C) -> Self {
|
||||
Self { cache, loader }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<C, T> Loader for CachedLoader<C, T>
|
||||
where
|
||||
C: CacheStorage<Key = T::Key, Value = T::Value>,
|
||||
T: Loader,
|
||||
{
|
||||
type Key = T::Key;
|
||||
type Value = T::Value;
|
||||
type Error = T::Error;
|
||||
|
||||
async fn load(
|
||||
&self,
|
||||
mut keys: HashSet<Self::Key>,
|
||||
) -> Result<HashMap<Self::Key, Self::Value>, Self::Error> {
|
||||
let mut res = HashMap::new();
|
||||
for key in &keys {
|
||||
if let Some(value) = self.cache.get(key) {
|
||||
res.insert(key.clone(), value);
|
||||
}
|
||||
}
|
||||
for key in res.keys() {
|
||||
keys.remove(key);
|
||||
}
|
||||
let values = self.loader.load(keys).await?;
|
||||
for (key, value) in &values {
|
||||
self.cache.set(key.clone(), value.clone());
|
||||
}
|
||||
res.extend(values);
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory-based LRU cache.
|
||||
pub struct LruCache<T: CacheStorage>(Mutex<lru::LruCache<T::Key, T::Value>>);
|
||||
|
||||
impl<T: CacheStorage> LruCache<T> {
|
||||
/// Creates a new LRU Cache that holds at most `cap` items.
|
||||
pub fn new(cap: usize) -> Self {
|
||||
Self(Mutex::new(lru::LruCache::new(cap)))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CacheStorage> CacheStorage for LruCache<T> {
|
||||
type Key = T::Key;
|
||||
type Value = T::Value;
|
||||
|
||||
fn get(&self, key: &Self::Key) -> Option<Self::Value> {
|
||||
self.0.lock().unwrap().get(key).cloned()
|
||||
}
|
||||
|
||||
fn set(&self, key: Self::Key, value: Self::Value) {
|
||||
self.0.lock().unwrap().put(key, value);
|
||||
}
|
||||
}
|
|
@ -13,12 +13,11 @@
|
|||
//! struct MyLoader;
|
||||
//!
|
||||
//! #[async_trait::async_trait]
|
||||
//! impl Loader for MyLoader {
|
||||
//! type Key = i32;
|
||||
//! impl Loader<i32> for MyLoader {
|
||||
//! type Value = String;
|
||||
//! type Error = Infallible;
|
||||
//!
|
||||
//! async fn load(&self, keys: HashSet<Self::Key>) -> Result<HashMap<Self::Key, Self::Value>, Self::Error> {
|
||||
//! async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
|
||||
//! // Use `MyLoader` to load data.
|
||||
//! Ok(keys.iter().copied().map(|n| (n, n.to_string())).collect())
|
||||
//! }
|
||||
|
@ -34,11 +33,8 @@
|
|||
//! }
|
||||
//!
|
||||
//! async_std::task::block_on(async move {
|
||||
//! let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
|
||||
//! .data(DataLoader::new(MyLoader))
|
||||
//! .finish();
|
||||
//!
|
||||
//! let res = schema.execute(r#"
|
||||
//! let schema = Schema::new(Query, EmptyMutation, EmptySubscription);
|
||||
//! let query = r#"
|
||||
//! {
|
||||
//! v1: value(n: 1)
|
||||
//! v2: value(n: 2)
|
||||
|
@ -46,7 +42,9 @@
|
|||
//! v4: value(n: 4)
|
||||
//! v5: value(n: 5)
|
||||
//! }
|
||||
//! "#).await.into_result().unwrap().data;
|
||||
//! "#;
|
||||
//! let request = Request::new(query).data(DataLoader::new(MyLoader));
|
||||
//! let res = schema.execute(request).await.into_result().unwrap().data;
|
||||
//!
|
||||
//! assert_eq!(res, value!({
|
||||
//! "v1": "1",
|
||||
|
@ -59,8 +57,7 @@
|
|||
//!
|
||||
//! ```
|
||||
|
||||
mod cache;
|
||||
|
||||
use std::any::{Any, TypeId};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::hash::Hash;
|
||||
use std::time::Duration;
|
||||
|
@ -69,18 +66,17 @@ use futures_channel::oneshot;
|
|||
use futures_timer::Delay;
|
||||
use futures_util::lock::Mutex;
|
||||
|
||||
pub use cache::{CacheStorage, CachedLoader, LruCache};
|
||||
use fnv::FnvHashMap;
|
||||
|
||||
type ResSender<T> = oneshot::Sender<
|
||||
Result<HashMap<<T as Loader>::Key, <T as Loader>::Value>, <T as Loader>::Error>,
|
||||
>;
|
||||
type ResSender<K, T> =
|
||||
oneshot::Sender<Result<HashMap<K, <T as Loader<K>>::Value>, <T as Loader<K>>::Error>>;
|
||||
|
||||
struct Requests<T: Loader> {
|
||||
keys: HashSet<T::Key>,
|
||||
pending: Vec<(HashSet<T::Key>, ResSender<T>)>,
|
||||
struct Requests<K: Send + Hash + Eq + Clone + 'static, T: Loader<K>> {
|
||||
keys: HashSet<K>,
|
||||
pending: Vec<(HashSet<K>, ResSender<K, T>)>,
|
||||
}
|
||||
|
||||
impl<T: Loader> Default for Requests<T> {
|
||||
impl<K: Send + Hash + Eq + Clone + 'static, T: Loader<K>> Default for Requests<K, T> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
keys: Default::default(),
|
||||
|
@ -89,9 +85,10 @@ impl<T: Loader> Default for Requests<T> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<T: Loader> Requests<T> {
|
||||
impl<K: Send + Hash + Eq + Clone + 'static, T: Loader<K>> Requests<K, T> {
|
||||
async fn load(self, loader: &T) {
|
||||
match loader.load(self.keys).await {
|
||||
let keys = self.keys.into_iter().collect::<Vec<_>>();
|
||||
match loader.load(&keys).await {
|
||||
Ok(values) => {
|
||||
for (keys, tx) in self.pending {
|
||||
let mut res = HashMap::new();
|
||||
|
@ -112,10 +109,7 @@ impl<T: Loader> Requests<T> {
|
|||
|
||||
/// Trait for batch loading.
|
||||
#[async_trait::async_trait]
|
||||
pub trait Loader: Send + Sync + 'static {
|
||||
/// Type of key.
|
||||
type Key: Send + Hash + Eq + Clone + 'static;
|
||||
|
||||
pub trait Loader<K: Send + Hash + Eq + Clone + 'static>: Send + Sync + 'static {
|
||||
/// type of value.
|
||||
type Value: Send + Clone + 'static;
|
||||
|
||||
|
@ -123,23 +117,20 @@ pub trait Loader: Send + Sync + 'static {
|
|||
type Error: Send + Clone + 'static;
|
||||
|
||||
/// Load the data set specified by the `keys`.
|
||||
async fn load(
|
||||
&self,
|
||||
keys: HashSet<Self::Key>,
|
||||
) -> Result<HashMap<Self::Key, Self::Value>, Self::Error>;
|
||||
async fn load(&self, keys: &[K]) -> Result<HashMap<K, Self::Value>, Self::Error>;
|
||||
}
|
||||
|
||||
/// Data loader.
|
||||
///
|
||||
/// Reference: https://github.com/facebook/dataloader
|
||||
pub struct DataLoader<T: Loader> {
|
||||
requests: Mutex<Requests<T>>,
|
||||
pub struct DataLoader<T> {
|
||||
requests: Mutex<FnvHashMap<TypeId, Box<dyn Any + Sync + Send>>>,
|
||||
delay: Duration,
|
||||
max_batch_size: usize,
|
||||
loader: T,
|
||||
}
|
||||
|
||||
impl<T: Loader> DataLoader<T> {
|
||||
impl<T> DataLoader<T> {
|
||||
/// Create a DataLoader with the `Loader` trait.
|
||||
pub fn new(loader: T) -> Self {
|
||||
Self {
|
||||
|
@ -166,41 +157,63 @@ impl<T: Loader> DataLoader<T> {
|
|||
}
|
||||
|
||||
/// Use this `DataLoader` load a data.
|
||||
pub async fn load_one(&self, key: T::Key) -> Result<Option<T::Value>, T::Error> {
|
||||
pub async fn load_one<K>(&self, key: K) -> Result<Option<T::Value>, T::Error>
|
||||
where
|
||||
K: Send + Sync + Hash + Eq + Clone + 'static,
|
||||
T: Loader<K>,
|
||||
{
|
||||
let mut values = self.load_many(std::iter::once(key.clone())).await?;
|
||||
Ok(values.remove(&key))
|
||||
}
|
||||
|
||||
/// Use this `DataLoader` to load some data.
|
||||
pub async fn load_many(
|
||||
pub async fn load_many<K>(
|
||||
&self,
|
||||
keys: impl Iterator<Item = T::Key>,
|
||||
) -> Result<HashMap<T::Key, T::Value>, T::Error> {
|
||||
keys: impl Iterator<Item = K>,
|
||||
) -> Result<HashMap<K, T::Value>, T::Error>
|
||||
where
|
||||
K: Send + Sync + Hash + Eq + Clone + 'static,
|
||||
T: Loader<K>,
|
||||
{
|
||||
let tid = TypeId::of::<K>();
|
||||
|
||||
let (start_fetch, rx) = {
|
||||
let mut requests = self.requests.lock().await;
|
||||
let prev_count = requests.keys.len();
|
||||
let typed_requests = requests
|
||||
.entry(tid)
|
||||
.or_insert_with(|| Box::new(Requests::<K, T>::default()))
|
||||
.downcast_mut::<Requests<K, T>>()
|
||||
.unwrap();
|
||||
let prev_count = typed_requests.keys.len();
|
||||
let keys = keys.collect::<HashSet<_>>();
|
||||
requests.keys.extend(keys.clone());
|
||||
if requests.keys.len() == prev_count {
|
||||
typed_requests.keys.extend(keys.clone());
|
||||
if typed_requests.keys.len() == prev_count {
|
||||
return Ok(Default::default());
|
||||
}
|
||||
let (tx, rx) = oneshot::channel();
|
||||
requests.pending.push((keys, tx));
|
||||
if requests.keys.len() >= self.max_batch_size {
|
||||
let r = std::mem::take(&mut *requests);
|
||||
typed_requests.pending.push((keys, tx));
|
||||
if typed_requests.keys.len() >= self.max_batch_size {
|
||||
let r = std::mem::take(&mut *typed_requests);
|
||||
drop(requests);
|
||||
r.load(&self.loader).await;
|
||||
(false, rx)
|
||||
} else {
|
||||
(!requests.keys.is_empty() && prev_count == 0, rx)
|
||||
(!typed_requests.keys.is_empty() && prev_count == 0, rx)
|
||||
}
|
||||
};
|
||||
|
||||
if start_fetch {
|
||||
Delay::new(self.delay).await;
|
||||
let requests = std::mem::take(&mut *self.requests.lock().await);
|
||||
if !requests.keys.is_empty() {
|
||||
requests.load(&self.loader).await;
|
||||
let mut requests = self.requests.lock().await;
|
||||
let typed_requests = requests
|
||||
.get_mut(&tid)
|
||||
.unwrap()
|
||||
.downcast_mut::<Requests<K, T>>()
|
||||
.unwrap();
|
||||
let typed_requests = std::mem::take(typed_requests);
|
||||
drop(requests);
|
||||
if !typed_requests.keys.is_empty() {
|
||||
typed_requests.load(&self.loader).await;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -211,7 +224,6 @@ impl<T: Loader> DataLoader<T> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use async_std::task;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[async_std::test]
|
||||
|
@ -219,30 +231,51 @@ mod tests {
|
|||
struct MyLoader;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Loader for MyLoader {
|
||||
type Key = i32;
|
||||
impl Loader<i32> for MyLoader {
|
||||
type Value = i32;
|
||||
type Error = ();
|
||||
|
||||
async fn load(
|
||||
&self,
|
||||
keys: HashSet<Self::Key>,
|
||||
) -> Result<HashMap<Self::Key, Self::Value>, Self::Error> {
|
||||
async fn load(&self, keys: &[i32]) -> Result<HashMap<i32, Self::Value>, Self::Error> {
|
||||
assert!(keys.len() <= 10);
|
||||
Ok(keys.into_iter().map(|k| (k, k)).collect())
|
||||
Ok(keys.iter().copied().map(|k| (k, k)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Loader<i64> for MyLoader {
|
||||
type Value = i64;
|
||||
type Error = ();
|
||||
|
||||
async fn load(&self, keys: &[i64]) -> Result<HashMap<i64, Self::Value>, Self::Error> {
|
||||
assert!(keys.len() <= 10);
|
||||
Ok(keys.iter().copied().map(|k| (k, k)).collect())
|
||||
}
|
||||
}
|
||||
|
||||
let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10));
|
||||
let mut handles = Vec::new();
|
||||
for i in 0..100 {
|
||||
handles.push(task::spawn({
|
||||
let loader = loader.clone();
|
||||
async move { loader.load_one(i).await }
|
||||
}));
|
||||
}
|
||||
assert_eq!(
|
||||
futures_util::future::try_join_all(handles).await.unwrap(),
|
||||
futures_util::future::try_join_all((0..100i32).map({
|
||||
let loader = loader.clone();
|
||||
move |n| {
|
||||
let loader = loader.clone();
|
||||
async move { loader.load_one(n).await }
|
||||
}
|
||||
}))
|
||||
.await
|
||||
.unwrap(),
|
||||
(0..100).map(Option::Some).collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
futures_util::future::try_join_all((0..100i64).map({
|
||||
let loader = loader.clone();
|
||||
move |n| {
|
||||
let loader = loader.clone();
|
||||
async move { loader.load_one(n).await }
|
||||
}
|
||||
}))
|
||||
.await
|
||||
.unwrap(),
|
||||
(0..100).map(Option::Some).collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue