diff --git a/kvbm/kvbm-config/src/cache.rs b/kvbm/kvbm-config/src/cache.rs index a2f33884..338147fe 100644 --- a/kvbm/kvbm-config/src/cache.rs +++ b/kvbm/kvbm-config/src/cache.rs @@ -10,9 +10,137 @@ use std::path::PathBuf; -use serde::{Deserialize, Serialize}; +use serde::{ + Deserialize, Deserializer, Serialize, + de::{self, MapAccess, Visitor}, +}; use validator::Validate; +const SERDE_JSON_ARBITRARY_PRECISION_NUMBER: &str = "$serde_json::private::Number"; + +fn deserialize_opt_f64<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + deserializer.deserialize_option(OptF64Visitor) +} + +struct OptF64Visitor; + +impl<'de> Visitor<'de> for OptF64Visitor { + type Value = Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("an optional f64") + } + + fn visit_none(self) -> Result + where + E: de::Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: de::Error, + { + Ok(None) + } + + fn visit_some(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_any(F64Visitor).map(Some) + } +} + +struct F64Visitor; + +impl<'de> Visitor<'de> for F64Visitor { + type Value = f64; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("an f64 or serde_json arbitrary_precision number map") + } + + fn visit_f64(self, value: f64) -> Result + where + E: de::Error, + { + Ok(value) + } + + fn visit_i64(self, value: i64) -> Result + where + E: de::Error, + { + Ok(value as f64) + } + + fn visit_u64(self, value: u64) -> Result + where + E: de::Error, + { + Ok(value as f64) + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + parse_f64_str(value) + } + + fn visit_map(self, mut map: A) -> Result + where + A: MapAccess<'de>, + { + let Some(key) = map.next_key::()? else { + return Err(de::Error::custom( + "expected serde_json arbitrary_precision number map", + )); + }; + + if key != SERDE_JSON_ARBITRARY_PRECISION_NUMBER { + return Err(de::Error::unknown_field( + &key, + &[SERDE_JSON_ARBITRARY_PRECISION_NUMBER], + )); + } + + let value = map.next_value::()?; + if map.next_key::()?.is_some() { + return Err(de::Error::custom( + "expected serde_json arbitrary_precision number map with one entry", + )); + } + + parse_f64_str(&value) + } +} + +fn parse_f64_str(value: &str) -> Result +where + E: de::Error, +{ + let parsed = value + .parse::() + .map_err(|_| de::Error::invalid_value(de::Unexpected::Str(value), &"an f64 string"))?; + // serde_json's native f64 deserializer rejects out-of-range numbers; preserve + // that behavior so the arbitrary_precision string path can't smuggle in + // non-finite values (e.g. `1e10000` -> inf) that would later cast to usize::MAX + // in `compute_num_blocks`. + if !parsed.is_finite() { + return Err(de::Error::invalid_value( + de::Unexpected::Str(value), + &"a finite f64", + )); + } + Ok(parsed) +} + /// Parallelism strategy for KV cache across workers. /// /// This determines how KV blocks are distributed and transferred across @@ -51,6 +179,7 @@ pub enum ParallelismMode { pub struct HostCacheConfig { /// Cache size in gigabytes. /// Used to compute num_blocks if not explicitly set. + #[serde(default, deserialize_with = "deserialize_opt_f64")] pub cache_size_gb: Option, /// Explicit number of blocks for the host cache. @@ -96,6 +225,7 @@ impl HostCacheConfig { pub struct DiskCacheConfig { /// Cache size in gigabytes. /// Used to compute num_blocks if not explicitly set. + #[serde(default, deserialize_with = "deserialize_opt_f64")] pub cache_size_gb: Option, /// Explicit number of blocks for the disk cache. @@ -208,6 +338,42 @@ mod tests { assert!(config.is_enabled()); } + #[test] + fn test_cache_size_gb_arbitrary_precision_map() { + let config: HostCacheConfig = + serde_json::from_str(r#"{"cache_size_gb": {"$serde_json::private::Number": "2.0"}}"#) + .unwrap(); + assert_eq!(config.cache_size_gb, Some(2.0)); + + let config: DiskCacheConfig = + serde_json::from_str(r#"{"cache_size_gb": {"$serde_json::private::Number": "2.0"}}"#) + .unwrap(); + assert_eq!(config.cache_size_gb, Some(2.0)); + + let config: HostCacheConfig = serde_json::from_str(r#"{"cache_size_gb": 4}"#).unwrap(); + assert_eq!(config.cache_size_gb, Some(4.0)); + + let config: HostCacheConfig = serde_json::from_str(r#"{"cache_size_gb": 1.5}"#).unwrap(); + assert_eq!(config.cache_size_gb, Some(1.5)); + + let config: HostCacheConfig = serde_json::from_str("{}").unwrap(); + assert_eq!(config.cache_size_gb, None); + } + + #[test] + fn test_cache_size_gb_rejects_non_finite() { + // Out-of-range numbers arriving through the arbitrary_precision string path + // must be rejected, matching serde_json's native f64 deserializer behavior, + // rather than silently becoming Some(inf). + let result: Result = serde_json::from_str( + r#"{"cache_size_gb": {"$serde_json::private::Number": "1e10000"}}"#, + ); + assert!( + result.is_err(), + "non-finite cache_size_gb should be rejected, got {result:?}" + ); + } + #[test] fn test_disk_cache_default() { let config = DiskCacheConfig::default();