diff --git a/Cargo.toml b/Cargo.toml index cb01a1cb..8f99fe82 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,12 +26,14 @@ default = [ "unblock", "url", "uuid", + "dataloader" ] apollo_tracing = ["chrono"] apollo_persisted_queries = ["lru", "sha2"] multipart = ["multer", "tempfile"] unblock = ["blocking"] string_number = ["num-traits"] +dataloader = ["futures-timer", "futures-channel", "lru"] # Used for doc(cfg()) nightly = [] @@ -70,6 +72,8 @@ multer = { version = "1.2.2", optional = true } num-traits = { version = "0.2.12", optional = true } sha2 = { version = "0.9.1", optional = true } tempfile = { version = "3.1.0", optional = true } +futures-timer = { version = "3.0.2", optional = true } +futures-channel = { version = "0.3.8", optional = true } [dev-dependencies] async-std = { version = "1.6.5", features = ["attributes"] } diff --git a/examples b/examples index b51536fb..7ca8aad9 160000 --- a/examples +++ b/examples @@ -1 +1 @@ -Subproject commit b51536fb0a933d25d9fcc91a08c2a53daf74daf8 +Subproject commit 7ca8aad9f056dff072537e27ae2ea6db01831957 diff --git a/integrations/tide/tests/test_utils.rs b/integrations/tide/tests/test_utils.rs index d482899d..0786a68b 100644 --- a/integrations/tide/tests/test_utils.rs +++ b/integrations/tide/tests/test_utils.rs @@ -19,5 +19,5 @@ pub fn client() -> Client { } pub async fn wait_server_ready() { - async_std::task::sleep(Duration::from_millis(500)).await; + async_std::task::sleep(Duration::from_secs(1)).await; } diff --git a/src/dataloader/cache.rs b/src/dataloader/cache.rs new file mode 100644 index 00000000..8c14befd --- /dev/null +++ b/src/dataloader/cache.rs @@ -0,0 +1,92 @@ +use std::collections::{HashMap, HashSet}; +use std::hash::Hash; +use std::sync::Mutex; + +use super::Loader; + +/// Cache storage for a loader. +pub trait CacheStorage: Send + Sync + 'static { + /// Type of key. + type Key: Send + Hash + Eq + Clone + 'static; + + /// Type of value. + type Value: Send + Clone + 'static; + + /// Load `value` from cache by `key`. + fn get(&self, key: &Self::Key) -> Option; + + /// Put `value` to cache by `key`. + fn set(&self, key: Self::Key, value: Self::Value); +} + +/// Loader for the [cached](trait.LoaderExt.html#method.cached) method. +pub struct CachedLoader { + loader: T, + cache: C, +} + +impl CachedLoader +where + C: CacheStorage, + T: Loader, +{ + /// Create a loader that can cache data. + pub fn new(loader: T, cache: C) -> Self { + Self { cache, loader } + } +} + +#[async_trait::async_trait] +impl Loader for CachedLoader +where + C: CacheStorage, + T: Loader, +{ + type Key = T::Key; + type Value = T::Value; + type Error = T::Error; + + async fn load( + &self, + mut keys: HashSet, + ) -> Result, Self::Error> { + let mut res = HashMap::new(); + for key in &keys { + if let Some(value) = self.cache.get(key) { + res.insert(key.clone(), value); + } + } + for key in res.keys() { + keys.remove(key); + } + let values = self.loader.load(keys).await?; + for (key, value) in &values { + self.cache.set(key.clone(), value.clone()); + } + res.extend(values); + Ok(res) + } +} + +/// Memory-based LRU cache. +pub struct LruCache(Mutex>); + +impl LruCache { + /// Creates a new LRU Cache that holds at most `cap` items. + pub fn new(cap: usize) -> Self { + Self(Mutex::new(lru::LruCache::new(cap))) + } +} + +impl CacheStorage for LruCache { + type Key = T::Key; + type Value = T::Value; + + fn get(&self, key: &Self::Key) -> Option { + self.0.lock().unwrap().get(key).cloned() + } + + fn set(&self, key: Self::Key, value: Self::Value) { + self.0.lock().unwrap().put(key, value); + } +} diff --git a/src/dataloader/mod.rs b/src/dataloader/mod.rs new file mode 100644 index 00000000..1936b2e7 --- /dev/null +++ b/src/dataloader/mod.rs @@ -0,0 +1,249 @@ +//! Batch loading support, used to solve N+1 problem. +//! +//! # Examples +//! +//! ```rust +//! use async_graphql::*; +//! use async_graphql::dataloader::*; +//! use std::collections::{HashSet, HashMap}; +//! use std::convert::Infallible; +//! use async_graphql::dataloader::Loader; +//! +//! /// This loader simply converts the integer key into a string value. +//! struct MyLoader; +//! +//! #[async_trait::async_trait] +//! impl Loader for MyLoader { +//! type Key = i32; +//! type Value = String; +//! type Error = Infallible; +//! +//! async fn load(&self, keys: HashSet) -> Result, Self::Error> { +//! // Use `MyLoader` to load data. +//! Ok(keys.iter().copied().map(|n| (n, n.to_string())).collect()) +//! } +//! } +//! +//! struct Query; +//! +//! #[Object] +//! impl Query { +//! async fn value(&self, ctx: &Context<'_>, n: i32) -> Option { +//! ctx.data_unchecked::>().load_one(n).await.unwrap() +//! } +//! } +//! +//! async_std::task::block_on(async move { +//! let schema = Schema::build(Query, EmptyMutation, EmptySubscription) +//! .data(DataLoader::new(MyLoader)) +//! .finish(); +//! +//! let res = schema.execute(r#" +//! { +//! v1: value(n: 1) +//! v2: value(n: 2) +//! v3: value(n: 3) +//! v4: value(n: 4) +//! v5: value(n: 5) +//! } +//! "#).await.into_result().unwrap().data; +//! +//! assert_eq!(res, value!({ +//! "v1": "1", +//! "v2": "2", +//! "v3": "3", +//! "v4": "4", +//! "v5": "5", +//! })); +//! }); +//! +//! ``` + +mod cache; + +use std::collections::{HashMap, HashSet}; +use std::hash::Hash; +use std::time::Duration; + +use futures_channel::oneshot; +use futures_timer::Delay; +use futures_util::lock::Mutex; + +pub use cache::{CacheStorage, CachedLoader, LruCache}; + +type ResSender = oneshot::Sender< + Result::Key, ::Value>, ::Error>, +>; + +struct Requests { + keys: HashSet, + pending: Vec<(HashSet, ResSender)>, +} + +impl Default for Requests { + fn default() -> Self { + Self { + keys: Default::default(), + pending: Default::default(), + } + } +} + +impl Requests { + async fn load(self, loader: &T) { + match loader.load(self.keys).await { + Ok(values) => { + for (keys, tx) in self.pending { + let mut res = HashMap::new(); + for key in &keys { + res.extend(values.get(key).map(|value| (key.clone(), value.clone()))); + } + tx.send(Ok(res)).ok(); + } + } + Err(err) => { + for (_, tx) in self.pending { + tx.send(Err(err.clone())).ok(); + } + } + } + } +} + +/// Trait for batch loading. +#[async_trait::async_trait] +pub trait Loader: Send + Sync + 'static { + /// Type of key. + type Key: Send + Hash + Eq + Clone + 'static; + + /// type of value. + type Value: Send + Clone + 'static; + + /// Type of error. + type Error: Send + Clone + 'static; + + /// Load the data set specified by the `keys`. + async fn load( + &self, + keys: HashSet, + ) -> Result, Self::Error>; +} + +/// Data loader. +/// +/// Reference: https://github.com/facebook/dataloader +pub struct DataLoader { + requests: Mutex>, + delay: Duration, + max_batch_size: usize, + loader: T, +} + +impl DataLoader { + /// Create a DataLoader with the `Loader` trait. + pub fn new(loader: T) -> Self { + Self { + requests: Default::default(), + delay: Duration::from_millis(20), + max_batch_size: 1000, + loader, + } + } + + /// Specify the delay time for loading data, the default is `20ms`. + pub fn delay(self, delay: Duration) -> Self { + Self { delay, ..self } + } + + /// pub fn Specify the max batch size for loading data, the default is `1000`. + /// + /// If the keys waiting to be loaded reach the threshold, they are loaded immediately. + pub fn max_batch_size(self, max_batch_size: usize) -> Self { + Self { + max_batch_size, + ..self + } + } + + /// Use this `DataLoader` load a data. + pub async fn load_one(&self, key: T::Key) -> Result, T::Error> { + let mut values = self.load_many(std::iter::once(key.clone())).await?; + Ok(values.remove(&key)) + } + + /// Use this `DataLoader` to load some data. + pub async fn load_many( + &self, + keys: impl Iterator, + ) -> Result, T::Error> { + let (start_fetch, rx) = { + let mut requests = self.requests.lock().await; + let prev_count = requests.keys.len(); + let keys = keys.collect::>(); + requests.keys.extend(keys.clone()); + if requests.keys.len() == prev_count { + return Ok(Default::default()); + } + let (tx, rx) = oneshot::channel(); + requests.pending.push((keys, tx)); + if requests.keys.len() >= self.max_batch_size { + let r = std::mem::take(&mut *requests); + drop(requests); + r.load(&self.loader).await; + (false, rx) + } else { + (!requests.keys.is_empty() && prev_count == 0, rx) + } + }; + + if start_fetch { + Delay::new(self.delay).await; + let requests = std::mem::take(&mut *self.requests.lock().await); + if !requests.keys.is_empty() { + requests.load(&self.loader).await; + } + } + + rx.await.unwrap() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use async_std::task; + use std::sync::Arc; + + #[async_std::test] + async fn test_dataloader() { + struct MyLoader; + + #[async_trait::async_trait] + impl Loader for MyLoader { + type Key = i32; + type Value = i32; + type Error = (); + + async fn load( + &self, + keys: HashSet, + ) -> Result, Self::Error> { + assert!(keys.len() <= 10); + Ok(keys.into_iter().map(|k| (k, k)).collect()) + } + } + + let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10)); + let mut handles = Vec::new(); + for i in 0..100 { + handles.push(task::spawn({ + let loader = loader.clone(); + async move { loader.load_one(i).await } + })); + } + assert_eq!( + futures_util::future::try_join_all(handles).await.unwrap(), + (0..100).map(Option::Some).collect::>() + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index d31556a6..5a330922 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -176,6 +176,9 @@ mod subscription; mod validation; pub mod context; +#[cfg(feature = "dataloader")] +#[cfg_attr(feature = "nightly", doc(cfg(feature = "dataloader")))] +pub mod dataloader; pub mod extensions; pub mod guard; pub mod http; diff --git a/src/schema.rs b/src/schema.rs index 5461119b..58baa44d 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -51,13 +51,13 @@ impl SchemaBuilder self } - /// Set the maximum complexity a query can have. By default there is no limit. + /// Set the maximum complexity a query can have. By default, there is no limit. pub fn limit_complexity(mut self, complexity: usize) -> Self { self.complexity = Some(complexity); self } - /// Set the maximum depth a query can have. By default there is no limit. + /// Set the maximum depth a query can have. By default, there is no limit. pub fn limit_depth(mut self, depth: usize) -> Self { self.depth = Some(depth); self @@ -336,7 +336,7 @@ where /// Get all names in this schema /// /// Maybe you want to serialize a custom binary protocol. In order to minimize message size, a dictionary - /// is usually used to compress type names, field names, directive names, and parameter names. This function gets all the names + /// is usually used to compress type names, field names, directive names, and parameter names. This function gets all the names, /// so you can create this dictionary. pub fn names(&self) -> Vec { self.0.env.registry.names()