Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix roundtrip deserialization of durations #233

Merged
merged 1 commit into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 5 additions & 10 deletions lib/src/types/duration.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use neo4rs_macros::BoltStruct;
#[derive(Debug, PartialEq, Eq, Clone, BoltStruct)]
#[signature(0xB4, 0x45)]
pub struct BoltDuration {
months: BoltInteger,
days: BoltInteger,
seconds: BoltInteger,
nanoseconds: BoltInteger,
pub(crate) months: BoltInteger,
pub(crate) days: BoltInteger,
pub(crate) seconds: BoltInteger,
pub(crate) nanoseconds: BoltInteger,
}

impl BoltDuration {
Expand All @@ -31,10 +31,6 @@ impl BoltDuration {
.saturating_add(self.days.value.saturating_mul(24 * 3600))
.saturating_add(self.months.value.saturating_mul(2_629_800))
}

pub(crate) fn nanoseconds(&self) -> i64 {
self.nanoseconds.value
}
}

impl From<std::time::Duration> for BoltDuration {
Expand All @@ -53,8 +49,7 @@ impl From<std::time::Duration> for BoltDuration {
impl From<BoltDuration> for std::time::Duration {
fn from(value: BoltDuration) -> Self {
//TODO: clarify month issue
let seconds =
value.seconds.value + (value.days.value * 24 * 3600) + (value.months.value * 2_629_800);
let seconds = value.seconds();
std::time::Duration::new(seconds as u64, value.nanoseconds.value as u32)
}
}
Expand Down
14 changes: 5 additions & 9 deletions lib/src/types/serde/date_time.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ use core::fmt;
use std::{iter::Peekable, marker::PhantomData};

use serde::de::{
value::{BorrowedStrDeserializer, MapDeserializer, SeqDeserializer},
value::{BorrowedStrDeserializer, MapDeserializer},
DeserializeSeed, Error, IntoDeserializer, MapAccess, SeqAccess, Visitor,
};

use crate::types::{serde::builder::SetOnce, BoltLocalDateTime, BoltString};
use crate::{
types::{BoltDateTime, BoltDateTimeZoneId, BoltDuration, BoltInteger},
types::{
serde::builder::SetOnce, BoltDateTime, BoltDateTimeZoneId, BoltInteger, BoltLocalDateTime,
BoltString,
},
DeError,
};

Expand Down Expand Up @@ -57,12 +59,6 @@ impl BoltDateTimeZoneId {
}
}

impl BoltDuration {
pub(crate) fn seq_access(&self) -> impl SeqAccess<'_, Error = DeError> {
SeqDeserializer::new([self.seconds(), self.nanoseconds()].into_iter())
}
}

struct BoltDateTimeZoneIdAccess<'a, const N: usize>(
&'a BoltDateTimeZoneId,
Peekable<<[Fields; N] as IntoIterator>::IntoIter>,
Expand Down
152 changes: 152 additions & 0 deletions lib/src/types/serde/duration.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
use core::fmt;

use serde::de::{value::SeqDeserializer, Error, MapAccess, SeqAccess, Visitor};

use crate::{
types::{serde::builder::SetOnce, BoltDuration, BoltInteger},
DeError,
};

crate::cenum!(Fields {
Months,
Days,
Seconds,
NanoSeconds,
});

impl BoltDuration {
pub(crate) fn seq_access_bolt(&self) -> impl SeqAccess<'_, Error = DeError> {
SeqDeserializer::new(
[
self.months.value,
self.days.value,
self.seconds.value,
self.nanoseconds.value,
]
.into_iter(),
)
}
pub(crate) fn seq_access_external(&self) -> impl SeqAccess<'_, Error = DeError> {
SeqDeserializer::new([self.seconds(), self.nanoseconds.value].into_iter())
}
}

pub struct BoltDurationVisitor;

impl<'de> Visitor<'de> for BoltDurationVisitor {
type Value = BoltDuration;

fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("BoltDuration struct")
}

fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut builder = DurationBuilder::default();

while let Some(key) = map.next_key::<Fields>()? {
match key {
Fields::Months => builder.months(|| map.next_value())?,
Fields::Days => builder.days(|| map.next_value())?,
Fields::Seconds => builder.seconds(|| map.next_value())?,
Fields::NanoSeconds => builder.nanoseconds(|| map.next_value())?,
}
}

