diff --git a/Cargo.toml b/Cargo.toml index 1f6282df903..0318a326e1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ members = [ # nebari = { path = "../nebari/nebari", version = "0.1" } # nebari = { git = "https://github.com/khonsulabs/nebari.git", branch = "main" } # arc-bytes = { path = "../shared-buffer" } +arc-bytes = { git = "https://github.com/khonsulabs/arc-bytes.git", branch = "main" } # [patch."https://github.com/khonsulabs/custodian.git"] # custodian-password = { path = "../custodian/password" } diff --git a/crates/bonsaidb-core/Cargo.toml b/crates/bonsaidb-core/Cargo.toml index 535a9f5098b..c1ae8c16d3d 100644 --- a/crates/bonsaidb-core/Cargo.toml +++ b/crates/bonsaidb-core/Cargo.toml @@ -52,6 +52,7 @@ tokio = { version = "1", features = ["full"] } futures = { version = "0.3" } num-derive = "0.3" anyhow = "1" +rand = "0.8" [package.metadata.docs.rs] all-features = true diff --git a/crates/bonsaidb-core/src/keyvalue.rs b/crates/bonsaidb-core/src/keyvalue.rs index 28be50682ee..eed2aeb0389 100644 --- a/crates/bonsaidb-core/src/keyvalue.rs +++ b/crates/bonsaidb-core/src/keyvalue.rs @@ -1,10 +1,11 @@ use arc_bytes::serde::Bytes; use serde::{Deserialize, Serialize}; +mod sorted_set; mod timestamp; pub use self::timestamp::Timestamp; -use crate::Error; +use crate::{keyvalue::sorted_set::SortedSet, Error}; mod implementation { use arc_bytes::serde::Bytes; @@ -286,12 +287,14 @@ pub struct SetCommand { } /// A value stored in a key. -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +#[derive(Serialize, Deserialize, Clone, Debug)] pub enum Value { /// A value stored as a byte array. Bytes(Bytes), /// A numeric value. Numeric(Numeric), + /// A set of values sorted by an associated value. + SortedSet(SortedSet), } impl Value { @@ -299,72 +302,79 @@ impl Value { pub fn validate(self) -> Result { match self { Self::Numeric(numeric) => numeric.validate().map(Self::Numeric), - Self::Bytes(vec) => Ok(Self::Bytes(vec)), + other => Ok(other), } } /// Deserializes the bytes contained inside of this value. Returns an error /// if this value doesn't contain bytes. pub fn deserialize Deserialize<'de>>(&self) -> Result { - match self { - Self::Bytes(bytes) => Ok(pot::from_slice(bytes)?), - Self::Numeric(_) => Err(Error::Database(String::from( - "key contains numeric value, not serialized data", - ))), + if let Self::Bytes(bytes) = self { + pot::from_slice(bytes).map_err(Error::from) + } else { + Err(Error::Database(String::from( + "key contains another type of data", + ))) } } /// Returns this value as an `i64`, allowing for precision to be lost if the type was not an `i64` originally. If saturating is true, the conversion will not allow overflows. Returns None if the value is bytes. #[must_use] pub fn as_i64_lossy(&self, saturating: bool) -> Option { - match self { - Self::Bytes(_) => None, - Self::Numeric(value) => Some(value.as_i64_lossy(saturating)), + if let Self::Numeric(value) = self { + Some(value.as_i64_lossy(saturating)) + } else { + None } } /// Returns this value as an `u64`, allowing for precision to be lost if the type was not an `u64` originally. If saturating is true, the conversion will not allow overflows. Returns None if the value is bytes. #[must_use] pub fn as_u64_lossy(&self, saturating: bool) -> Option { - match self { - Self::Bytes(_) => None, - Self::Numeric(value) => Some(value.as_u64_lossy(saturating)), + if let Self::Numeric(value) = self { + Some(value.as_u64_lossy(saturating)) + } else { + None } } /// Returns this value as an `f64`, allowing for precision to be lost if the type was not an `f64` originally. Returns None if the value is bytes. #[must_use] pub const fn as_f64_lossy(&self) -> Option { - match self { - Self::Bytes(_) => None, - Self::Numeric(value) => Some(value.as_f64_lossy()), + if let Self::Numeric(value) = self { + Some(value.as_f64_lossy()) + } else { + None } } /// Returns this numeric as an `i64`, allowing for precision to be lost if the type was not an `i64` originally. Returns None if the value is bytes. #[must_use] pub fn as_i64(&self) -> Option { - match self { - Self::Bytes(_) => None, - Self::Numeric(value) => value.as_i64(), + if let Self::Numeric(value) = self { + value.as_i64() + } else { + None } } /// Returns this numeric as an `u64`, allowing for precision to be lost if the type was not an `u64` originally. Returns None if the value is bytes. #[must_use] pub fn as_u64(&self) -> Option { - match self { - Self::Bytes(_) => None, - Self::Numeric(value) => value.as_u64(), + if let Self::Numeric(value) = self { + value.as_u64() + } else { + None } } /// Returns this numeric as an `f64`, allowing for precision to be lost if the type was not an `f64` originally. Returns None if the value is bytes. #[must_use] pub const fn as_f64(&self) -> Option { - match self { - Self::Bytes(_) => None, - Self::Numeric(value) => value.as_f64(), + if let Self::Numeric(value) = self { + value.as_f64() + } else { + None } } } diff --git a/crates/bonsaidb-core/src/keyvalue/sorted_set.rs b/crates/bonsaidb-core/src/keyvalue/sorted_set.rs new file mode 100644 index 00000000000..a48a62badf6 --- /dev/null +++ b/crates/bonsaidb-core/src/keyvalue/sorted_set.rs @@ -0,0 +1,268 @@ +use std::{cmp::Ordering, collections::HashMap, ops::Deref}; + +use arc_bytes::{ArcBytes, OwnedBytes}; +use serde::{ser::SerializeMap, Deserialize, Serialize}; + +#[derive(Default, Clone, Debug)] +pub struct SortedSet { + members: HashMap, + sorted_members: Vec, +} + +impl SortedSet { + pub fn insert(&mut self, value: OwnedBytes, score: Score) -> Option { + let entry = Entry { score, value }; + let existing_score = self + .members + .insert(entry.value.clone(), entry.score.clone()); + + if existing_score.is_some() { + let (remove_index, _) = self + .sorted_members + .iter() + .enumerate() + .find(|(_, member)| member.value == entry.value) + .unwrap(); + self.sorted_members.remove(remove_index); + } + + let insert_at = self + .sorted_members + .binary_search(&entry) + .unwrap_or_else(|i| i); + self.sorted_members.insert(insert_at, entry); + + existing_score + } + + pub fn score(&self, value: &[u8]) -> Option<&Score> { + self.members.get(value) + } + + pub fn remove(&mut self, value: &[u8]) -> Option { + let existing_score = self.members.remove(value); + if existing_score.is_some() { + let (remove_index, _) = self + .sorted_members + .iter() + .enumerate() + .find(|(_index, member)| member.value == value) + .unwrap(); + self.sorted_members.remove(remove_index); + } + existing_score + } +} + +impl Deref for SortedSet { + type Target = Vec; + + fn deref(&self) -> &Self::Target { + &self.sorted_members + } +} + +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd)] +pub struct Entry { + score: Score, + value: OwnedBytes, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub enum Score { + Signed(i64), + Unsigned(u64), + Float(f64), + Bytes(OwnedBytes), +} + +// We check that the float value on input is not a NaN. +impl Eq for Score {} + +impl PartialEq for Score { + fn eq(&self, other: &Self) -> bool { + self.cmp(other) == Ordering::Equal + } +} + +#[allow(clippy::cast_precision_loss)] +impl Ord for Score { + fn cmp(&self, other: &Self) -> Ordering { + match (self, other) { + (Self::Signed(a), Self::Signed(b)) => a.cmp(b), + (Self::Signed(a), Self::Unsigned(b)) => { + if let Ok(a) = u64::try_from(*a) { + a.cmp(b) + } else { + Ordering::Less + } + } + (Self::Unsigned(a), Self::Signed(b)) => { + if let Ok(b) = u64::try_from(*b) { + a.cmp(&b) + } else { + Ordering::Greater + } + } + (Self::Unsigned(a), Self::Unsigned(b)) => a.cmp(b), + (Self::Float(a), Self::Float(b)) => float_cmp(*a, *b), + (Self::Float(a), Self::Signed(b)) => float_cmp(*a, *b as f64), + (Self::Float(a), Self::Unsigned(b)) => float_cmp(*a, *b as f64), + (Self::Signed(a), Self::Float(b)) => float_cmp(*a as f64, *b), + (Self::Unsigned(a), Self::Float(b)) => float_cmp(*a as f64, *b), + (Self::Bytes(a), Self::Bytes(b)) => a.cmp(b), + (_, Self::Bytes(_)) => Ordering::Less, + (Self::Bytes(_), _) => Ordering::Greater, + } + } +} + +fn float_cmp(f1: f64, f2: f64) -> Ordering { + let abs_diff = f1 - f2; + if abs_diff < f64::EPSILON && abs_diff > -f64::EPSILON { + Ordering::Equal + } else if abs_diff < 0. { + Ordering::Less + } else { + Ordering::Greater + } +} + +#[test] +fn ord_tests() { + use rand::seq::SliceRandom; + let mut set = SortedSet::default(); + let mut rng = rand::thread_rng(); + let originals = vec![ + Score::Signed(-1), + Score::Unsigned(0), + Score::Signed(1), + Score::Float(1.), + Score::Float(1.5), + Score::Unsigned(2), + Score::Bytes(OwnedBytes::from(b"\x00")), + Score::Bytes(OwnedBytes::from(b"\x01")), + ]; + let mut shuffled = originals.clone(); + shuffled.shuffle(&mut rng); + for (i, score) in shuffled.into_iter().enumerate() { + set.insert(OwnedBytes::from(i.to_string().into_bytes()), score); + } + for (i, score) in originals.iter().enumerate() { + assert_eq!(&set[i].score, score); + } +} + +impl PartialOrd for Score { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Serialize for SortedSet { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut map = serializer.serialize_map(Some(self.members.len()))?; + for member in &self.sorted_members { + map.serialize_entry(&member.value, &member.score)?; + } + map.end() + } +} + +impl<'de> Deserialize<'de> for SortedSet { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_map(Visitor) + } +} + +struct Visitor; + +impl<'de> serde::de::Visitor<'de> for Visitor { + type Value = SortedSet; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("sorted set entries") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let (mut members, mut sorted_members) = if let Some(size) = map.size_hint() { + (HashMap::with_capacity(size), Vec::with_capacity(size)) + } else { + (HashMap::default(), Vec::default()) + }; + + while let Some((value, score)) = map.next_entry::, Score>()? { + let entry = Entry { + value: OwnedBytes(value.into_owned()), + score, + }; + members.insert(entry.value.clone(), entry.score.clone()); + sorted_members.push(entry); + } + + sorted_members.sort(); + + Ok(SortedSet { + members, + sorted_members, + }) + } +} + +#[test] +fn basics() { + let mut set = SortedSet::default(); + assert_eq!( + set.insert(OwnedBytes::from(b"first"), Score::Unsigned(2)), + None + ); + assert_eq!(set.score(b"first"), Some(&Score::Unsigned(2))); + assert_eq!( + set.insert(OwnedBytes::from(b"first"), Score::Unsigned(1)), + Some(Score::Unsigned(2)) + ); + assert_eq!(set.score(b"first"), Some(&Score::Unsigned(1))); + assert_eq!( + set.insert(OwnedBytes::from(b"second"), Score::Unsigned(1)), + None + ); + assert_eq!(set.score(b"second"), Some(&Score::Unsigned(1))); + + println!("With 2: {:?}", set); + + assert_eq!(set.insert(OwnedBytes::from(b"a"), Score::Unsigned(2)), None); + assert_eq!(set.len(), 3); + assert_eq!(set.score(b"a"), Some(&Score::Unsigned(2))); + assert_eq!(set[0].value, b"first"); + assert_eq!(set[1].value, b"second"); + assert_eq!(set[2].value, b"a"); + assert_eq!(set.remove(b"first"), Some(Score::Unsigned(1))); + assert_eq!(set.remove(b"first"), None); + assert_eq!(set[0].value, b"second"); + assert_eq!(set[1].value, b"a"); +} + +#[test] +fn serialization() { + let mut set = SortedSet::default(); + set.insert(OwnedBytes::from(b"a"), Score::Signed(2)); + set.insert(OwnedBytes::from(b"b"), Score::Unsigned(1)); + set.insert(OwnedBytes::from(b"c"), Score::Float(0.)); + let as_bytes = pot::to_vec(&set).unwrap(); + let deserialized = pot::from_slice::(&as_bytes).unwrap(); + assert_eq!(deserialized.score(b"a"), set.score(b"a")); + assert_eq!(deserialized.score(b"b"), set.score(b"b")); + assert_eq!(deserialized.score(b"c"), set.score(b"c")); + assert_eq!(deserialized[0].value, b"c"); + assert_eq!(deserialized[1].value, b"b"); + assert_eq!(deserialized[2].value, b"a"); +} diff --git a/crates/bonsaidb-core/src/test_util.rs b/crates/bonsaidb-core/src/test_util.rs index 91691bff320..a7ecf9921ee 100644 --- a/crates/bonsaidb-core/src/test_util.rs +++ b/crates/bonsaidb-core/src/test_util.rs @@ -1630,13 +1630,12 @@ macro_rules! define_kv_test_suite { db.get_key("akey").and_delete().into().await?, Some(String::from("new_value")) ); - assert_eq!(db.get_key("akey").await?, None); - assert_eq!( - db.set_key("akey", &String::from("new_value")) - .returning_previous() - .await?, - None - ); + assert!(db.get_key("akey").await?.is_none()); + assert!(db + .set_key("akey", &String::from("new_value")) + .returning_previous() + .await? + .is_none()); assert_eq!(db.delete_key("akey").await?, KeyStatus::Deleted); assert_eq!(db.delete_key("akey").await?, KeyStatus::NotChanged); @@ -1951,10 +1950,10 @@ macro_rules! define_kv_test_suite { continue; } - assert_eq!(kv.get_key("b").await?, None, "b never expired"); + assert!(kv.get_key("b").await?.is_none(), "b never expired"); timing.wait_until(Duration::from_secs_f32(5.)).await; - assert_eq!(kv.get_key("a").await?, None, "a never expired"); + assert!(kv.get_key("a").await?.is_none(), "a never expired"); break; } harness.shutdown().await?; diff --git a/crates/bonsaidb-local/src/database/keyvalue.rs b/crates/bonsaidb-local/src/database/keyvalue.rs index ecd5c285490..ef25ab8922b 100644 --- a/crates/bonsaidb-local/src/database/keyvalue.rs +++ b/crates/bonsaidb-local/src/database/keyvalue.rs @@ -487,17 +487,16 @@ impl KeyValueState { last_updated: now, }); - match entry.value { - Value::Numeric(existing) => { - let value = Value::Numeric(op(&existing, amount, saturating).validate()?); - entry.value = value.clone(); + if let Value::Numeric(existing) = entry.value { + let value = Value::Numeric(op(&existing, amount, saturating).validate()?); + entry.value = value.clone(); - self.set(full_key, entry); - Ok(Output::Value(Some(value))) - } - Value::Bytes(_) => Err(bonsaidb_core::Error::Database(String::from( + self.set(full_key, entry); + Ok(Output::Value(Some(value))) + } else { + Err(bonsaidb_core::Error::Database(String::from( "type of stored `Value` is not `Numeric`", - ))), + ))) } }