diff --git a/libsql/src/de.rs b/libsql/src/de.rs index 44f231c134..7fde70d6de 100644 --- a/libsql/src/de.rs +++ b/libsql/src/de.rs @@ -1,7 +1,7 @@ //! Deserialization utilities. use crate::{Row, Value}; -use serde::de::{value::Error as DeError, Error, IntoDeserializer, MapAccess, Visitor}; +use serde::de::{value::Error as DeError, Error, IntoDeserializer, MapAccess, SeqAccess, Visitor}; use serde::{Deserialize, Deserializer}; struct RowDeserializer<'de> { @@ -15,15 +15,12 @@ impl<'de> Deserializer<'de> for RowDeserializer<'de> { where V: Visitor<'de>, { - Err(DeError::custom("Expects a struct")) + Err(DeError::custom( + "Expects a map, newtype, sequence, struct, or tuple", + )) } - fn deserialize_struct( - self, - _name: &'static str, - _fields: &'static [&'static str], - visitor: V, - ) -> Result + fn deserialize_map(self, visitor: V) -> Result where V: Visitor<'de>, { @@ -73,10 +70,83 @@ impl<'de> Deserializer<'de> for RowDeserializer<'de> { }) } + fn deserialize_struct( + self, + _name: &'static str, + _fields: &'static [&'static str], + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.deserialize_map(visitor) + } + + fn deserialize_seq(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + struct RowSeqAccess<'a> { + row: &'a Row, + idx: std::ops::Range, + } + + impl<'de> SeqAccess<'de> for RowSeqAccess<'de> { + type Error = DeError; + + fn next_element_seed(&mut self, seed: T) -> Result, Self::Error> + where + T: serde::de::DeserializeSeed<'de>, + { + match self.idx.next() { + None => Ok(None), + Some(i) => { + let value = self.row.get_value(i as i32).map_err(DeError::custom)?; + seed.deserialize(value.into_deserializer()).map(Some) + } + } + } + } + + visitor.visit_seq(RowSeqAccess { + row: self.row, + idx: 0..(self.row.column_count() as usize), + }) + } + + fn deserialize_tuple(self, _len: usize, visitor: V) -> Result + where + V: Visitor<'de>, + { + self.deserialize_seq(visitor) + } + + fn deserialize_tuple_struct( + self, + _name: &'static str, + _len: usize, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + self.deserialize_seq(visitor) + } + + fn deserialize_newtype_struct( + self, + _name: &'static str, + visitor: V, + ) -> Result + where + V: Visitor<'de>, + { + visitor.visit_newtype_struct(self) + } + serde::forward_to_deserialize_any! { bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string - bytes byte_buf option unit unit_struct newtype_struct seq tuple - tuple_struct map enum identifier ignored_any + bytes byte_buf option unit unit_struct enum identifier ignored_any } } diff --git a/libsql/tests/integration_tests.rs b/libsql/tests/integration_tests.rs index 57addab948..71c70bee18 100644 --- a/libsql/tests/integration_tests.rs +++ b/libsql/tests/integration_tests.rs @@ -643,9 +643,11 @@ async fn deserialize_row() { .await .unwrap(); + use std::collections::HashMap; + use serde::Deserialize; - #[derive(Deserialize, Debug)] + #[derive(Deserialize, Debug, PartialEq)] struct Data { id: i64, name: String, @@ -684,6 +686,51 @@ async fn deserialize_row() { assert_eq!(data.none, None); assert_eq!(data.status, Status::Draft); assert_eq!(data.wrapper, Wrapper(Status::Published)); + + #[derive(Deserialize, Debug)] + struct Newtype(Data); + let newtype: Newtype = libsql::de::from_row(&row).unwrap(); + assert_eq!(newtype.0, data); + + let tuple: (i64, String, f64, Vec, Option, Status, Wrapper) = + libsql::de::from_row(&row).unwrap(); + assert_eq!( + tuple, + ( + 123, + "potato".to_string(), + 42.0, + vec![0xde, 0xad, 0xbe, 0xef], + None, + Status::Draft, + Wrapper(Status::Published) + ) + ); + + let row2 = conn + .query("SELECT name, status, wrapper FROM users", ()) + .await + .unwrap() + .next() + .await + .unwrap() + .unwrap(); + let arr: Vec = libsql::de::from_row(&row2).unwrap(); + assert_eq!(arr, vec!["potato", "Draft", "Published"]); + + let map: HashMap = libsql::de::from_row(&row2).unwrap(); + assert_eq!( + map, + HashMap::from_iter( + [ + ("name", "potato"), + ("status", "Draft"), + ("wrapper", "Published"), + ] + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + ) + ); } #[tokio::test]