Add dataloader
This commit is contained in:
parent
f8b8099209
commit
af1d501d75
|
@ -26,12 +26,14 @@ default = [
|
||||||
"unblock",
|
"unblock",
|
||||||
"url",
|
"url",
|
||||||
"uuid",
|
"uuid",
|
||||||
|
"dataloader"
|
||||||
]
|
]
|
||||||
apollo_tracing = ["chrono"]
|
apollo_tracing = ["chrono"]
|
||||||
apollo_persisted_queries = ["lru", "sha2"]
|
apollo_persisted_queries = ["lru", "sha2"]
|
||||||
multipart = ["multer", "tempfile"]
|
multipart = ["multer", "tempfile"]
|
||||||
unblock = ["blocking"]
|
unblock = ["blocking"]
|
||||||
string_number = ["num-traits"]
|
string_number = ["num-traits"]
|
||||||
|
dataloader = ["futures-timer", "futures-channel", "lru"]
|
||||||
# Used for doc(cfg())
|
# Used for doc(cfg())
|
||||||
nightly = []
|
nightly = []
|
||||||
|
|
||||||
|
@ -70,6 +72,8 @@ multer = { version = "1.2.2", optional = true }
|
||||||
num-traits = { version = "0.2.12", optional = true }
|
num-traits = { version = "0.2.12", optional = true }
|
||||||
sha2 = { version = "0.9.1", optional = true }
|
sha2 = { version = "0.9.1", optional = true }
|
||||||
tempfile = { version = "3.1.0", optional = true }
|
tempfile = { version = "3.1.0", optional = true }
|
||||||
|
futures-timer = { version = "3.0.2", optional = true }
|
||||||
|
futures-channel = { version = "0.3.8", optional = true }
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
async-std = { version = "1.6.5", features = ["attributes"] }
|
async-std = { version = "1.6.5", features = ["attributes"] }
|
||||||
|
|
2
examples
2
examples
|
@ -1 +1 @@
|
||||||
Subproject commit b51536fb0a933d25d9fcc91a08c2a53daf74daf8
|
Subproject commit 7ca8aad9f056dff072537e27ae2ea6db01831957
|
|
@ -19,5 +19,5 @@ pub fn client() -> Client {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn wait_server_ready() {
|
pub async fn wait_server_ready() {
|
||||||
async_std::task::sleep(Duration::from_millis(500)).await;
|
async_std::task::sleep(Duration::from_secs(1)).await;
|
||||||
}
|
}
|
||||||
|
|
92
src/dataloader/cache.rs
Normal file
92
src/dataloader/cache.rs
Normal file
|
@ -0,0 +1,92 @@
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::hash::Hash;
|
||||||
|
use std::sync::Mutex;
|
||||||
|
|
||||||
|
use super::Loader;
|
||||||
|
|
||||||
|
/// Cache storage for a loader.
|
||||||
|
pub trait CacheStorage: Send + Sync + 'static {
|
||||||
|
/// Type of key.
|
||||||
|
type Key: Send + Hash + Eq + Clone + 'static;
|
||||||
|
|
||||||
|
/// Type of value.
|
||||||
|
type Value: Send + Clone + 'static;
|
||||||
|
|
||||||
|
/// Load `value` from cache by `key`.
|
||||||
|
fn get(&self, key: &Self::Key) -> Option<Self::Value>;
|
||||||
|
|
||||||
|
/// Put `value` to cache by `key`.
|
||||||
|
fn set(&self, key: Self::Key, value: Self::Value);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Loader for the [cached](trait.LoaderExt.html#method.cached) method.
|
||||||
|
pub struct CachedLoader<C, T> {
|
||||||
|
loader: T,
|
||||||
|
cache: C,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<C, T> CachedLoader<C, T>
|
||||||
|
where
|
||||||
|
C: CacheStorage<Key = T::Key, Value = T::Value>,
|
||||||
|
T: Loader,
|
||||||
|
{
|
||||||
|
/// Create a loader that can cache data.
|
||||||
|
pub fn new(loader: T, cache: C) -> Self {
|
||||||
|
Self { cache, loader }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl<C, T> Loader for CachedLoader<C, T>
|
||||||
|
where
|
||||||
|
C: CacheStorage<Key = T::Key, Value = T::Value>,
|
||||||
|
T: Loader,
|
||||||
|
{
|
||||||
|
type Key = T::Key;
|
||||||
|
type Value = T::Value;
|
||||||
|
type Error = T::Error;
|
||||||
|
|
||||||
|
async fn load(
|
||||||
|
&self,
|
||||||
|
mut keys: HashSet<Self::Key>,
|
||||||
|
) -> Result<HashMap<Self::Key, Self::Value>, Self::Error> {
|
||||||
|
let mut res = HashMap::new();
|
||||||
|
for key in &keys {
|
||||||
|
if let Some(value) = self.cache.get(key) {
|
||||||
|
res.insert(key.clone(), value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for key in res.keys() {
|
||||||
|
keys.remove(key);
|
||||||
|
}
|
||||||
|
let values = self.loader.load(keys).await?;
|
||||||
|
for (key, value) in &values {
|
||||||
|
self.cache.set(key.clone(), value.clone());
|
||||||
|
}
|
||||||
|
res.extend(values);
|
||||||
|
Ok(res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Memory-based LRU cache.
|
||||||
|
pub struct LruCache<T: CacheStorage>(Mutex<lru::LruCache<T::Key, T::Value>>);
|
||||||
|
|
||||||
|
impl<T: CacheStorage> LruCache<T> {
|
||||||
|
/// Creates a new LRU Cache that holds at most `cap` items.
|
||||||
|
pub fn new(cap: usize) -> Self {
|
||||||
|
Self(Mutex::new(lru::LruCache::new(cap)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: CacheStorage> CacheStorage for LruCache<T> {
|
||||||
|
type Key = T::Key;
|
||||||
|
type Value = T::Value;
|
||||||
|
|
||||||
|
fn get(&self, key: &Self::Key) -> Option<Self::Value> {
|
||||||
|
self.0.lock().unwrap().get(key).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set(&self, key: Self::Key, value: Self::Value) {
|
||||||
|
self.0.lock().unwrap().put(key, value);
|
||||||
|
}
|
||||||
|
}
|
249
src/dataloader/mod.rs
Normal file
249
src/dataloader/mod.rs
Normal file
|
@ -0,0 +1,249 @@
|
||||||
|
//! Batch loading support, used to solve N+1 problem.
|
||||||
|
//!
|
||||||
|
//! # Examples
|
||||||
|
//!
|
||||||
|
//! ```rust
|
||||||
|
//! use async_graphql::*;
|
||||||
|
//! use async_graphql::dataloader::*;
|
||||||
|
//! use std::collections::{HashSet, HashMap};
|
||||||
|
//! use std::convert::Infallible;
|
||||||
|
//! use async_graphql::dataloader::Loader;
|
||||||
|
//!
|
||||||
|
//! /// This loader simply converts the integer key into a string value.
|
||||||
|
//! struct MyLoader;
|
||||||
|
//!
|
||||||
|
//! #[async_trait::async_trait]
|
||||||
|
//! impl Loader for MyLoader {
|
||||||
|
//! type Key = i32;
|
||||||
|
//! type Value = String;
|
||||||
|
//! type Error = Infallible;
|
||||||
|
//!
|
||||||
|
//! async fn load(&self, keys: HashSet<Self::Key>) -> Result<HashMap<Self::Key, Self::Value>, Self::Error> {
|
||||||
|
//! // Use `MyLoader` to load data.
|
||||||
|
//! Ok(keys.iter().copied().map(|n| (n, n.to_string())).collect())
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! struct Query;
|
||||||
|
//!
|
||||||
|
//! #[Object]
|
||||||
|
//! impl Query {
|
||||||
|
//! async fn value(&self, ctx: &Context<'_>, n: i32) -> Option<String> {
|
||||||
|
//! ctx.data_unchecked::<DataLoader<MyLoader>>().load_one(n).await.unwrap()
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//!
|
||||||
|
//! async_std::task::block_on(async move {
|
||||||
|
//! let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
|
||||||
|
//! .data(DataLoader::new(MyLoader))
|
||||||
|
//! .finish();
|
||||||
|
//!
|
||||||
|
//! let res = schema.execute(r#"
|
||||||
|
//! {
|
||||||
|
//! v1: value(n: 1)
|
||||||
|
//! v2: value(n: 2)
|
||||||
|
//! v3: value(n: 3)
|
||||||
|
//! v4: value(n: 4)
|
||||||
|
//! v5: value(n: 5)
|
||||||
|
//! }
|
||||||
|
//! "#).await.into_result().unwrap().data;
|
||||||
|
//!
|
||||||
|
//! assert_eq!(res, value!({
|
||||||
|
//! "v1": "1",
|
||||||
|
//! "v2": "2",
|
||||||
|
//! "v3": "3",
|
||||||
|
//! "v4": "4",
|
||||||
|
//! "v5": "5",
|
||||||
|
//! }));
|
||||||
|
//! });
|
||||||
|
//!
|
||||||
|
//! ```
|
||||||
|
|
||||||
|
mod cache;
|
||||||
|
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
use std::hash::Hash;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use futures_channel::oneshot;
|
||||||
|
use futures_timer::Delay;
|
||||||
|
use futures_util::lock::Mutex;
|
||||||
|
|
||||||
|
pub use cache::{CacheStorage, CachedLoader, LruCache};
|
||||||
|
|
||||||
|
type ResSender<T> = oneshot::Sender<
|
||||||
|
Result<HashMap<<T as Loader>::Key, <T as Loader>::Value>, <T as Loader>::Error>,
|
||||||
|
>;
|
||||||
|
|
||||||
|
struct Requests<T: Loader> {
|
||||||
|
keys: HashSet<T::Key>,
|
||||||
|
pending: Vec<(HashSet<T::Key>, ResSender<T>)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Loader> Default for Requests<T> {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
keys: Default::default(),
|
||||||
|
pending: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Loader> Requests<T> {
|
||||||
|
async fn load(self, loader: &T) {
|
||||||
|
match loader.load(self.keys).await {
|
||||||
|
Ok(values) => {
|
||||||
|
for (keys, tx) in self.pending {
|
||||||
|
let mut res = HashMap::new();
|
||||||
|
for key in &keys {
|
||||||
|
res.extend(values.get(key).map(|value| (key.clone(), value.clone())));
|
||||||
|
}
|
||||||
|
tx.send(Ok(res)).ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(err) => {
|
||||||
|
for (_, tx) in self.pending {
|
||||||
|
tx.send(Err(err.clone())).ok();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Trait for batch loading.
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
pub trait Loader: Send + Sync + 'static {
|
||||||
|
/// Type of key.
|
||||||
|
type Key: Send + Hash + Eq + Clone + 'static;
|
||||||
|
|
||||||
|
/// type of value.
|
||||||
|
type Value: Send + Clone + 'static;
|
||||||
|
|
||||||
|
/// Type of error.
|
||||||
|
type Error: Send + Clone + 'static;
|
||||||
|
|
||||||
|
/// Load the data set specified by the `keys`.
|
||||||
|
async fn load(
|
||||||
|
&self,
|
||||||
|
keys: HashSet<Self::Key>,
|
||||||
|
) -> Result<HashMap<Self::Key, Self::Value>, Self::Error>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Data loader.
|
||||||
|
///
|
||||||
|
/// Reference: https://github.com/facebook/dataloader
|
||||||
|
pub struct DataLoader<T: Loader> {
|
||||||
|
requests: Mutex<Requests<T>>,
|
||||||
|
delay: Duration,
|
||||||
|
max_batch_size: usize,
|
||||||
|
loader: T,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Loader> DataLoader<T> {
|
||||||
|
/// Create a DataLoader with the `Loader` trait.
|
||||||
|
pub fn new(loader: T) -> Self {
|
||||||
|
Self {
|
||||||
|
requests: Default::default(),
|
||||||
|
delay: Duration::from_millis(20),
|
||||||
|
max_batch_size: 1000,
|
||||||
|
loader,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Specify the delay time for loading data, the default is `20ms`.
|
||||||
|
pub fn delay(self, delay: Duration) -> Self {
|
||||||
|
Self { delay, ..self }
|
||||||
|
}
|
||||||
|
|
||||||
|
/// pub fn Specify the max batch size for loading data, the default is `1000`.
|
||||||
|
///
|
||||||
|
/// If the keys waiting to be loaded reach the threshold, they are loaded immediately.
|
||||||
|
pub fn max_batch_size(self, max_batch_size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
max_batch_size,
|
||||||
|
..self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Use this `DataLoader` load a data.
|
||||||
|
pub async fn load_one(&self, key: T::Key) -> Result<Option<T::Value>, T::Error> {
|
||||||
|
let mut values = self.load_many(std::iter::once(key.clone())).await?;
|
||||||
|
Ok(values.remove(&key))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Use this `DataLoader` to load some data.
|
||||||
|
pub async fn load_many(
|
||||||
|
&self,
|
||||||
|
keys: impl Iterator<Item = T::Key>,
|
||||||
|
) -> Result<HashMap<T::Key, T::Value>, T::Error> {
|
||||||
|
let (start_fetch, rx) = {
|
||||||
|
let mut requests = self.requests.lock().await;
|
||||||
|
let prev_count = requests.keys.len();
|
||||||
|
let keys = keys.collect::<HashSet<_>>();
|
||||||
|
requests.keys.extend(keys.clone());
|
||||||
|
if requests.keys.len() == prev_count {
|
||||||
|
return Ok(Default::default());
|
||||||
|
}
|
||||||
|
let (tx, rx) = oneshot::channel();
|
||||||
|
requests.pending.push((keys, tx));
|
||||||
|
if requests.keys.len() >= self.max_batch_size {
|
||||||
|
let r = std::mem::take(&mut *requests);
|
||||||
|
drop(requests);
|
||||||
|
r.load(&self.loader).await;
|
||||||
|
(false, rx)
|
||||||
|
} else {
|
||||||
|
(!requests.keys.is_empty() && prev_count == 0, rx)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if start_fetch {
|
||||||
|
Delay::new(self.delay).await;
|
||||||
|
let requests = std::mem::take(&mut *self.requests.lock().await);
|
||||||
|
if !requests.keys.is_empty() {
|
||||||
|
requests.load(&self.loader).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rx.await.unwrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use async_std::task;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[async_std::test]
|
||||||
|
async fn test_dataloader() {
|
||||||
|
struct MyLoader;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Loader for MyLoader {
|
||||||
|
type Key = i32;
|
||||||
|
type Value = i32;
|
||||||
|
type Error = ();
|
||||||
|
|
||||||
|
async fn load(
|
||||||
|
&self,
|
||||||
|
keys: HashSet<Self::Key>,
|
||||||
|
) -> Result<HashMap<Self::Key, Self::Value>, Self::Error> {
|
||||||
|
assert!(keys.len() <= 10);
|
||||||
|
Ok(keys.into_iter().map(|k| (k, k)).collect())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10));
|
||||||
|
let mut handles = Vec::new();
|
||||||
|
for i in 0..100 {
|
||||||
|
handles.push(task::spawn({
|
||||||
|
let loader = loader.clone();
|
||||||
|
async move { loader.load_one(i).await }
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
assert_eq!(
|
||||||
|
futures_util::future::try_join_all(handles).await.unwrap(),
|
||||||
|
(0..100).map(Option::Some).collect::<Vec<_>>()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -176,6 +176,9 @@ mod subscription;
|
||||||
mod validation;
|
mod validation;
|
||||||
|
|
||||||
pub mod context;
|
pub mod context;
|
||||||
|
#[cfg(feature = "dataloader")]
|
||||||
|
#[cfg_attr(feature = "nightly", doc(cfg(feature = "dataloader")))]
|
||||||
|
pub mod dataloader;
|
||||||
pub mod extensions;
|
pub mod extensions;
|
||||||
pub mod guard;
|
pub mod guard;
|
||||||
pub mod http;
|
pub mod http;
|
||||||
|
|
|
@ -51,13 +51,13 @@ impl<Query, Mutation, Subscription> SchemaBuilder<Query, Mutation, Subscription>
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the maximum complexity a query can have. By default there is no limit.
|
/// Set the maximum complexity a query can have. By default, there is no limit.
|
||||||
pub fn limit_complexity(mut self, complexity: usize) -> Self {
|
pub fn limit_complexity(mut self, complexity: usize) -> Self {
|
||||||
self.complexity = Some(complexity);
|
self.complexity = Some(complexity);
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the maximum depth a query can have. By default there is no limit.
|
/// Set the maximum depth a query can have. By default, there is no limit.
|
||||||
pub fn limit_depth(mut self, depth: usize) -> Self {
|
pub fn limit_depth(mut self, depth: usize) -> Self {
|
||||||
self.depth = Some(depth);
|
self.depth = Some(depth);
|
||||||
self
|
self
|
||||||
|
@ -336,7 +336,7 @@ where
|
||||||
/// Get all names in this schema
|
/// Get all names in this schema
|
||||||
///
|
///
|
||||||
/// Maybe you want to serialize a custom binary protocol. In order to minimize message size, a dictionary
|
/// Maybe you want to serialize a custom binary protocol. In order to minimize message size, a dictionary
|
||||||
/// is usually used to compress type names, field names, directive names, and parameter names. This function gets all the names
|
/// is usually used to compress type names, field names, directive names, and parameter names. This function gets all the names,
|
||||||
/// so you can create this dictionary.
|
/// so you can create this dictionary.
|
||||||
pub fn names(&self) -> Vec<String> {
|
pub fn names(&self) -> Vec<String> {
|
||||||
self.0.env.registry.names()
|
self.0.env.registry.names()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user