diff --git a/src/named_type.rs b/src/named_type.rs index e12a9e2..a8e4649 100644 --- a/src/named_type.rs +++ b/src/named_type.rs @@ -19,12 +19,7 @@ pub(super) fn expand_named_type(input: DeriveInput, impl_arithmetic: bool) -> To }; match &group { - UnderlyingType::Int => { - ast.extend(implement_basic_primitive(name, value_type)); - ast.extend(implement_min_max(name, value_type)); - ast.extend(implement_hash(name)); - } - UnderlyingType::UInt => { + UnderlyingType::Int | UnderlyingType::UInt => { ast.extend(implement_basic_primitive(name, value_type)); ast.extend(implement_min_max(name, value_type)); ast.extend(implement_hash(name)); @@ -51,17 +46,13 @@ pub(super) fn expand_named_type(input: DeriveInput, impl_arithmetic: bool) -> To if impl_arithmetic { match &group { - UnderlyingType::Int => { + UnderlyingType::Int | UnderlyingType::Float => { ast.extend(implement_arithmetic(name)); ast.extend(implement_negate(name)); } UnderlyingType::UInt => { ast.extend(implement_arithmetic(name)); } - UnderlyingType::Float => { - ast.extend(implement_arithmetic(name)); - ast.extend(implement_negate(name)); - } _ => panic!("Non-arithmetic type {value_type}"), } } diff --git a/tests/named_type.rs b/tests/named_type.rs index ad74a3a..8deaa71 100644 --- a/tests/named_type.rs +++ b/tests/named_type.rs @@ -77,7 +77,7 @@ mod tests { } #[test] - fn test_arithmetic() { + fn test_int_arithmetic() { #[derive(NamedNumeric)] struct Second(i32); @@ -98,14 +98,44 @@ mod tests { assert_eq!(y, Second(6)); y /= x; assert_eq!(y, Second(3)); + + let z = Second(2); + + assert_eq!(-z, Second(-2)); + + assert_eq!(Second::max().value(), i32::MAX); + assert_eq!(Second::min().value(), i32::MIN); + } + + #[test] + fn test_float_arithmetic() { #[derive(NamedNumeric)] - struct Minute(i32); - let z = Minute(2); + struct Second(f64); + + let x = Second(2.5); + let mut y = Second(3.5); + assert_eq!(y + x, Second(6.0)); + assert_eq!(y - x, Second(1.0)); + assert_eq!(y * x, Second(8.75)); + assert_eq!(y / x, Second(1.4)); + assert!(x < y); + assert!(y >= x); + + y += x; + assert_eq!(y, Second(6.0)); + y -= x; + assert_eq!(y, Second(3.5)); + y *= x; + assert_eq!(y, Second(8.75)); + y /= x; + assert_eq!(y, Second(3.5)); + + let z = Second(2.0); - assert_eq!(-z, Minute(-2)); + assert_eq!(-z, Second(-2.0)); - assert_eq!(Minute::max().value(), i32::MAX); - assert_eq!(Minute::min().value(), i32::MIN); + assert_eq!(Second::max().value(), f64::MAX); + assert_eq!(Second::min().value(), f64::MIN); } #[test] @@ -125,7 +155,7 @@ mod tests { } #[test] - fn test_float() { + fn test_float_nan() { #[derive(NamedType)] struct Meter(f64); @@ -139,7 +169,7 @@ mod tests { } #[test] - fn test_bool() { + fn test_bool_negate() { #[derive(NamedType)] struct IsTrue(bool); let is_true = IsTrue(true);