//! Batch loading support, used to solve N+1 problem. //! //! # Examples //! //! ```rust //! use async_graphql::*; //! use async_graphql::dataloader::*; //! use std::collections::{HashSet, HashMap}; //! use std::convert::Infallible; //! use async_graphql::dataloader::Loader; //! //! /// This loader simply converts the integer key into a string value. //! struct MyLoader; //! //! #[async_trait::async_trait] //! impl Loader for MyLoader { //! type Value = String; //! type Error = Infallible; //! //! async fn load(&self, keys: &[i32]) -> Result, Self::Error> { //! // Use `MyLoader` to load data. //! Ok(keys.iter().copied().map(|n| (n, n.to_string())).collect()) //! } //! } //! //! struct Query; //! //! #[Object] //! impl Query { //! async fn value(&self, ctx: &Context<'_>, n: i32) -> Option { //! ctx.data_unchecked::>().load_one(n).await.unwrap() //! } //! } //! //! tokio::runtime::Runtime::new().unwrap().block_on(async move { //! let schema = Schema::new(Query, EmptyMutation, EmptySubscription); //! let query = r#" //! { //! v1: value(n: 1) //! v2: value(n: 2) //! v3: value(n: 3) //! v4: value(n: 4) //! v5: value(n: 5) //! } //! "#; //! let request = Request::new(query).data(DataLoader::new(MyLoader)); //! let res = schema.execute(request).await.into_result().unwrap().data; //! //! assert_eq!(res, value!({ //! "v1": "1", //! "v2": "2", //! "v3": "3", //! "v4": "4", //! "v5": "5", //! })); //! }); //! //! ``` use std::any::{Any, TypeId}; use std::collections::{HashMap, HashSet}; use std::hash::Hash; use std::time::Duration; use futures_channel::oneshot; use futures_timer::Delay; use futures_util::lock::Mutex; use fnv::FnvHashMap; #[allow(clippy::type_complexity)] struct ResSender>( oneshot::Sender, T::Error>>, ); struct Requests> { keys: HashSet, pending: Vec<(HashSet, ResSender)>, } impl> Default for Requests { fn default() -> Self { Self { keys: Default::default(), pending: Default::default(), } } } impl> Requests { async fn load(self, loader: &T) { let keys = self.keys.into_iter().collect::>(); match loader.load(&keys).await { Ok(values) => { for (keys, tx) in self.pending { let mut res = HashMap::new(); for key in &keys { res.extend(values.get(key).map(|value| (key.clone(), value.clone()))); } tx.0.send(Ok(res)).ok(); } } Err(err) => { for (_, tx) in self.pending { tx.0.send(Err(err.clone())).ok(); } } } } } /// Trait for batch loading. #[async_trait::async_trait] pub trait Loader: Send + Sync + 'static { /// type of value. type Value: Send + Clone + 'static; /// Type of error. type Error: Send + Clone + 'static; /// Load the data set specified by the `keys`. async fn load(&self, keys: &[K]) -> Result, Self::Error>; } /// Data loader. /// /// Reference: https://github.com/facebook/dataloader pub struct DataLoader { requests: Mutex>>, delay: Duration, max_batch_size: usize, loader: T, } impl DataLoader { /// Create a DataLoader with the `Loader` trait. pub fn new(loader: T) -> Self { Self { requests: Default::default(), delay: Duration::from_millis(1), max_batch_size: 1000, loader, } } /// Specify the delay time for loading data, the default is `1ms`. pub fn delay(self, delay: Duration) -> Self { Self { delay, ..self } } /// pub fn Specify the max batch size for loading data, the default is `1000`. /// /// If the keys waiting to be loaded reach the threshold, they are loaded immediately. pub fn max_batch_size(self, max_batch_size: usize) -> Self { Self { max_batch_size, ..self } } /// Use this `DataLoader` load a data. pub async fn load_one(&self, key: K) -> Result, T::Error> where K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader, { let mut values = self.load_many(std::iter::once(key.clone())).await?; Ok(values.remove(&key)) } /// Use this `DataLoader` to load some data. pub async fn load_many( &self, keys: impl Iterator, ) -> Result, T::Error> where K: Send + Sync + Hash + Eq + Clone + 'static, T: Loader, { let tid = TypeId::of::(); let (start_fetch, rx) = { let mut requests = self.requests.lock().await; let typed_requests = requests .entry(tid) .or_insert_with(|| Box::new(Requests::::default())) .downcast_mut::>() .unwrap(); let prev_count = typed_requests.keys.len(); let keys = keys.collect::>(); typed_requests.keys.extend(keys.clone()); let (tx, rx) = oneshot::channel(); typed_requests.pending.push((keys, ResSender(tx))); if typed_requests.keys.len() >= self.max_batch_size { let r = std::mem::take(&mut *typed_requests); drop(requests); r.load(&self.loader).await; (false, rx) } else { (!typed_requests.keys.is_empty() && prev_count == 0, rx) } }; if start_fetch { Delay::new(self.delay).await; let mut requests = self.requests.lock().await; let typed_requests = requests .get_mut(&tid) .unwrap() .downcast_mut::>() .unwrap(); let typed_requests = std::mem::take(typed_requests); drop(requests); if !typed_requests.keys.is_empty() { typed_requests.load(&self.loader).await; } } rx.await.unwrap() } } #[cfg(test)] mod tests { use super::*; use std::sync::Arc; struct MyLoader; #[async_trait::async_trait] impl Loader for MyLoader { type Value = i32; type Error = (); async fn load(&self, keys: &[i32]) -> Result, Self::Error> { assert!(keys.len() <= 10); Ok(keys.iter().copied().map(|k| (k, k)).collect()) } } #[async_trait::async_trait] impl Loader for MyLoader { type Value = i64; type Error = (); async fn load(&self, keys: &[i64]) -> Result, Self::Error> { assert!(keys.len() <= 10); Ok(keys.iter().copied().map(|k| (k, k)).collect()) } } #[tokio::test] async fn test_dataloader() { let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10)); assert_eq!( futures_util::future::try_join_all((0..100i32).map({ let loader = loader.clone(); move |n| { let loader = loader.clone(); async move { loader.load_one(n).await } } })) .await .unwrap(), (0..100).map(Option::Some).collect::>() ); assert_eq!( futures_util::future::try_join_all((0..100i64).map({ let loader = loader.clone(); move |n| { let loader = loader.clone(); async move { loader.load_one(n).await } } })) .await .unwrap(), (0..100).map(Option::Some).collect::>() ); } #[tokio::test] async fn test_duplicate_keys() { let loader = Arc::new(DataLoader::new(MyLoader).max_batch_size(10)); assert_eq!( futures_util::future::try_join_all([1, 3, 5, 1, 7, 8, 3, 7].iter().copied().map({ let loader = loader.clone(); move |n| { let loader = loader.clone(); async move { loader.load_one(n).await } } })) .await .unwrap(), [1, 3, 5, 1, 7, 8, 3, 7] .iter() .copied() .map(Option::Some) .collect::>() ); } }