builder.build()
}

fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
const FIELDS: [Fields; 4] = [
Fields::Months,
Fields::Days,
Fields::Seconds,
Fields::NanoSeconds,
];

let mut require_next = |field| {
seq.next_element()
.and_then(|value| value.ok_or_else(|| Error::missing_field(field)))
};

let mut builder = DurationBuilder::default();

for field in FIELDS {
match field {
Fields::Months => builder.months(|| require_next("months"))?,
Fields::Days => builder.days(|| require_next("days"))?,
Fields::Seconds => builder.seconds(|| require_next("seconds"))?,
Fields::NanoSeconds => builder.nanoseconds(|| require_next("nanoseconds"))?,
}
}

if seq.next_element::<serde::de::IgnoredAny>()?.is_some() {
return Err(Error::invalid_length(0, &"4"));
}

builder.build()
}
}

#[derive(Default)]
pub(crate) struct DurationBuilder {
pub(crate) months: SetOnce<BoltInteger>,
pub(crate) days: SetOnce<BoltInteger>,
pub(crate) seconds: SetOnce<BoltInteger>,
pub(crate) nanoseconds: SetOnce<BoltInteger>,
}

impl DurationBuilder {
fn months<E: Error>(&mut self, f: impl FnOnce() -> Result<BoltInteger, E>) -> Result<(), E> {
self.months
.try_insert_with(f)
.map_or_else(|_| Err(Error::duplicate_field("months")), |_| Ok(()))
}

fn days<E: Error>(&mut self, f: impl FnOnce() -> Result<BoltInteger, E>) -> Result<(), E> {
self.days
.try_insert_with(f)
.map_or_else(|_| Err(Error::duplicate_field("days")), |_| Ok(()))
}

fn seconds<E: Error>(&mut self, f: impl FnOnce() -> Result<BoltInteger, E>) -> Result<(), E> {
self.seconds
.try_insert_with(f)
.map_or_else(|_| Err(Error::duplicate_field("seconds")), |_| Ok(()))
}

fn nanoseconds<E: Error>(
&mut self,
f: impl FnOnce() -> Result<BoltInteger, E>,
) -> Result<(), E> {
self.nanoseconds
.try_insert_with(f)
.map_or_else(|_| Err(Error::duplicate_field("nanoseconds")), |_| Ok(()))
}

fn build<E: Error>(mut self: DurationBuilder) -> Result<BoltDuration, E> {
Ok(BoltDuration {
months: self
.months
.take()
.ok_or_else(|| Error::missing_field("months"))?,
days: self
.days
.take()
.ok_or_else(|| Error::missing_field("days"))?,
seconds: self
.seconds
.take()
.ok_or_else(|| Error::missing_field("seconds"))?,
nanoseconds: self
.nanoseconds
.take()
.ok_or_else(|| Error::missing_field("nanoseconds"))?,
})
}
}
1 change: 1 addition & 0 deletions lib/src/types/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod builder;
mod cenum;
mod date_time;
mod de;
mod duration;
mod element;
mod error;
mod kind;
Expand Down
29 changes: 24 additions & 5 deletions lib/src/types/serde/typ.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::{
types::{
serde::{
date_time::BoltDateTimeVisitor,
duration::BoltDurationVisitor,
element::ElementDataDeserializer,
node::BoltNodeVisitor,
path::BoltPathVisitor,
Expand Down Expand Up @@ -240,7 +241,9 @@ impl<'de> Visitor<'de> for BoltTypeVisitor {
BoltKind::Path => variant
.tuple_variant(1, BoltPathVisitor)
.map(BoltType::Path),
BoltKind::Duration => variant.tuple_variant(1, self),
BoltKind::Duration => variant
.tuple_variant(1, BoltDurationVisitor)
.map(BoltType::Duration),
BoltKind::Date => variant
.tuple_variant(1, BoltDateTimeVisitor::<BoltDate>::new())
.map(BoltType::Date),
Expand Down Expand Up @@ -328,7 +331,7 @@ impl<'de> Deserializer<'de> for BoltTypeDeserializer<'de> {
BoltType::Point3D(p) => p
.into_deserializer()
.deserialize_struct(name, fields, visitor),
BoltType::Duration(d) => visitor.visit_seq(d.seq_access()),
BoltType::Duration(d) => visitor.visit_seq(d.seq_access_external()),
_ => self.unexpected(visitor),
}
}
Expand Down Expand Up @@ -360,7 +363,7 @@ impl<'de> Deserializer<'de> for BoltTypeDeserializer<'de> {
BoltType::Point3D(p) => p
.into_deserializer()
.deserialize_newtype_struct(name, visitor),
BoltType::Duration(d) => visitor.visit_seq(d.seq_access()),
BoltType::Duration(d) => visitor.visit_seq(d.seq_access_external()),
BoltType::DateTimeZoneId(dtz) if name == "Timezone" => {
visitor.visit_newtype_struct(BorrowedStrDeserializer::new(dtz.tz_id()))
}
Expand All @@ -378,7 +381,8 @@ impl<'de> Deserializer<'de> for BoltTypeDeserializer<'de> {
}
BoltType::Point2D(p) => p.into_deserializer().deserialize_tuple(len, visitor),
BoltType::Point3D(p) => p.into_deserializer().deserialize_tuple(len, visitor),
BoltType::Duration(d) if len == 2 => visitor.visit_seq(d.seq_access()),
BoltType::Duration(d) if len == 2 => visitor.visit_seq(d.seq_access_external()),
BoltType::Duration(d) if len == 4 => visitor.visit_seq(d.seq_access_bolt()),
BoltType::DateTimeZoneId(dtz) => visitor.visit_seq(
dtz.seq_access(
std::any::type_name::<V>()
Expand Down Expand Up @@ -879,7 +883,8 @@ impl<'de> VariantAccess<'de> for BoltEnum<'de> {
BoltType::Point3D(p) => BoltPointDeserializer::new(p).deserialize_tuple(len, visitor),
BoltType::Bytes(b) => visitor.visit_borrowed_bytes(&b.value),
BoltType::Path(p) => ElementDataDeserializer::new(p).tuple_variant(len, visitor),
BoltType::Duration(d) => visitor.visit_seq(d.seq_access()),
BoltType::Duration(d) if len == 1 => visitor.visit_seq(d.seq_access_bolt()),
BoltType::Duration(d) => visitor.visit_seq(d.seq_access_external()),
BoltType::Date(d) => visitor.visit_map(d.map_access()),
BoltType::Time(t) => visitor.visit_map(t.map_access()),
BoltType::LocalTime(t) => visitor.visit_map(t.map_access()),
Expand Down Expand Up @@ -2007,6 +2012,20 @@ mod tests {
assert_eq!(actual, duration);
}

#[test]
fn duration_roundtrip() {
let duration = BoltDuration::from(Duration::new(42, 1337));

let bolt = BoltType::Duration(duration.clone());

let actual = bolt.to::<BoltType>().unwrap();
let BoltType::Duration(actual) = actual else {
panic!()
};

assert_eq!(actual, duration);
}

fn test_date() -> NaiveDate {
NaiveDate::from_ymd_opt(1999, 7, 14).unwrap()
}
Expand Down
34 changes: 34 additions & 0 deletions lib/tests/duration_deserialization.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use neo4rs::*;

mod container;

#[tokio::test]
async fn duration_deserialization() {
let neo4j = container::Neo4jContainer::new().await;
let graph = neo4j.graph();

let duration = std::time::Duration::new(5259600, 7);
let mut result = graph
.execute(query("RETURN $d as output").param("d", duration))
.await
.unwrap();
let row = result.next().await.unwrap().unwrap();
let d: std::time::Duration = row.get("output").unwrap();
assert_eq!(d, duration);

let mut result = graph
.execute(query("RETURN $d as output").param("d", duration))
.await
.unwrap();
let row = result.next().await.unwrap().unwrap();
let d = row.get::<BoltType>("output").unwrap();
assert_eq!(
d,
BoltType::Duration(BoltDuration::new(
0.into(),
0.into(),
5259600.into(),
7.into(),
))
);
}
Loading