diff --git a/packages/storage/src/de.rs b/packages/storage/src/de.rs new file mode 100644 index 0000000..e5c3673 --- /dev/null +++ b/packages/storage/src/de.rs @@ -0,0 +1,249 @@ +use std::array::TryFromSliceError; +use std::convert::TryInto; + +use cosmwasm_std::{Addr, StdError, StdResult}; +use serde::{Serialize, de::DeserializeOwned}; + +pub trait KeyDeserialize { + type Output: Sized + DeserializeOwned + Serialize; + + fn from_vec(value: Vec) -> StdResult; + + fn from_slice(value: &[u8]) -> StdResult { + Self::from_vec(value.to_vec()) + } +} + +impl KeyDeserialize for () { + type Output = (); + + #[inline(always)] + fn from_vec(_value: Vec) -> StdResult { + Ok(()) + } +} + +impl KeyDeserialize for Vec { + type Output = Vec; + + #[inline(always)] + fn from_vec(value: Vec) -> StdResult { + Ok(value) + } +} + +impl KeyDeserialize for &Vec { + type Output = Vec; + + #[inline(always)] + fn from_vec(value: Vec) -> StdResult { + Ok(value) + } +} + +impl KeyDeserialize for &[u8] { + type Output = Vec; + + #[inline(always)] + fn from_vec(value: Vec) -> StdResult { + Ok(value) + } +} + +impl KeyDeserialize for String { + type Output = String; + + #[inline(always)] + fn from_vec(value: Vec) -> StdResult { + String::from_utf8(value).map_err(StdError::invalid_utf8) + } +} + +impl KeyDeserialize for &String { + type Output = String; + + #[inline(always)] + fn from_vec(value: Vec) -> StdResult { + Self::Output::from_vec(value) + } +} + +impl KeyDeserialize for &str { + type Output = String; + + #[inline(always)] + fn from_vec(value: Vec) -> StdResult { + Self::Output::from_vec(value) + } +} + +impl KeyDeserialize for Addr { + type Output = Addr; + + #[inline(always)] + fn from_vec(value: Vec) -> StdResult { + Ok(Addr::unchecked(String::from_vec(value)?)) + } +} + +impl KeyDeserialize for &Addr { + type Output = Addr; + + #[inline(always)] + fn from_vec(value: Vec) -> StdResult { + Self::Output::from_vec(value) + } +} + +macro_rules! integer_de { + (for $($t:ty),+) => { + $(impl KeyDeserialize for $t { + type Output = $t; + + #[inline(always)] + fn from_vec(value: Vec) -> StdResult { + Ok(<$t>::from_be_bytes(value.as_slice().try_into() + .map_err(|err: TryFromSliceError| StdError::generic_err(err.to_string()))?)) + } + })* + } +} + +integer_de!(for i8, u8, i16, u16, i32, u32, i64, u64, i128, u128); + +fn parse_length(value: &[u8]) -> StdResult { + Ok(u16::from_be_bytes( + value + .try_into() + .map_err(|_| StdError::generic_err("Could not read 2 byte length"))?, + ) + .into()) +} + +impl KeyDeserialize for (T, U) { + type Output = (T::Output, U::Output); + + #[inline(always)] + fn from_vec(mut value: Vec) -> StdResult { + let mut tu = value.split_off(2); + let t_len = parse_length(&value)?; + let u = tu.split_off(t_len); + + Ok((T::from_vec(tu)?, U::from_vec(u)?)) + } +} + +impl KeyDeserialize for (T, U, V) { + type Output = (T::Output, U::Output, V::Output); + + #[inline(always)] + fn from_vec(mut value: Vec) -> StdResult { + let mut tuv = value.split_off(2); + let t_len = parse_length(&value)?; + let mut len_uv = tuv.split_off(t_len); + + let mut uv = len_uv.split_off(2); + let u_len = parse_length(&len_uv)?; + let v = uv.split_off(u_len); + + Ok((T::from_vec(tuv)?, U::from_vec(uv)?, V::from_vec(v)?)) + } +} + +#[cfg(test)] +mod test { + use super::*; + + const BYTES: &[u8] = b"Hello"; + const STRING: &str = "Hello"; + + #[test] + #[allow(clippy::unit_cmp)] + fn deserialize_empty_works() { + assert_eq!(<()>::from_slice(BYTES).unwrap(), ()); + } + + #[test] + fn deserialize_bytes_works() { + assert_eq!(>::from_slice(BYTES).unwrap(), BYTES); + assert_eq!(<&Vec>::from_slice(BYTES).unwrap(), BYTES); + assert_eq!(<&[u8]>::from_slice(BYTES).unwrap(), BYTES); + } + + #[test] + fn deserialize_string_works() { + assert_eq!(::from_slice(BYTES).unwrap(), STRING); + assert_eq!(<&String>::from_slice(BYTES).unwrap(), STRING); + assert_eq!(<&str>::from_slice(BYTES).unwrap(), STRING); + } + + #[test] + fn deserialize_broken_string_errs() { + assert!(matches!( + ::from_slice(b"\xc3").err(), + Some(StdError::InvalidUtf8 { .. }) + )); + } + + #[test] + fn deserialize_addr_works() { + assert_eq!(::from_slice(BYTES).unwrap(), Addr::unchecked(STRING)); + assert_eq!(<&Addr>::from_slice(BYTES).unwrap(), Addr::unchecked(STRING)); + } + + #[test] + fn deserialize_broken_addr_errs() { + assert!(matches!( + ::from_slice(b"\xc3").err(), + Some(StdError::InvalidUtf8 { .. }) + )); + } + + #[test] + fn deserialize_naked_integer_works() { + assert_eq!(u8::from_slice(&[1]).unwrap(), 1u8); + assert_eq!(i8::from_slice(&[127]).unwrap(), -1i8); + assert_eq!(i8::from_slice(&[128]).unwrap(), 0i8); + + assert_eq!(u16::from_slice(&[1, 0]).unwrap(), 256u16); + assert_eq!(i16::from_slice(&[128, 0]).unwrap(), 0i16); + assert_eq!(i16::from_slice(&[127, 255]).unwrap(), -1i16); + + assert_eq!(u32::from_slice(&[1, 0, 0, 0]).unwrap(), 16777216u32); + assert_eq!(i32::from_slice(&[128, 0, 0, 0]).unwrap(), 0i32); + assert_eq!(i32::from_slice(&[127, 255, 255, 255]).unwrap(), -1i32); + + assert_eq!( + u64::from_slice(&[1, 0, 0, 0, 0, 0, 0, 0]).unwrap(), + 72057594037927936u64 + ); + assert_eq!(i64::from_slice(&[128, 0, 0, 0, 0, 0, 0, 0]).unwrap(), 0i64); + assert_eq!( + i64::from_slice(&[127, 255, 255, 255, 255, 255, 255, 255]).unwrap(), + -1i64 + ); + + assert_eq!( + u128::from_slice(&[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).unwrap(), + 1329227995784915872903807060280344576u128 + ); + assert_eq!( + i128::from_slice(&[128, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]).unwrap(), + 0i128 + ); + assert_eq!( + i128::from_slice(&[ + 127, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 + ]) + .unwrap(), + -1i128 + ); + assert_eq!( + i128::from_slice(&[ + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255 + ]) + .unwrap(), + 170141183460469231731687303715884105727i128, + ); + } +} diff --git a/packages/storage/src/keymap.rs b/packages/storage/src/keymap.rs index df39c12..02e9895 100644 --- a/packages/storage/src/keymap.rs +++ b/packages/storage/src/keymap.rs @@ -13,6 +13,7 @@ use cosmwasm_storage::to_length_prefixed; use secret_toolkit_serialization::{Bincode2, Serde}; use crate::{IterOption, WithIter, WithoutIter}; +use crate::de::KeyDeserialize; const INDEXES: &[u8] = b"indexes"; const MAP_LENGTH: &[u8] = b"length"; @@ -59,7 +60,7 @@ pub struct KeymapBuilder<'a, K, T, Ser = Bincode2, I = WithIter> { impl<'a, K, T, Ser> KeymapBuilder<'a, K, T, Ser, WithIter> where - K: Serialize + DeserializeOwned, + K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde, { @@ -117,7 +118,7 @@ where // This enables writing `.iter().skip(n).rev()` impl<'a, K, T, Ser> KeymapBuilder<'a, K, T, Ser, WithoutIter> where - K: Serialize + DeserializeOwned, + K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde, { @@ -137,7 +138,7 @@ where pub struct Keymap<'a, K, T, Ser = Bincode2, I = WithIter> where - K: Serialize + DeserializeOwned, + K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde, I: IterOption, @@ -154,7 +155,7 @@ where serialization_type: PhantomData, } -impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: Serde> +impl<'a, K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde> Keymap<'a, K, T, Ser> { /// constructor @@ -190,7 +191,7 @@ impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: } } -impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: Serde> +impl<'a, K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde> Keymap<'a, K, T, Ser, WithoutIter> { /// Serialize key @@ -232,7 +233,7 @@ impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: } } -impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: Serde> +impl<'a, K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde> Keymap<'a, K, T, Ser, WithIter> { /// Serialize key @@ -241,7 +242,7 @@ impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: } /// Deserialize key - fn deserialize_key(&self, key_data: &[u8]) -> StdResult { + fn deserialize_key(&self, key_data: &[u8]) -> StdResult { Ser::deserialize(key_data) } @@ -341,6 +342,11 @@ impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: self.load_impl(storage, &key_vec) } + fn get_from_deserialized_key(&self, storage: &dyn Storage, key: &K::Output) -> StdResult> { + let key_vec = Ser::serialize(key)?; + self.load_impl(storage, &key_vec) + } + /// user facing remove function pub fn remove(&self, storage: &mut dyn Storage, key: &K) -> StdResult<()> { let key_vec = self.serialize_key(key)?; @@ -444,7 +450,7 @@ impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: storage: &dyn Storage, start_page: u32, size: u32, - ) -> StdResult> { + ) -> StdResult> { let start_pos = start_page * size; let mut end_pos = start_pos + size - 1; @@ -470,7 +476,7 @@ impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: storage: &dyn Storage, start_page: u32, size: u32, - ) -> StdResult> { + ) -> StdResult> { let start_pos = start_page * size; let mut end_pos = start_pos + size - 1; @@ -496,7 +502,7 @@ impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: storage: &dyn Storage, start: u32, end: u32, - ) -> StdResult> { + ) -> StdResult> { let start_page = self.page_from_position(start); let end_page = self.page_from_position(end); @@ -529,7 +535,7 @@ impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: storage: &dyn Storage, start: u32, end: u32, - ) -> StdResult> { + ) -> StdResult> { let start_page = self.page_from_position(start); let end_page = self.page_from_position(end); @@ -572,7 +578,7 @@ impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: } } -impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: Serde> +impl<'a, K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde> PrefixedTypedStorage, Bincode2> for Keymap<'a, K, T, Ser, WithIter> { fn as_slice(&self) -> &[u8] { @@ -584,7 +590,7 @@ impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: } } -impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: Serde> +impl<'a, K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde> PrefixedTypedStorage for Keymap<'a, K, T, Ser, WithoutIter> { fn as_slice(&self) -> &[u8] { @@ -599,7 +605,7 @@ impl<'a, K: Serialize + DeserializeOwned, T: Serialize + DeserializeOwned, Ser: /// An iterator over the keys of the Keymap. pub struct KeyIter<'a, K, T, Ser> where - K: Serialize + DeserializeOwned, + K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde, { @@ -612,7 +618,7 @@ where impl<'a, K, T, Ser> KeyIter<'a, K, T, Ser> where - K: Serialize + DeserializeOwned, + K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde, { @@ -635,11 +641,11 @@ where impl<'a, K, T, Ser> Iterator for KeyIter<'a, K, T, Ser> where - K: Serialize + DeserializeOwned, + K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde, { - type Item = StdResult; + type Item = StdResult; fn next(&mut self) -> Option { if self.start >= self.end { @@ -688,7 +694,7 @@ where impl<'a, K, T, Ser> DoubleEndedIterator for KeyIter<'a, K, T, Ser> where - K: Serialize + DeserializeOwned, + K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde, { @@ -734,7 +740,7 @@ where // This enables writing `.iter().skip(n).rev()` impl<'a, K, T, Ser> ExactSizeIterator for KeyIter<'a, K, T, Ser> where - K: Serialize + DeserializeOwned, + K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde, { @@ -745,7 +751,7 @@ where /// An iterator over the (key, item) pairs of the Keymap. Less efficient than just iterating over keys. pub struct KeyItemIter<'a, K, T, Ser> where - K: Serialize + DeserializeOwned, + K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde, { @@ -758,7 +764,7 @@ where impl<'a, K, T, Ser> KeyItemIter<'a, K, T, Ser> where - K: Serialize + DeserializeOwned, + K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde, { @@ -781,11 +787,11 @@ where impl<'a, K, T, Ser> Iterator for KeyItemIter<'a, K, T, Ser> where - K: Serialize + DeserializeOwned, + K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde, { - type Item = StdResult<(K, T)>; + type Item = StdResult<(K::Output, T)>; fn next(&mut self) -> Option { if self.start >= self.end { @@ -813,7 +819,7 @@ where self.start += 1; // turn key into pair let pair = match key { - Ok(k) => match self.keymap.get_from_key(self.storage, &k) { + Ok(k) => match self.keymap.get_from_deserialized_key(self.storage, &k) { Ok(internal_item) => match internal_item.get_item() { Ok(item) => Ok((k, item)), Err(e) => Err(e), @@ -845,7 +851,7 @@ where impl<'a, K, T, Ser> DoubleEndedIterator for KeyItemIter<'a, K, T, Ser> where - K: Serialize + DeserializeOwned, + K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde, { @@ -875,7 +881,7 @@ where } // turn key into pair let pair = match key { - Ok(k) => match self.keymap.get_from_key(self.storage, &k) { + Ok(k) => match self.keymap.get_from_deserialized_key(self.storage, &k) { Ok(internal_item) => match internal_item.get_item() { Ok(item) => Ok((k, item)), Err(e) => Err(e), @@ -902,7 +908,7 @@ where // This enables writing `.iter().skip(n).rev()` impl<'a, K, T, Ser> ExactSizeIterator for KeyItemIter<'a, K, T, Ser> where - K: Serialize + DeserializeOwned, + K: Serialize + KeyDeserialize, T: Serialize + DeserializeOwned, Ser: Serde, { diff --git a/packages/storage/src/lib.rs b/packages/storage/src/lib.rs index c6b0f14..b5d5b91 100644 --- a/packages/storage/src/lib.rs +++ b/packages/storage/src/lib.rs @@ -5,6 +5,7 @@ pub mod deque_store; pub mod item; pub mod keymap; pub mod keyset; +pub mod de; pub use append_store::AppendStore; pub use deque_store::DequeStore;