From 324804f09de09693384c2fd71a3c3f7b371beec1 Mon Sep 17 00:00:00 2001 From: Matthew Hennefarth Date: Thu, 30 Jan 2025 21:28:54 -0600 Subject: [PATCH 01/12] change to trait --- sci-rs/src/special/combinatorics.rs | 138 +++++++++++++++++----------- 1 file changed, 82 insertions(+), 56 deletions(-) diff --git a/sci-rs/src/special/combinatorics.rs b/sci-rs/src/special/combinatorics.rs index 9702a79e..f31945a6 100644 --- a/sci-rs/src/special/combinatorics.rs +++ b/sci-rs/src/special/combinatorics.rs @@ -1,24 +1,70 @@ use nalgebra::min; use num_traits::{FromPrimitive, PrimInt}; -/// The number of combinations of `n` taken `k` at a time. -/// -/// This is also known as $n$ choose $k$ and is generally given by the formula -/// $$ -/// \begin{pmatrix} -/// n \\\\ k -/// \end{pmatrix} = \frac{n!}{k!(n-k)!} -/// $$ -/// -/// # Examples -/// ``` -/// use sci_rs::special::comb; -/// assert_eq!(comb(5, 2), 10); -/// ``` -/// -/// # Notes -/// When `n < 0` or `k < 0` or `n < k`, then `0` is returned. -pub fn comb(n: Int, k: Int) -> Int +/// Various combinatorics functions for integer types. +pub trait Comb { + /// The number of combinations of `n` taken `k` at a time. + /// + /// This is also known as $n$ choose $k$ and is generally given by the formula + /// $$ + /// \begin{pmatrix} + /// n \\\\ k + /// \end{pmatrix} = \frac{n!}{k!(n-k)!} + /// $$ + /// + /// # Examples + /// ``` + /// use sci_rs::special::Comb; + /// assert_eq!(5_i32.comb(2), 10); + /// ``` + /// + /// # Notes + /// When `n < 0` or `k < 0` or `n < k`, then `0` is returned. + fn comb(self, k: Self) -> Self; + + /// Number of combinations with repetition. + /// + /// The number of combinations of `n` taken `k` at a time with repetition. This is also known as a + /// `k`-combination with repetition or `k`-multicombinations. For a more detailed explanation, see + /// the [wiki] page. + /// + /// # Examples + /// ``` + /// use sci_rs::special::Comb; + /// assert_eq!(5_i32.comb_rep(2), 15); + /// assert_eq!(10_i32.comb_rep(3), 220); + /// ``` + /// + /// # Notes + /// When `n < 0` or `k < 0` or `n < k`, then `0` is returned. + /// + /// # References + /// - [Wikipedia][wiki] + /// + /// [wiki]: https://en.wikipedia.org/wiki/Combination#Number_of_combinations_with_repetition + fn comb_rep(self, k: Self) -> Self; +} + +macro_rules! comb_primint_impl { + ($($T: ty)*) => ($( + impl Comb for $T { + #[inline(always)] + fn comb(self, k: Self) -> Self { + primint_comb(self, k) + } + + #[inline(always)] + fn comb_rep(self, k: Self) -> Self { + primint_comb_rep(self, k) + } + + } + )*) +} + +comb_primint_impl! {u8 u16 u32 u64 u128 usize i8 i16 i32 i64 i128 isize} + +fn primint_comb(n: Int, k: Int) -> Int where Int: PrimInt + FromPrimitive, { @@ -32,31 +78,11 @@ where }) } -/// Number of combinations with repetition. -/// -/// The number of combinations of `n` taken `k` at a time with repetition. This is also known as a -/// `k`-combination with repetition or `k`-multicombinations. For a more detailed explanation, see -/// the [wiki] page. -/// -/// # Examples -/// ``` -/// use sci_rs::special::comb_rep; -/// assert_eq!(comb_rep(5, 2), 15); -/// assert_eq!(comb_rep(10, 3), 220); -/// ``` -/// -/// # Notes -/// When `n < 0` or `k < 0` or `n < k`, then `0` is returned. -/// -/// # References -/// - [Wikipedia][wiki] -/// -/// [wiki]: https://en.wikipedia.org/wiki/Combination#Number_of_combinations_with_repetition -pub fn comb_rep(n: Int, k: Int) -> Int +fn primint_comb_rep(n: Int, k: Int) -> Int where Int: PrimInt + FromPrimitive, { - comb(n + k - Int::one(), k) + primint_comb(n + k - Int::one(), k) } #[cfg(test)] @@ -76,26 +102,26 @@ mod tests { #[test] fn choose() { - assert_eq!(comb(3_u8, 1), 3); - assert_eq!(comb(3_u8, 2), 3); - assert_eq!(comb(3_u8, 3), 1); + assert_eq!(3_u8.comb(1), 3); + assert_eq!(3_u8.comb(2), 3); + assert_eq!(3_u8.comb(3), 1); const REF_VALUES_5: [i32; 7] = [1, 5, 10, 10, 5, 1, 0]; const REF_VALUES_10: [i32; 12] = [1, 10, 45, 120, 210, 252, 210, 120, 45, 10, 1, 0]; const REF_VALUES_15: [i32; 16] = [ 1, 15, 105, 455, 1365, 3003, 5005, 6435, 6435, 5005, 3003, 1365, 455, 105, 15, 1, ]; - check_values(5, &REF_VALUES_5, comb); - check_values(10, &REF_VALUES_10, comb); - check_values(15, &REF_VALUES_15, comb); + check_values(5, &REF_VALUES_5, i32::comb); + check_values(10, &REF_VALUES_10, i32::comb); + check_values(15, &REF_VALUES_15, i32::comb); } #[test] fn choose_negatives() { for n in -10..-1 { for m in -5..5 { - assert_eq!(comb(n, m), 0); - assert_eq!(comb(m, n), 0); + assert_eq!(n.comb(m), 0); + assert_eq!(m.comb(n), 0); } } } @@ -104,14 +130,14 @@ mod tests { fn choose_greater_than() { for i in 0..10 { for j in 0..i { - assert_eq!(comb(j, i), 0); + assert_eq!(j.comb(i), 0); } } } #[test] fn zero_choose_zero() { - assert_eq!(comb(0, 0), 1); + assert_eq!(0.comb(0), 1); } #[test] @@ -122,17 +148,17 @@ mod tests { 1, 10, 55, 220, 715, 2002, 5005, 11440, 24310, 48620, 92378, 167960, 293930, 497420, 817190, ]; - check_values(5, &REF_VALUES_5, comb_rep); - check_values(7, &REF_VALUES_7, comb_rep); - check_values(10, &REF_VALUES_10, comb_rep); + check_values(5, &REF_VALUES_5, i32::comb_rep); + check_values(7, &REF_VALUES_7, i32::comb_rep); + check_values(10, &REF_VALUES_10, i32::comb_rep); } #[test] fn choose_replacement_negatives() { for n in -10..-1 { for m in -5..5 { - assert_eq!(comb_rep(n, m), 0); - assert_eq!(comb_rep(m, n), 0); + assert_eq!(n.comb_rep(m), 0); + assert_eq!(m.comb_rep(n), 0); } } } @@ -141,7 +167,7 @@ mod tests { fn choose_zero_replacement() { for i in 0..1 { for j in 0..1 { - assert_eq!(comb_rep(i, j), i); + assert_eq!(i.comb_rep(j), i); } } } From 267c984388879885493485b2314c307608d5660f Mon Sep 17 00:00:00 2001 From: Matthew Hennefarth Date: Tue, 11 Feb 2025 21:25:45 -0600 Subject: [PATCH 02/12] update comb to combinatoric name --- sci-rs/src/special/combinatorics.rs | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sci-rs/src/special/combinatorics.rs b/sci-rs/src/special/combinatorics.rs index f31945a6..5f05b185 100644 --- a/sci-rs/src/special/combinatorics.rs +++ b/sci-rs/src/special/combinatorics.rs @@ -2,7 +2,7 @@ use nalgebra::min; use num_traits::{FromPrimitive, PrimInt}; /// Various combinatorics functions for integer types. -pub trait Comb { +pub trait Combinatoric { /// The number of combinations of `n` taken `k` at a time. /// /// This is also known as $n$ choose $k$ and is generally given by the formula @@ -14,7 +14,7 @@ pub trait Comb { /// /// # Examples /// ``` - /// use sci_rs::special::Comb; + /// use sci_rs::special::Combinatoric; /// assert_eq!(5_i32.comb(2), 10); /// ``` /// @@ -30,7 +30,7 @@ pub trait Comb { /// /// # Examples /// ``` - /// use sci_rs::special::Comb; + /// use sci_rs::special::Combinatoric; /// assert_eq!(5_i32.comb_rep(2), 15); /// assert_eq!(10_i32.comb_rep(3), 220); /// ``` @@ -45,9 +45,9 @@ pub trait Comb { fn comb_rep(self, k: Self) -> Self; } -macro_rules! comb_primint_impl { +macro_rules! combinatoric_primint_impl { ($($T: ty)*) => ($( - impl Comb for $T { + impl Combinatoric for $T { #[inline(always)] fn comb(self, k: Self) -> Self { primint_comb(self, k) @@ -62,7 +62,7 @@ macro_rules! comb_primint_impl { )*) } -comb_primint_impl! {u8 u16 u32 u64 u128 usize i8 i16 i32 i64 i128 isize} +combinatoric_primint_impl! {u8 u16 u32 u64 u128 usize i8 i16 i32 i64 i128 isize} fn primint_comb(n: Int, k: Int) -> Int where From cbbfb699f1f4874c6a4a7c7dc19d20a60ed04576 Mon Sep 17 00:00:00 2001 From: Matthew Hennefarth Date: Tue, 11 Feb 2025 21:40:14 -0600 Subject: [PATCH 03/12] add perm --- sci-rs/src/special/combinatorics.rs | 71 +++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/sci-rs/src/special/combinatorics.rs b/sci-rs/src/special/combinatorics.rs index 5f05b185..625939b0 100644 --- a/sci-rs/src/special/combinatorics.rs +++ b/sci-rs/src/special/combinatorics.rs @@ -43,6 +43,24 @@ pub trait Combinatoric { /// /// [wiki]: https://en.wikipedia.org/wiki/Combination#Number_of_combinations_with_repetition fn comb_rep(self, k: Self) -> Self; + + /// Number of permutations of `n` things taken `k` at a time. + /// + /// Also known as the `k`-permutation of `n`. + /// $$ + /// \text{Perm}(n, k) = \frac{n!}{(n-k)!} + /// $$ + /// # Examples + /// ``` + /// use sci_rs::special::Combinatoric; + /// assert_eq!(5.perm(5), 120); // should be 5! + /// assert_eq!(5.perm(0), 1); + /// assert_eq!(6.perm(3), 6*5*4); + /// ``` + /// + /// # Notes + /// When `n<0` or `k<0`, the `0` is returned. + fn perm(self, k: Self) -> Self; } macro_rules! combinatoric_primint_impl { @@ -58,6 +76,11 @@ macro_rules! combinatoric_primint_impl { primint_comb_rep(self, k) } + #[inline(always)] + fn perm(self, k:Self) -> Self { + primint_perm(self, k) + } + } )*) } @@ -85,6 +108,16 @@ where primint_comb(n + k - Int::one(), k) } +fn primint_perm(n: Int, k: Int) -> Int where Int: PrimInt + FromPrimitive { + if k > n ||n < Int::zero() || k < Int::zero() { + return Int::zero(); + } + + let start = (n - k + Int::one()).to_usize().unwrap(); + let end = (n + Int::one()).to_usize().unwrap(); + (start..end).fold(Int::one(), |result, val| result * Int::from_usize(val).unwrap()) +} + #[cfg(test)] mod tests { use super::*; @@ -171,4 +204,42 @@ mod tests { } } } + + #[test] + fn perm() { + let ref_values_4 = [1, 4, 12, 24, 24, 0]; + let ref_values_7 = [1, 7, 42, 210, 840, 2520, 5040, 5040, 0]; + let ref_values_13 = [ + 1, 13, 156, 1716, 17160, 154440, 1235520, 8648640, 51891840, 259459200, 1037836800, + ]; + check_values(4, &ref_values_4, i32::perm); + check_values(7, &ref_values_7, i32::perm); + check_values(13, &ref_values_13, i32::perm); + } + + #[test] + fn perm_edge() { + assert_eq!(0.perm(0), 1); + assert_eq!(1.perm(0), 1); + assert_eq!(0.perm(1), 0); + } + + #[test] + fn perm_negative() { + for i in 0..4 { + assert_eq!((-4).perm(i), 0); + assert_eq!((-3).perm(i), 0); + assert_eq!((-3241).perm(i), 0); + } + + for i in -4..0 { + assert_eq!(4.perm(i), 0); + assert_eq!(2.perm(i), 0); + assert_eq!(2341.perm(i), 0); + assert_eq!((-2).perm(i), 0); + assert_eq!((-4).perm(i), 0); + assert_eq!((-5).perm(i), 0); + assert_eq!((-3241).perm(i), 0); + } + } } From a57fa68f2fab8afd9bbc315ced9fcda5cf2030a4 Mon Sep 17 00:00:00 2001 From: Matthew Hennefarth Date: Tue, 11 Feb 2025 22:59:58 -0600 Subject: [PATCH 04/12] update tests --- sci-rs/src/special/combinatorics.rs | 232 ++++++++++++++++++---------- 1 file changed, 149 insertions(+), 83 deletions(-) diff --git a/sci-rs/src/special/combinatorics.rs b/sci-rs/src/special/combinatorics.rs index 625939b0..63ce523e 100644 --- a/sci-rs/src/special/combinatorics.rs +++ b/sci-rs/src/special/combinatorics.rs @@ -59,8 +59,37 @@ pub trait Combinatoric { /// ``` /// /// # Notes - /// When `n<0` or `k<0`, the `0` is returned. + /// When `n<0` or `k<0`, the `0` is returned. fn perm(self, k: Self) -> Self; + + /// Stirling number of the second kind. + /// + /// These count the number of ways to partition a set of `n` elements into `k` non-empty + /// subsets. These are often called `n` subset `k` and denoted as either + /// $$ + /// S(n,k) + /// $$ + /// or + /// $$ + /// \begin{Bmatrix} + /// n \\ k + /// \end{Bmatrix} + /// $$ + /// See the [wiki] page for more details. + /// + /// # Examples + /// ``` + /// use sci_rs::special::Combinatoric; + /// assert_eq!(3.stirling2(2), 3); + /// assert_eq!(0.stirling2(0), 1); + /// assert_eq!(4.stirling2(3), 6); + /// ``` + /// + /// # References + /// - [Wikipedia][wiki] + /// + /// [wiki]: https://en.wikipedia.org/wiki/Stirling_numbers_of_the_second_kind + fn stirling2(self, k: Self) -> Self; } macro_rules! combinatoric_primint_impl { @@ -77,8 +106,13 @@ macro_rules! combinatoric_primint_impl { } #[inline(always)] - fn perm(self, k:Self) -> Self { - primint_perm(self, k) + fn perm(self, k: Self) -> Self { + primint_perm(self, k) + } + + #[inline(always)] + fn stirling2(self, k: Self) -> Self { + primint_stirling2(self, k) } } @@ -108,14 +142,41 @@ where primint_comb(n + k - Int::one(), k) } -fn primint_perm(n: Int, k: Int) -> Int where Int: PrimInt + FromPrimitive { - if k > n ||n < Int::zero() || k < Int::zero() { +fn primint_perm(n: Int, k: Int) -> Int +where + Int: PrimInt + FromPrimitive, +{ + if k > n || n < Int::zero() || k < Int::zero() { return Int::zero(); } let start = (n - k + Int::one()).to_usize().unwrap(); let end = (n + Int::one()).to_usize().unwrap(); - (start..end).fold(Int::one(), |result, val| result * Int::from_usize(val).unwrap()) + (start..end).fold(Int::one(), |result, val| { + result * Int::from_usize(val).unwrap() + }) +} + +fn primint_stirling2(n: Int, k: Int) -> Int +where + Int: PrimInt + FromPrimitive, +{ + if n < Int::zero() || k < Int::zero() { + return Int::zero(); + } + if k > n { + return Int::zero(); + } + + if n == k { + return Int::one(); + } + + if k == Int::zero() || n == Int::zero() { + return Int::zero(); + } + + k * primint_stirling2(n - Int::one(), k) + primint_stirling2(n - Int::one(), k - Int::one()) } #[cfg(test)] @@ -123,35 +184,40 @@ mod tests { use super::*; use core::fmt; - fn check_values(x: T, ref_values: &[T], func: fn(T, T) -> T) + fn check_values(ref_values: &[[T; 10]], func: fn(T, T) -> T) where T: PrimInt + FromPrimitive + fmt::Debug, { - for (i, &val) in ref_values.iter().enumerate() { - let i = T::from_usize(i).unwrap(); - assert_eq!(func(x, i), val); + for (n, &elements) in ref_values.iter().enumerate() { + for (k, &val) in elements.iter().enumerate() { + let n = T::from_usize(n).unwrap(); + let k = T::from_usize(k).unwrap(); + assert_eq!(func(n, k), val); + } } } #[test] - fn choose() { - assert_eq!(3_u8.comb(1), 3); - assert_eq!(3_u8.comb(2), 3); - assert_eq!(3_u8.comb(3), 1); - - const REF_VALUES_5: [i32; 7] = [1, 5, 10, 10, 5, 1, 0]; - const REF_VALUES_10: [i32; 12] = [1, 10, 45, 120, 210, 252, 210, 120, 45, 10, 1, 0]; - const REF_VALUES_15: [i32; 16] = [ - 1, 15, 105, 455, 1365, 3003, 5005, 6435, 6435, 5005, 3003, 1365, 455, 105, 15, 1, + fn comb() { + // Generated from scipy + let ref_values = [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 2, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 3, 3, 1, 0, 0, 0, 0, 0, 0], + [1, 4, 6, 4, 1, 0, 0, 0, 0, 0], + [1, 5, 10, 10, 5, 1, 0, 0, 0, 0], + [1, 6, 15, 20, 15, 6, 1, 0, 0, 0], + [1, 7, 21, 35, 35, 21, 7, 1, 0, 0], + [1, 8, 28, 56, 70, 56, 28, 8, 1, 0], + [1, 9, 36, 84, 126, 126, 84, 36, 9, 1], ]; - check_values(5, &REF_VALUES_5, i32::comb); - check_values(10, &REF_VALUES_10, i32::comb); - check_values(15, &REF_VALUES_15, i32::comb); + check_values(&ref_values, i32::comb); } #[test] - fn choose_negatives() { - for n in -10..-1 { + fn comb_negatives() { + for n in -10..0 { for m in -5..5 { assert_eq!(n.comb(m), 0); assert_eq!(m.comb(n), 0); @@ -160,35 +226,26 @@ mod tests { } #[test] - fn choose_greater_than() { - for i in 0..10 { - for j in 0..i { - assert_eq!(j.comb(i), 0); - } - } - } - - #[test] - fn zero_choose_zero() { - assert_eq!(0.comb(0), 1); - } - - #[test] - fn choose_replacement() { - const REF_VALUES_5: [i32; 10] = [1, 5, 15, 35, 70, 126, 210, 330, 495, 715]; - const REF_VALUES_7: [i32; 10] = [1, 7, 28, 84, 210, 462, 924, 1716, 3003, 5005]; - const REF_VALUES_10: [i32; 15] = [ - 1, 10, 55, 220, 715, 2002, 5005, 11440, 24310, 48620, 92378, 167960, 293930, 497420, - 817190, + fn comb_rep() { + let ref_values = [ + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [1, 3, 6, 10, 15, 21, 28, 36, 45, 55], + [1, 4, 10, 20, 35, 56, 84, 120, 165, 220], + [1, 5, 15, 35, 70, 126, 210, 330, 495, 715], + [1, 6, 21, 56, 126, 252, 462, 792, 1287, 2002], + [1, 7, 28, 84, 210, 462, 924, 1716, 3003, 5005], + [1, 8, 36, 120, 330, 792, 1716, 3432, 6435, 11440], + [1, 9, 45, 165, 495, 1287, 3003, 6435, 12870, 24310], ]; - check_values(5, &REF_VALUES_5, i32::comb_rep); - check_values(7, &REF_VALUES_7, i32::comb_rep); - check_values(10, &REF_VALUES_10, i32::comb_rep); + + check_values(&ref_values, i32::comb_rep); } #[test] - fn choose_replacement_negatives() { - for n in -10..-1 { + fn comb_rep_negatives() { + for n in -4..0 { for m in -5..5 { assert_eq!(n.comb_rep(m), 0); assert_eq!(m.comb_rep(n), 0); @@ -196,50 +253,59 @@ mod tests { } } - #[test] - fn choose_zero_replacement() { - for i in 0..1 { - for j in 0..1 { - assert_eq!(i.comb_rep(j), i); - } - } - } - #[test] fn perm() { - let ref_values_4 = [1, 4, 12, 24, 24, 0]; - let ref_values_7 = [1, 7, 42, 210, 840, 2520, 5040, 5040, 0]; - let ref_values_13 = [ - 1, 13, 156, 1716, 17160, 154440, 1235520, 8648640, 51891840, 259459200, 1037836800, + // Generated from scipy + let ref_values = [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 2, 2, 0, 0, 0, 0, 0, 0, 0], + [1, 3, 6, 6, 0, 0, 0, 0, 0, 0], + [1, 4, 12, 24, 24, 0, 0, 0, 0, 0], + [1, 5, 20, 60, 120, 120, 0, 0, 0, 0], + [1, 6, 30, 120, 360, 720, 720, 0, 0, 0], + [1, 7, 42, 210, 840, 2520, 5040, 5040, 0, 0], + [1, 8, 56, 336, 1680, 6720, 20160, 40320, 40320, 0], + [1, 9, 72, 504, 3024, 15120, 60480, 181440, 362880, 362880], ]; - check_values(4, &ref_values_4, i32::perm); - check_values(7, &ref_values_7, i32::perm); - check_values(13, &ref_values_13, i32::perm); + check_values(&ref_values, i32::perm); } - #[test] - fn perm_edge() { - assert_eq!(0.perm(0), 1); - assert_eq!(1.perm(0), 1); - assert_eq!(0.perm(1), 0); - } - #[test] fn perm_negative() { - for i in 0..4 { - assert_eq!((-4).perm(i), 0); - assert_eq!((-3).perm(i), 0); - assert_eq!((-3241).perm(i), 0); + for i in -4..0 { + for j in -5..5 { + assert_eq!(i.perm(j), 0); + assert_eq!(j.perm(i), 0); + } } + } + #[test] + fn stirling2() { + // Generated from scipy + let ref_values = [ + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 3, 1, 0, 0, 0, 0, 0, 0], + [0, 1, 7, 6, 1, 0, 0, 0, 0, 0], + [0, 1, 15, 25, 10, 1, 0, 0, 0, 0], + [0, 1, 31, 90, 65, 15, 1, 0, 0, 0], + [0, 1, 63, 301, 350, 140, 21, 1, 0, 0], + [0, 1, 127, 966, 1701, 1050, 266, 28, 1, 0], + [0, 1, 255, 3025, 7770, 6951, 2646, 462, 36, 1], + ]; + check_values(&ref_values, i32::stirling2); + } + + #[test] + fn stirling2_negative() { for i in -4..0 { - assert_eq!(4.perm(i), 0); - assert_eq!(2.perm(i), 0); - assert_eq!(2341.perm(i), 0); - assert_eq!((-2).perm(i), 0); - assert_eq!((-4).perm(i), 0); - assert_eq!((-5).perm(i), 0); - assert_eq!((-3241).perm(i), 0); + for j in -5..5 { + assert_eq!(i.stirling2(j), 0); + assert_eq!(j.stirling2(i), 0); + } } } } From d0e7085e9ab7c3e9bb63cae7d889e65621978ecb Mon Sep 17 00:00:00 2001 From: Matthew Hennefarth Date: Tue, 11 Feb 2025 23:12:18 -0600 Subject: [PATCH 05/12] update tests and fix primint_comb_rep for u32 --- sci-rs/src/special/combinatorics.rs | 33 ++++++++++++++++++----------- 1 file changed, 21 insertions(+), 12 deletions(-) diff --git a/sci-rs/src/special/combinatorics.rs b/sci-rs/src/special/combinatorics.rs index 63ce523e..9949e789 100644 --- a/sci-rs/src/special/combinatorics.rs +++ b/sci-rs/src/special/combinatorics.rs @@ -139,6 +139,9 @@ fn primint_comb_rep(n: Int, k: Int) -> Int where Int: PrimInt + FromPrimitive, { + if n + k == Int::zero() { + return Int::zero(); + } primint_comb(n + k - Int::one(), k) } @@ -184,14 +187,16 @@ mod tests { use super::*; use core::fmt; - fn check_values(ref_values: &[[T; 10]], func: fn(T, T) -> T) + fn check_values(ref_values: &[[K;10]], func: fn(T, T) -> T) where + K: PrimInt + FromPrimitive, T: PrimInt + FromPrimitive + fmt::Debug, { for (n, &elements) in ref_values.iter().enumerate() { for (k, &val) in elements.iter().enumerate() { let n = T::from_usize(n).unwrap(); let k = T::from_usize(k).unwrap(); + let val = T::from(val).unwrap(); assert_eq!(func(n, k), val); } } @@ -201,17 +206,18 @@ mod tests { fn comb() { // Generated from scipy let ref_values = [ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 2, 1, 0, 0, 0, 0, 0, 0, 0], - [1, 3, 3, 1, 0, 0, 0, 0, 0, 0], - [1, 4, 6, 4, 1, 0, 0, 0, 0, 0], - [1, 5, 10, 10, 5, 1, 0, 0, 0, 0], - [1, 6, 15, 20, 15, 6, 1, 0, 0, 0], - [1, 7, 21, 35, 35, 21, 7, 1, 0, 0], - [1, 8, 28, 56, 70, 56, 28, 8, 1, 0], - [1, 9, 36, 84, 126, 126, 84, 36, 9, 1], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 2, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 3, 3, 1, 0, 0, 0, 0, 0, 0], + [1, 4, 6, 4, 1, 0, 0, 0, 0, 0], + [1, 5, 10, 10, 5, 1, 0, 0, 0, 0], + [1, 6, 15, 20, 15, 6, 1, 0, 0, 0], + [1, 7, 21, 35, 35, 21, 7, 1, 0, 0], + [1, 8, 28, 56, 70, 56, 28, 8, 1, 0], + [1, 9, 36, 84, 126, 126, 84, 36, 9, 1], ]; + check_values(&ref_values, u32::comb); check_values(&ref_values, i32::comb); } @@ -227,6 +233,7 @@ mod tests { #[test] fn comb_rep() { + // Generated from scipy let ref_values = [ [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], @@ -239,7 +246,7 @@ mod tests { [1, 8, 36, 120, 330, 792, 1716, 3432, 6435, 11440], [1, 9, 45, 165, 495, 1287, 3003, 6435, 12870, 24310], ]; - + check_values(&ref_values, u32::comb_rep); check_values(&ref_values, i32::comb_rep); } @@ -268,6 +275,7 @@ mod tests { [1, 8, 56, 336, 1680, 6720, 20160, 40320, 40320, 0], [1, 9, 72, 504, 3024, 15120, 60480, 181440, 362880, 362880], ]; + check_values(&ref_values, u32::perm); check_values(&ref_values, i32::perm); } @@ -296,6 +304,7 @@ mod tests { [0, 1, 127, 966, 1701, 1050, 266, 28, 1, 0], [0, 1, 255, 3025, 7770, 6951, 2646, 462, 36, 1], ]; + check_values(&ref_values, u32::stirling2); check_values(&ref_values, i32::stirling2); } From 4d4e32019804ec09e29ca747ae4e46f582674924 Mon Sep 17 00:00:00 2001 From: Matthew Hennefarth Date: Tue, 11 Feb 2025 23:18:38 -0600 Subject: [PATCH 06/12] fix cargo fmt --- sci-rs/src/special/combinatorics.rs | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sci-rs/src/special/combinatorics.rs b/sci-rs/src/special/combinatorics.rs index 9949e789..471df6df 100644 --- a/sci-rs/src/special/combinatorics.rs +++ b/sci-rs/src/special/combinatorics.rs @@ -187,9 +187,9 @@ mod tests { use super::*; use core::fmt; - fn check_values(ref_values: &[[K;10]], func: fn(T, T) -> T) + fn check_values(ref_values: &[[K; 10]], func: fn(T, T) -> T) where - K: PrimInt + FromPrimitive, + K: PrimInt + FromPrimitive, T: PrimInt + FromPrimitive + fmt::Debug, { for (n, &elements) in ref_values.iter().enumerate() { @@ -206,16 +206,16 @@ mod tests { fn comb() { // Generated from scipy let ref_values = [ - [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], - [1, 2, 1, 0, 0, 0, 0, 0, 0, 0], - [1, 3, 3, 1, 0, 0, 0, 0, 0, 0], - [1, 4, 6, 4, 1, 0, 0, 0, 0, 0], - [1, 5, 10, 10, 5, 1, 0, 0, 0, 0], - [1, 6, 15, 20, 15, 6, 1, 0, 0, 0], - [1, 7, 21, 35, 35, 21, 7, 1, 0, 0], - [1, 8, 28, 56, 70, 56, 28, 8, 1, 0], - [1, 9, 36, 84, 126, 126, 84, 36, 9, 1], + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [1, 2, 1, 0, 0, 0, 0, 0, 0, 0], + [1, 3, 3, 1, 0, 0, 0, 0, 0, 0], + [1, 4, 6, 4, 1, 0, 0, 0, 0, 0], + [1, 5, 10, 10, 5, 1, 0, 0, 0, 0], + [1, 6, 15, 20, 15, 6, 1, 0, 0, 0], + [1, 7, 21, 35, 35, 21, 7, 1, 0, 0], + [1, 8, 28, 56, 70, 56, 28, 8, 1, 0], + [1, 9, 36, 84, 126, 126, 84, 36, 9, 1], ]; check_values(&ref_values, u32::comb); check_values(&ref_values, i32::comb); From 7514051bc02e93b84175f2fd76e0b65ea9565407 Mon Sep 17 00:00:00 2001 From: Matthew Hennefarth Date: Wed, 12 Feb 2025 09:51:03 -0600 Subject: [PATCH 07/12] add debug info if test fails --- sci-rs/src/special/combinatorics.rs | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/sci-rs/src/special/combinatorics.rs b/sci-rs/src/special/combinatorics.rs index 471df6df..68cac5c9 100644 --- a/sci-rs/src/special/combinatorics.rs +++ b/sci-rs/src/special/combinatorics.rs @@ -187,17 +187,17 @@ mod tests { use super::*; use core::fmt; - fn check_values(ref_values: &[[K; 10]], func: fn(T, T) -> T) + fn check_values(ref_values: &[[K; 10]], func: fn(T, T) -> T, func_name: &str) where K: PrimInt + FromPrimitive, - T: PrimInt + FromPrimitive + fmt::Debug, + T: PrimInt + FromPrimitive + fmt::Display + fmt::Debug, { for (n, &elements) in ref_values.iter().enumerate() { for (k, &val) in elements.iter().enumerate() { let n = T::from_usize(n).unwrap(); let k = T::from_usize(k).unwrap(); let val = T::from(val).unwrap(); - assert_eq!(func(n, k), val); + assert_eq!(func(n, k), val, "{}({}, {}) != {}", func_name, n, k, val); } } } @@ -217,8 +217,8 @@ mod tests { [1, 8, 28, 56, 70, 56, 28, 8, 1, 0], [1, 9, 36, 84, 126, 126, 84, 36, 9, 1], ]; - check_values(&ref_values, u32::comb); - check_values(&ref_values, i32::comb); + check_values(&ref_values, u32::comb, "u32::comb"); + check_values(&ref_values, i32::comb, "i32::comb"); } #[test] @@ -246,8 +246,8 @@ mod tests { [1, 8, 36, 120, 330, 792, 1716, 3432, 6435, 11440], [1, 9, 45, 165, 495, 1287, 3003, 6435, 12870, 24310], ]; - check_values(&ref_values, u32::comb_rep); - check_values(&ref_values, i32::comb_rep); + check_values(&ref_values, u32::comb_rep, "u32::comb_rep"); + check_values(&ref_values, i32::comb_rep, "i32::comb_rep"); } #[test] @@ -275,8 +275,8 @@ mod tests { [1, 8, 56, 336, 1680, 6720, 20160, 40320, 40320, 0], [1, 9, 72, 504, 3024, 15120, 60480, 181440, 362880, 362880], ]; - check_values(&ref_values, u32::perm); - check_values(&ref_values, i32::perm); + check_values(&ref_values, u32::perm, "u32::perm"); + check_values(&ref_values, i32::perm, "i32::perm"); } #[test] @@ -304,8 +304,8 @@ mod tests { [0, 1, 127, 966, 1701, 1050, 266, 28, 1, 0], [0, 1, 255, 3025, 7770, 6951, 2646, 462, 36, 1], ]; - check_values(&ref_values, u32::stirling2); - check_values(&ref_values, i32::stirling2); + check_values(&ref_values, u32::stirling2, "u32::stirling2"); + check_values(&ref_values, i32::stirling2, "i32::stirling2"); } #[test] From b2d98d987b5d785f46ae7ca2367039d3ba6beb8c Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sun, 9 Mar 2025 19:33:31 +0800 Subject: [PATCH 08/12] Add special::xsf::chbevl Introduces the xsf namespace without tainting the special:: namespace. Introduces the chbevl function, which will be needed for functions such as i0, i1, k0, k1, rgamma and shichi. --- sci-rs/src/special/mod.rs | 4 ++++ sci-rs/src/special/xsf/chbevl.rs | 28 ++++++++++++++++++++++++++++ sci-rs/src/special/xsf/mod.rs | 2 ++ 3 files changed, 34 insertions(+) create mode 100644 sci-rs/src/special/xsf/chbevl.rs create mode 100644 sci-rs/src/special/xsf/mod.rs diff --git a/sci-rs/src/special/mod.rs b/sci-rs/src/special/mod.rs index 5cfce8c9..0f8cc627 100644 --- a/sci-rs/src/special/mod.rs +++ b/sci-rs/src/special/mod.rs @@ -2,3 +2,7 @@ mod combinatorics; pub use combinatorics::*; + +// Name is from special/xsf folder, which is Scipy has designated as X special functions (written +// in C++) that are not exposed to Python. We keep these set of functions as being crate internal. +pub(crate) mod xsf; diff --git a/sci-rs/src/special/xsf/chbevl.rs b/sci-rs/src/special/xsf/chbevl.rs new file mode 100644 index 00000000..9dc7b4f4 --- /dev/null +++ b/sci-rs/src/special/xsf/chbevl.rs @@ -0,0 +1,28 @@ +use num_traits::real::Real; + +/// This function evaluates the series y = Sum{i=0..N} coef[i] T_i(x/2) of Chebyshev polynomials Ti +/// at argument x/2. Returns 0 for empty coef. +/// +/// Coefficients are stored in the reverse order, i.e.: the zero order term is the last in the +/// array. +pub(crate) fn chbevl(x: T, coef: &[T]) -> T +where + T: Real + core::ops::Mul, +{ + // Summation over the empty set is defined to be zero. + if coef.is_empty() { + return T::zero(); + } + + let (b0, _, b2): (T, T, T) = coef.iter().fold( + ( + // Safety:: 0 len is checked above. + *unsafe { coef.first().unwrap_unchecked() }, + T::zero(), + T::zero(), + ), + |acc, &e| (x * acc.0 - acc.1 + e, acc.0, acc.1), + ); + + (b0 - b2) / unsafe { T::from(2.).unwrap_unchecked() } +} diff --git a/sci-rs/src/special/xsf/mod.rs b/sci-rs/src/special/xsf/mod.rs new file mode 100644 index 00000000..8f5747c3 --- /dev/null +++ b/sci-rs/src/special/xsf/mod.rs @@ -0,0 +1,2 @@ +mod chbevl; +pub(crate) use chbevl::*; From c8a676986e4140731a494078a4b21a80de499bae Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sun, 9 Mar 2025 21:34:59 +0800 Subject: [PATCH 09/12] Add Bessel as a trait: i0 for f64 Implement bessel as a trait for f64 with the modified i0 bessel function first. Can easily be extended for Vec and Array as necessary later. --- sci-rs/src/special/bessel.rs | 14 +++++ sci-rs/src/special/mod.rs | 4 ++ sci-rs/src/special/xsf/i0.rs | 108 ++++++++++++++++++++++++++++++++++ sci-rs/src/special/xsf/mod.rs | 2 + 4 files changed, 128 insertions(+) create mode 100644 sci-rs/src/special/bessel.rs create mode 100644 sci-rs/src/special/xsf/i0.rs diff --git a/sci-rs/src/special/bessel.rs b/sci-rs/src/special/bessel.rs new file mode 100644 index 00000000..52ba4efe --- /dev/null +++ b/sci-rs/src/special/bessel.rs @@ -0,0 +1,14 @@ +/// All [functions located in the `Faster versions of common Bessel +/// functions.`]() +pub trait Bessel { + /// Modified Bessel function of order 0. + /// + /// ## Notes + /// * The range is partitioned into the two intervals [0, 8] and (8, infinity). + /// * [Scipy has this as a + /// ufunc](), + /// as a supposed wrapper over the Cephes routine. We try to define it over reasonable types in + /// the impl. + /// + fn i0(&self) -> Self; +} diff --git a/sci-rs/src/special/mod.rs b/sci-rs/src/special/mod.rs index 0f8cc627..ef410ae8 100644 --- a/sci-rs/src/special/mod.rs +++ b/sci-rs/src/special/mod.rs @@ -3,6 +3,10 @@ mod combinatorics; pub use combinatorics::*; +/// Adds the [Bessel] trait. +mod bessel; +pub use bessel::Bessel; + // Name is from special/xsf folder, which is Scipy has designated as X special functions (written // in C++) that are not exposed to Python. We keep these set of functions as being crate internal. pub(crate) mod xsf; diff --git a/sci-rs/src/special/xsf/i0.rs b/sci-rs/src/special/xsf/i0.rs new file mode 100644 index 00000000..cbb378df --- /dev/null +++ b/sci-rs/src/special/xsf/i0.rs @@ -0,0 +1,108 @@ +use super::super::xsf; +use super::super::Bessel; +use num_traits::Pow; + +// Chebyshev coefficients for exp(-x) I0(x) +// in the interval [0,8]. +// +// lim(x->0){ exp(-x) I0(x) } = 1. +const I0_A_F64: [f64; 30] = [ + -4.41534164647933937950E-18, + 3.33079451882223809783E-17, + -2.43127984654795469359E-16, + 1.71539128555513303061E-15, + -1.16853328779934516808E-14, + 7.67618549860493561688E-14, + -4.85644678311192946090E-13, + 2.95505266312963983461E-12, + -1.72682629144155570723E-11, + 9.67580903537323691224E-11, + -5.18979560163526290666E-10, + 2.65982372468238665035E-9, + -1.30002500998624804212E-8, + 6.04699502254191894932E-8, + -2.67079385394061173391E-7, + 1.11738753912010371815E-6, + -4.41673835845875056359E-6, + 1.64484480707288970893E-5, + -5.75419501008210370398E-5, + 1.88502885095841655729E-4, + -5.76375574538582365885E-4, + 1.63947561694133579842E-3, + -4.32430999505057594430E-3, + 1.05464603945949983183E-2, + -2.37374148058994688156E-2, + 4.93052842396707084878E-2, + -9.49010970480476444210E-2, + 1.71620901522208775349E-1, + -3.04682672343198398683E-1, + 6.76795274409476084995E-1, +]; + +// Chebyshev coefficients for exp(-x) sqrt(x) I0(x) +// in the inverted interval [8,infinity]. +// +// lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). +const I0_B_F64: [f64; 25] = [ + -7.23318048787475395456E-18, + -4.83050448594418207126E-18, + 4.46562142029675999901E-17, + 3.46122286769746109310E-17, + -2.82762398051658348494E-16, + -3.42548561967721913462E-16, + 1.77256013305652638360E-15, + 3.81168066935262242075E-15, + -9.55484669882830764870E-15, + -4.15056934728722208663E-14, + 1.54008621752140982691E-14, + 3.85277838274214270114E-13, + 7.18012445138366623367E-13, + -1.79417853150680611778E-12, + -1.32158118404477131188E-11, + -3.14991652796324136454E-11, + 1.18891471078464383424E-11, + 4.94060238822496958910E-10, + 3.39623202570838634515E-9, + 2.26666899049817806459E-8, + 2.04891858946906374183E-7, + 2.89137052083475648297E-6, + 6.88975834691682398426E-5, + 3.36911647825569408990E-3, + 8.04490411014108831608E-1, +]; + +impl Bessel for f64 { + fn i0(&self) -> Self { + let x = self.abs(); + if x <= 8. { + let y = (x / 2.) - 2.; + return x.exp() * xsf::chbevl(y, &I0_A_F64); + } + x.exp() * xsf::chbevl(32. / x - 2., &I0_B_F64) / x.sqrt() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_relative_eq; + + #[test] + fn i0_f64() { + let result: f64 = (0.).i0(); + let exp = 1.; + assert_relative_eq!(result, exp, epsilon = 1e-6); + let result: f64 = (1.).i0(); + let exp = 1.2660658777520082; + assert_relative_eq!(result, exp, epsilon = 1e-6); + let result: f64 = 0.213.i0(); + let exp = 1.0113744522192416; + assert_relative_eq!(result, exp, epsilon = 1e-6); + let result: f64 = 5.0.i0(); + let exp = 27.239871823604442; + assert_relative_eq!(result, exp, epsilon = 1e-6); + let result: f64 = 30.546.i0(); + let exp = 1337209608661.4026; + assert_relative_eq!(result, exp, epsilon = 1e-6); + } +} diff --git a/sci-rs/src/special/xsf/mod.rs b/sci-rs/src/special/xsf/mod.rs index 8f5747c3..e259c768 100644 --- a/sci-rs/src/special/xsf/mod.rs +++ b/sci-rs/src/special/xsf/mod.rs @@ -1,2 +1,4 @@ mod chbevl; pub(crate) use chbevl::*; + +mod i0; From eb917dd7bcc84ad8a0a938d03d3db0b6c61108bc Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sun, 9 Mar 2025 21:43:22 +0800 Subject: [PATCH 10/12] Add modified bessel i0 for f32 --- sci-rs/src/special/xsf/i0.rs | 95 +++++++++++++++++++++++++++++++++++- 1 file changed, 94 insertions(+), 1 deletion(-) diff --git a/sci-rs/src/special/xsf/i0.rs b/sci-rs/src/special/xsf/i0.rs index cbb378df..e3556d3c 100644 --- a/sci-rs/src/special/xsf/i0.rs +++ b/sci-rs/src/special/xsf/i0.rs @@ -1,6 +1,6 @@ use super::super::xsf; use super::super::Bessel; -use num_traits::Pow; +use num_traits::{real::Real, Pow}; // Chebyshev coefficients for exp(-x) I0(x) // in the interval [0,8]. @@ -82,6 +82,86 @@ impl Bessel for f64 { } } +// Chebyshev coefficients for exp(-x) I0(x) +// in the interval [0,8]. +// +// lim(x->0){ exp(-x) I0(x) } = 1. +const I0_A_F32: [f32; 30] = [ + -4.41534164647933937950E-18, + 3.33079451882223809783E-17, + -2.43127984654795469359E-16, + 1.71539128555513303061E-15, + -1.16853328779934516808E-14, + 7.67618549860493561688E-14, + -4.85644678311192946090E-13, + 2.95505266312963983461E-12, + -1.72682629144155570723E-11, + 9.67580903537323691224E-11, + -5.18979560163526290666E-10, + 2.65982372468238665035E-9, + -1.30002500998624804212E-8, + 6.04699502254191894932E-8, + -2.67079385394061173391E-7, + 1.11738753912010371815E-6, + -4.41673835845875056359E-6, + 1.64484480707288970893E-5, + -5.75419501008210370398E-5, + 1.88502885095841655729E-4, + -5.76375574538582365885E-4, + 1.63947561694133579842E-3, + -4.32430999505057594430E-3, + 1.05464603945949983183E-2, + -2.37374148058994688156E-2, + 4.93052842396707084878E-2, + -9.49010970480476444210E-2, + 1.71620901522208775349E-1, + -3.04682672343198398683E-1, + 6.76795274409476084995E-1, +]; + +// Chebyshev coefficients for exp(-x) sqrt(x) I0(x) +// in the inverted interval [8,infinity]. +// +// lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). +const I0_B_F32: [f32; 25] = [ + -7.23318048787475395456E-18, + -4.83050448594418207126E-18, + 4.46562142029675999901E-17, + 3.46122286769746109310E-17, + -2.82762398051658348494E-16, + -3.42548561967721913462E-16, + 1.77256013305652638360E-15, + 3.81168066935262242075E-15, + -9.55484669882830764870E-15, + -4.15056934728722208663E-14, + 1.54008621752140982691E-14, + 3.85277838274214270114E-13, + 7.18012445138366623367E-13, + -1.79417853150680611778E-12, + -1.32158118404477131188E-11, + -3.14991652796324136454E-11, + 1.18891471078464383424E-11, + 4.94060238822496958910E-10, + 3.39623202570838634515E-9, + 2.26666899049817806459E-8, + 2.04891858946906374183E-7, + 2.89137052083475648297E-6, + 6.88975834691682398426E-5, + 3.36911647825569408990E-3, + 8.04490411014108831608E-1, +]; + +impl Bessel for f32 { + fn i0(&self) -> Self { + let x = self.abs(); + if x <= 8. { + let y = (x / 2.) - 2.; + return x.exp() * xsf::chbevl(y, &I0_A_F32); + } + x.exp() * xsf::chbevl(32. / x - 2., &I0_B_F32) / x.sqrt() + } +} + #[cfg(test)] mod tests { use super::*; @@ -105,4 +185,17 @@ mod tests { let exp = 1337209608661.4026; assert_relative_eq!(result, exp, epsilon = 1e-6); } + + #[test] + fn i0_f32() { + let result: f32 = (0.).i0(); + let exp = 1.; + assert_relative_eq!(result, exp, epsilon = 1e-6); + let result: f32 = (1.).i0(); + let exp = 1.2660658777520082; + assert_relative_eq!(result, exp, epsilon = 1e-6); + let result: f32 = 0.213.i0(); + let exp = 1.0113744522192416; + assert_relative_eq!(result, exp, epsilon = 1e-6); + } } From 216dafab91707127c1f49e2bcc2c90536f3952f8 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sun, 9 Mar 2025 21:58:43 +0800 Subject: [PATCH 11/12] Add exponentially scaled modified bessel i0e to Bessel trait --- sci-rs/src/special/bessel.rs | 10 +++++++ sci-rs/src/special/xsf/i0.rs | 56 ++++++++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/sci-rs/src/special/bessel.rs b/sci-rs/src/special/bessel.rs index 52ba4efe..ab53269a 100644 --- a/sci-rs/src/special/bessel.rs +++ b/sci-rs/src/special/bessel.rs @@ -11,4 +11,14 @@ pub trait Bessel { /// the impl. /// fn i0(&self) -> Self; + + /// Exponentially scaled modified Bessel function of order 0. + /// + /// ## Notes + /// * The range is partitioned into the two intervals [0, 8] and (8, infinity). + /// * [Scipy has this as a + /// ufunc](), + /// as a supposed wrapper over the Cephes routine. We try to define it over reasonable types in + /// the impl. + fn i0e(&self) -> Self; } diff --git a/sci-rs/src/special/xsf/i0.rs b/sci-rs/src/special/xsf/i0.rs index e3556d3c..eae8ef48 100644 --- a/sci-rs/src/special/xsf/i0.rs +++ b/sci-rs/src/special/xsf/i0.rs @@ -80,6 +80,15 @@ impl Bessel for f64 { } x.exp() * xsf::chbevl(32. / x - 2., &I0_B_F64) / x.sqrt() } + + fn i0e(&self) -> Self { + let x = self.abs(); + if x <= 8. { + let y = (x / 2.) - 2.; + return xsf::chbevl(y, &I0_A_F64); + } + xsf::chbevl(32. / x - 2., &I0_B_F64) / x.sqrt() + } } // Chebyshev coefficients for exp(-x) I0(x) @@ -160,6 +169,15 @@ impl Bessel for f32 { } x.exp() * xsf::chbevl(32. / x - 2., &I0_B_F32) / x.sqrt() } + + fn i0e(&self) -> Self { + let x = self.abs(); + if x <= 8. { + let y = (x / 2.) - 2.; + return xsf::chbevl(y, &I0_A_F32); + } + xsf::chbevl(32. / x - 2., &I0_B_F32) / x.sqrt() + } } #[cfg(test)] @@ -198,4 +216,42 @@ mod tests { let exp = 1.0113744522192416; assert_relative_eq!(result, exp, epsilon = 1e-6); } + + #[test] + fn i0e_f64() { + let result: f64 = (0.).i0e(); + let exp = 1.; + assert_relative_eq!(result, exp, epsilon = 1e-6); + let result: f64 = (1.).i0e(); + let exp = 0.46575960759364043; + assert_relative_eq!(result, exp, epsilon = 1e-6); + let result: f64 = 0.213.i0e(); + let exp = 0.8173484705849442; + assert_relative_eq!(result, exp, epsilon = 1e-6); + let result: f64 = 5.0.i0e(); + let exp = 0.18354081260932834; + assert_relative_eq!(result, exp, epsilon = 1e-6); + let result: f64 = 30.546.i0e(); + let exp = 0.0724836816695565; + assert_relative_eq!(result, exp, epsilon = 1e-6); + } + + #[test] + fn i0e_f32() { + let result: f32 = (0.).i0e(); + let exp = 1.; + assert_relative_eq!(result, exp, epsilon = 1e-6); + let result: f32 = (1.).i0e(); + let exp = 0.46575960759364043; + assert_relative_eq!(result, exp, epsilon = 1e-6); + let result: f32 = 0.213.i0e(); + let exp = 0.8173484705849442; + assert_relative_eq!(result, exp, epsilon = 1e-6); + let result: f32 = 5.0.i0e(); + let exp = 0.18354081260932834; + assert_relative_eq!(result, exp, epsilon = 1e-6); + let result: f32 = 30.546.i0e(); + let exp = 0.0724836816695565; + assert_relative_eq!(result, exp, epsilon = 1e-6); + } } From f080447484413fc148bd7f70c23919bc84346ce9 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sun, 9 Mar 2025 22:04:14 +0800 Subject: [PATCH 12/12] Silence clippy excessive precisions in Bessel implementation We keep the extra precision to keep it easy to maintain the implementations (currently just mirrored) between the f32 and f64 versions. --- sci-rs/src/special/xsf/i0.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sci-rs/src/special/xsf/i0.rs b/sci-rs/src/special/xsf/i0.rs index eae8ef48..56ea50e7 100644 --- a/sci-rs/src/special/xsf/i0.rs +++ b/sci-rs/src/special/xsf/i0.rs @@ -6,6 +6,7 @@ use num_traits::{real::Real, Pow}; // in the interval [0,8]. // // lim(x->0){ exp(-x) I0(x) } = 1. +#[allow(clippy::excessive_precision)] const I0_A_F64: [f64; 30] = [ -4.41534164647933937950E-18, 3.33079451882223809783E-17, @@ -43,6 +44,7 @@ const I0_A_F64: [f64; 30] = [ // in the inverted interval [8,infinity]. // // lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). +#[allow(clippy::excessive_precision)] const I0_B_F64: [f64; 25] = [ -7.23318048787475395456E-18, -4.83050448594418207126E-18, @@ -95,6 +97,7 @@ impl Bessel for f64 { // in the interval [0,8]. // // lim(x->0){ exp(-x) I0(x) } = 1. +#[allow(clippy::excessive_precision)] const I0_A_F32: [f32; 30] = [ -4.41534164647933937950E-18, 3.33079451882223809783E-17, @@ -132,6 +135,7 @@ const I0_A_F32: [f32; 30] = [ // in the inverted interval [8,infinity]. // // lim(x->inf){ exp(-x) sqrt(x) I0(x) } = 1/sqrt(2pi). +#[allow(clippy::excessive_precision)] const I0_B_F32: [f32; 25] = [ -7.23318048787475395456E-18, -4.83050448594418207126E-18,