Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 167 additions & 1 deletion kvbm/kvbm-config/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,137 @@

use std::path::PathBuf;

use serde::{Deserialize, Serialize};
use serde::{
Deserialize, Deserializer, Serialize,
de::{self, MapAccess, Visitor},
};
use validator::Validate;

const SERDE_JSON_ARBITRARY_PRECISION_NUMBER: &str = "$serde_json::private::Number";

fn deserialize_opt_f64<'de, D>(deserializer: D) -> Result<Option<f64>, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_option(OptF64Visitor)
}

struct OptF64Visitor;

impl<'de> Visitor<'de> for OptF64Visitor {
type Value = Option<f64>;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("an optional f64")
}

fn visit_none<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(None)
}

fn visit_unit<E>(self) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(None)
}

fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_any(F64Visitor).map(Some)
}
}

struct F64Visitor;

impl<'de> Visitor<'de> for F64Visitor {
type Value = f64;

fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("an f64 or serde_json arbitrary_precision number map")
}

fn visit_f64<E>(self, value: f64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value)
}

fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value as f64)
}

fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
where
E: de::Error,
{
Ok(value as f64)
}

fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
parse_f64_str(value)
}

fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let Some(key) = map.next_key::<String>()? else {
return Err(de::Error::custom(
"expected serde_json arbitrary_precision number map",
));
};

if key != SERDE_JSON_ARBITRARY_PRECISION_NUMBER {
return Err(de::Error::unknown_field(
&key,
&[SERDE_JSON_ARBITRARY_PRECISION_NUMBER],
));
}

let value = map.next_value::<String>()?;
if map.next_key::<de::IgnoredAny>()?.is_some() {
return Err(de::Error::custom(
"expected serde_json arbitrary_precision number map with one entry",
));
}

parse_f64_str(&value)
}
}

fn parse_f64_str<E>(value: &str) -> Result<f64, E>
where
E: de::Error,
{
let parsed = value
.parse::<f64>()
.map_err(|_| de::Error::invalid_value(de::Unexpected::Str(value), &"an f64 string"))?;
// serde_json's native f64 deserializer rejects out-of-range numbers; preserve
// that behavior so the arbitrary_precision string path can't smuggle in
// non-finite values (e.g. `1e10000` -> inf) that would later cast to usize::MAX
// in `compute_num_blocks`.
if !parsed.is_finite() {
return Err(de::Error::invalid_value(
de::Unexpected::Str(value),
&"a finite f64",
));
}
Ok(parsed)
}

/// Parallelism strategy for KV cache across workers.
///
/// This determines how KV blocks are distributed and transferred across
Expand Down Expand Up @@ -51,6 +179,7 @@ pub enum ParallelismMode {
pub struct HostCacheConfig {
/// Cache size in gigabytes.
/// Used to compute num_blocks if not explicitly set.
#[serde(default, deserialize_with = "deserialize_opt_f64")]
pub cache_size_gb: Option<f64>,

/// Explicit number of blocks for the host cache.
Expand Down Expand Up @@ -96,6 +225,7 @@ impl HostCacheConfig {
pub struct DiskCacheConfig {
/// Cache size in gigabytes.
/// Used to compute num_blocks if not explicitly set.
#[serde(default, deserialize_with = "deserialize_opt_f64")]
pub cache_size_gb: Option<f64>,

/// Explicit number of blocks for the disk cache.
Expand Down Expand Up @@ -208,6 +338,42 @@ mod tests {
assert!(config.is_enabled());
}

#[test]
fn test_cache_size_gb_arbitrary_precision_map() {
let config: HostCacheConfig =
serde_json::from_str(r#"{"cache_size_gb": {"$serde_json::private::Number": "2.0"}}"#)
.unwrap();
assert_eq!(config.cache_size_gb, Some(2.0));

let config: DiskCacheConfig =
serde_json::from_str(r#"{"cache_size_gb": {"$serde_json::private::Number": "2.0"}}"#)
.unwrap();
assert_eq!(config.cache_size_gb, Some(2.0));

let config: HostCacheConfig = serde_json::from_str(r#"{"cache_size_gb": 4}"#).unwrap();
assert_eq!(config.cache_size_gb, Some(4.0));

let config: HostCacheConfig = serde_json::from_str(r#"{"cache_size_gb": 1.5}"#).unwrap();
assert_eq!(config.cache_size_gb, Some(1.5));

let config: HostCacheConfig = serde_json::from_str("{}").unwrap();
assert_eq!(config.cache_size_gb, None);
}

#[test]
fn test_cache_size_gb_rejects_non_finite() {
// Out-of-range numbers arriving through the arbitrary_precision string path
// must be rejected, matching serde_json's native f64 deserializer behavior,
// rather than silently becoming Some(inf).
let result: Result<HostCacheConfig, _> = serde_json::from_str(
r#"{"cache_size_gb": {"$serde_json::private::Number": "1e10000"}}"#,
);
assert!(
result.is_err(),
"non-finite cache_size_gb should be rejected, got {result:?}"
);
}

#[test]
fn test_disk_cache_default() {
let config = DiskCacheConfig::default();
Expand Down