diff --git a/Cargo.toml b/Cargo.toml index 742759e1..0d67f801 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,6 +68,7 @@ opentelemetry = { version = "0.17.0", optional = true, default-features = false, "trace", ] } rust_decimal = { version = "1.14.3", optional = true } +bigdecimal = { version = "0.3.0", optional = true } secrecy = { version = "0.8.0", optional = true } smol_str = { version = "0.1.21", optional = true } time = { version = "0.3.5", optional = true, features = [ diff --git a/src/types/external/big_decimal.rs b/src/types/external/big_decimal.rs new file mode 100644 index 00000000..049ad0a0 --- /dev/null +++ b/src/types/external/big_decimal.rs @@ -0,0 +1,31 @@ +use std::str::FromStr; + +use bigdecimal::BigDecimal; + +use crate::{InputValueError, InputValueResult, Scalar, ScalarType, Value}; + +#[Scalar(internal, name = "BigDecimal")] +impl ScalarType for BigDecimal { + fn parse(value: Value) -> InputValueResult { + match &value { + Value::Number(n) => { + if let Some(f) = n.as_f64() { + return BigDecimal::try_from(f).map_err(InputValueError::custom); + } + + if let Some(f) = n.as_i64() { + return Ok(BigDecimal::from(f)); + } + + // unwrap safe here, because we have check the other possibility + Ok(BigDecimal::from(n.as_u64().unwrap())) + } + Value::String(s) => Ok(BigDecimal::from_str(s)?), + _ => Err(InputValueError::expected_type(value)), + } + } + + fn to_value(&self) -> Value { + Value::String(self.to_string()) + } +} diff --git a/src/types/external/decimal.rs b/src/types/external/decimal.rs index cbe65591..dca544b3 100644 --- a/src/types/external/decimal.rs +++ b/src/types/external/decimal.rs @@ -9,6 +9,18 @@ impl ScalarType for Decimal { fn parse(value: Value) -> InputValueResult { match &value { Value::String(s) => Ok(Decimal::from_str(s)?), + Value::Number(n) => { + if let Some(f) = n.as_f64() { + return Decimal::try_from(f).map_err(InputValueError::custom); + } + + if let Some(f) = n.as_i64() { + return Ok(Decimal::from(f)); + } + + // unwrap safe here, because we have check the other possibility + Ok(Decimal::from(n.as_u64().unwrap())) + } _ => Err(InputValueError::expected_type(value)), } } diff --git a/src/types/external/mod.rs b/src/types/external/mod.rs index 7a696fc9..f9086efb 100644 --- a/src/types/external/mod.rs +++ b/src/types/external/mod.rs @@ -15,6 +15,8 @@ mod string; #[cfg(feature = "tokio-sync")] mod tokio; +#[cfg(feature = "bigdecimal")] +mod big_decimal; #[cfg(feature = "bson")] mod bson; #[cfg(feature = "chrono-tz")]