From dbf32f76acd54bf56a934e59498ab7cc697af90a Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Tue, 11 Mar 2025 13:19:14 +0800 Subject: [PATCH 01/68] Add error enums The error variants here should guide the user as to the wrong use of functions, instead of taking down the entire program with a panic, which is especially necessary in an embedded environment. --- sci-rs-core/src/lib.rs | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) create mode 100644 sci-rs-core/src/lib.rs diff --git a/sci-rs-core/src/lib.rs b/sci-rs-core/src/lib.rs new file mode 100644 index 00000000..e9c4dc0f --- /dev/null +++ b/sci-rs-core/src/lib.rs @@ -0,0 +1,35 @@ +//! Core library for sci-rs. + +#![cfg_attr(not(feature = "std"), no_std)] + +#[cfg(feature = "alloc")] +extern crate alloc; + +use core::{error, fmt}; + +/// Errors raised whilst running sci-rs. +#[derive(Debug, PartialEq, Eq)] +pub enum Error { + /// Argument parsed into function were invalid. + #[cfg(feature = "alloc")] + InvalidArg { + /// The invalid arg + arg: alloc::string::String, + /// Explaining why arg is invalid. + reason: alloc::string::String, + }, + /// Two or more optional arguments passed into functions conflict. + #[cfg(feature = "alloc")] + ConfictArg { + /// Explaining what arg is invalid. + reason: alloc::string::String, + }, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + todo!() + } +} + +impl error::Error for Error {} From 66d829a73ee7c0277cd7eb1bee7d061eaecd8870 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Tue, 11 Mar 2025 13:24:10 +0800 Subject: [PATCH 02/68] Add non-allocating variants of Error enum variants This however does not specify to the end user which arguments are raising the error. Fix "Conflict" typo in Error enum variant --- sci-rs-core/src/lib.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/sci-rs-core/src/lib.rs b/sci-rs-core/src/lib.rs index e9c4dc0f..14c5071d 100644 --- a/sci-rs-core/src/lib.rs +++ b/sci-rs-core/src/lib.rs @@ -18,12 +18,18 @@ pub enum Error { /// Explaining why arg is invalid. reason: alloc::string::String, }, + /// Argument parsed into function were invalid. + #[cfg(not(feature = "alloc"))] + InvalidArg, /// Two or more optional arguments passed into functions conflict. #[cfg(feature = "alloc")] - ConfictArg { + ConflictArg { /// Explaining what arg is invalid. reason: alloc::string::String, }, + /// Two or more optional arguments passed into functions conflict. + #[cfg(not(feature = "alloc"))] + ConflictArg, } impl fmt::Display for Error { From 42c29033deb334f0aadb2dc000289884ab7c3da5 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Fri, 25 Apr 2025 13:39:02 +0800 Subject: [PATCH 03/68] Add ndarray_conv to core ndarray_conv provides linear and FFT convolution for N-dimensional tensors. This commit thus introduces into core under a "num_rs" namespace, since there are many scipy functions that use numpy functions. The ConvolveMode enum is thus moved, and the relevant enum variants of ndarray_conv is also provisioned by means of `.into()`. --- sci-rs-core/Cargo.toml | 29 +++++++++++++++++++ sci-rs-core/src/lib.rs | 2 ++ sci-rs-core/src/num_rs/convolve/mod.rs | 11 +++++++ .../src/num_rs/convolve/ndarray_conv_binds.rs | 12 ++++++++ sci-rs-core/src/num_rs/mod.rs | 2 ++ sci-rs/Cargo.toml | 1 + sci-rs/src/signal/convolve.rs | 10 +------ 7 files changed, 58 insertions(+), 9 deletions(-) create mode 100644 sci-rs-core/Cargo.toml create mode 100644 sci-rs-core/src/num_rs/convolve/mod.rs create mode 100644 sci-rs-core/src/num_rs/convolve/ndarray_conv_binds.rs create mode 100644 sci-rs-core/src/num_rs/mod.rs diff --git a/sci-rs-core/Cargo.toml b/sci-rs-core/Cargo.toml new file mode 100644 index 00000000..8ff58964 --- /dev/null +++ b/sci-rs-core/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "sci-rs-core" +version = "0.0.0" +edition = "2021" +authors = ["Jacob Trueb "] +description = "Core library for sci-rs internals." +license = "MIT OR Apache-2.0" +repository = "https://github.com/qsib-cbie/sci-rs.git" +homepage = "https://github.com/qsib-cbie/sci-rs.git" +readme = "../README.md" +keywords = ["scipy", "dsp", "signal", "filter", "design"] +categories = ["science", "mathematics", "no-std", "embedded"] + + +[package.metadata.docs.rs] +all-features = true + +[features] +default = ['alloc'] + +# Allow allocating vecs, matrices, etc. +alloc = [] + +# Enable FFT and standard library features +std = ['alloc'] + +[dependencies] +ndarray = { version = "0.16.1", default-features = false } +ndarray-conv = { version = "0.4.1" } diff --git a/sci-rs-core/src/lib.rs b/sci-rs-core/src/lib.rs index 0cf6a8d8..8de2ed82 100644 --- a/sci-rs-core/src/lib.rs +++ b/sci-rs-core/src/lib.rs @@ -60,3 +60,5 @@ impl fmt::Display for Error { } impl error::Error for Error {} + +pub mod num_rs; diff --git a/sci-rs-core/src/num_rs/convolve/mod.rs b/sci-rs-core/src/num_rs/convolve/mod.rs new file mode 100644 index 00000000..e927424e --- /dev/null +++ b/sci-rs-core/src/num_rs/convolve/mod.rs @@ -0,0 +1,11 @@ +mod ndarray_conv_binds; + +/// Convolution mode determines behavior near edges and output size +pub enum ConvolveMode { + /// Full convolution, output size is `in1.len() + in2.len() - 1` + Full, + /// Valid convolution, output size is `max(in1.len(), in2.len()) - min(in1.len(), in2.len()) + 1` + Valid, + /// Same convolution, output size is `in1.len()` + Same, +} diff --git a/sci-rs-core/src/num_rs/convolve/ndarray_conv_binds.rs b/sci-rs-core/src/num_rs/convolve/ndarray_conv_binds.rs new file mode 100644 index 00000000..b0f38c79 --- /dev/null +++ b/sci-rs-core/src/num_rs/convolve/ndarray_conv_binds.rs @@ -0,0 +1,12 @@ +use super::ConvolveMode; +use ndarray_conv::ConvMode; + +impl From for ConvMode { + fn from(value: ConvolveMode) -> Self { + match value { + ConvolveMode::Full => ConvMode::Full, + ConvolveMode::Same => ConvMode::Same, + ConvolveMode::Valid => ConvMode::Valid, + } + } +} diff --git a/sci-rs-core/src/num_rs/mod.rs b/sci-rs-core/src/num_rs/mod.rs new file mode 100644 index 00000000..424dc967 --- /dev/null +++ b/sci-rs-core/src/num_rs/mod.rs @@ -0,0 +1,2 @@ +mod convolve; +pub use convolve::*; diff --git a/sci-rs/Cargo.toml b/sci-rs/Cargo.toml index 697e7933..1269f678 100644 --- a/sci-rs/Cargo.toml +++ b/sci-rs/Cargo.toml @@ -36,6 +36,7 @@ lstsq = { version = "0.6.0", default-features = false } rustfft = { version = "6.2.0", optional = true } kalmanfilt = { version = "0.3.0", default-features = false } gaussfilt = { version = "0.1.3", default-features = false } +sci-rs-core = { path = "../sci-rs-core" } [dev-dependencies] approx = "0.5.1" diff --git a/sci-rs/src/signal/convolve.rs b/sci-rs/src/signal/convolve.rs index b527805e..7e0486b5 100644 --- a/sci-rs/src/signal/convolve.rs +++ b/sci-rs/src/signal/convolve.rs @@ -2,15 +2,7 @@ use nalgebra::Complex; use num_traits::{Float, FromPrimitive, Signed, Zero}; use rustfft::{FftNum, FftPlanner}; -/// Convolution mode determines behavior near edges and output size -pub enum ConvolveMode { - /// Full convolution, output size is `in1.len() + in2.len() - 1` - Full, - /// Valid convolution, output size is `max(in1.len(), in2.len()) - min(in1.len(), in2.len()) + 1` - Valid, - /// Same convolution, output size is `in1.len()` - Same, -} +pub use sci_rs_core::num_rs::ConvolveMode; /// Performs FFT-based convolution on two slices of floating point values. /// From 54e75e48a7a696c2c30df4784d28d9e4a4f9aace Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Fri, 25 Apr 2025 17:46:03 +0800 Subject: [PATCH 04/68] Add Result "alias" to core --- sci-rs-core/src/lib.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sci-rs-core/src/lib.rs b/sci-rs-core/src/lib.rs index 14c5071d..202b7be6 100644 --- a/sci-rs-core/src/lib.rs +++ b/sci-rs-core/src/lib.rs @@ -7,6 +7,8 @@ extern crate alloc; use core::{error, fmt}; +pub type Result = core::result::Result; + /// Errors raised whilst running sci-rs. #[derive(Debug, PartialEq, Eq)] pub enum Error { From 2d1f079d65258f7ca8996ee2c52c4b5797ac5d26 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Fri, 25 Apr 2025 17:55:54 +0800 Subject: [PATCH 05/68] Add error messages Display trait is for when `main` ends and an error message is shown to the end user. --- sci-rs-core/src/lib.rs | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/sci-rs-core/src/lib.rs b/sci-rs-core/src/lib.rs index 202b7be6..0cf6a8d8 100644 --- a/sci-rs-core/src/lib.rs +++ b/sci-rs-core/src/lib.rs @@ -4,6 +4,8 @@ #[cfg(feature = "alloc")] extern crate alloc; +#[cfg(feature = "alloc")] +use alloc::format; use core::{error, fmt}; @@ -36,7 +38,24 @@ pub enum Error { impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - todo!() + write!( + f, + "{}", + match self { + #[cfg(feature = "alloc")] + Error::InvalidArg { arg, reason } => + format!("Invalid Argument on arg = {} with reason = {}", arg, reason), + #[cfg(not(feature = "alloc"))] + Error::InvalidArg => + "There were invalid arguments. Reasons not shown without `alloc` feature.", + #[cfg(feature = "alloc")] + Error::ConflictArg { reason } => + format!("Conflicting Arguments with reason = {}", reason), + #[cfg(not(feature = "alloc"))] + Error::ConflictArg => + "There were conflicting arguments. Reasons not shown without `alloc` feature.", + } + ) } } From 333c9924e681a890594113e3aff25f841bcd505f Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sat, 26 Apr 2025 17:39:53 +0800 Subject: [PATCH 06/68] Add error messages arising from ndarray_conv --- sci-rs-core/src/lib.rs | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/sci-rs-core/src/lib.rs b/sci-rs-core/src/lib.rs index 8de2ed82..83fb565f 100644 --- a/sci-rs-core/src/lib.rs +++ b/sci-rs-core/src/lib.rs @@ -34,6 +34,12 @@ pub enum Error { /// Two or more optional arguments passed into functions conflict. #[cfg(not(feature = "alloc"))] ConflictArg, + /// Errors raised by [ndarray_conv::Error] + #[cfg(feature = "alloc")] + Conv { reason: alloc::string::String }, + /// Errors raised by [ndarray_conv::Error] + #[cfg(not(feature = "alloc"))] + Conv, } impl fmt::Display for Error { @@ -54,6 +60,13 @@ impl fmt::Display for Error { #[cfg(not(feature = "alloc"))] Error::ConflictArg => "There were conflicting arguments. Reasons not shown without `alloc` feature.", + #[cfg(feature = "alloc")] + Error::Conv { reason } => format!( + "An error occurred during the convolution from ndarray_conv with reason {}.", + reason + ), + #[cfg(not(feature = "alloc"))] + Error::Conv => "An error occurred during the convolution from ndarray_conv. Reasons not shown without `alloc` feature.", } ) } From bfcbc37e92ba3961896481ff2308705311e15ad4 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sun, 27 Apr 2025 19:10:39 +0800 Subject: [PATCH 07/68] Add linear convolution through ndarray_conv in num-rs space There will be other functions in sci-rs::signal space that uses the linear convolution. --- sci-rs-core/Cargo.toml | 1 + sci-rs-core/src/lib.rs | 2 + sci-rs-core/src/num_rs/convolve/mod.rs | 128 +++++++++++++++++++++++++ sci-rs-core/src/num_rs/mod.rs | 2 + 4 files changed, 133 insertions(+) diff --git a/sci-rs-core/Cargo.toml b/sci-rs-core/Cargo.toml index 8ff58964..2a1e814e 100644 --- a/sci-rs-core/Cargo.toml +++ b/sci-rs-core/Cargo.toml @@ -27,3 +27,4 @@ std = ['alloc'] [dependencies] ndarray = { version = "0.16.1", default-features = false } ndarray-conv = { version = "0.4.1" } +num-traits = { version = "0.2.15", default-features = false } diff --git a/sci-rs-core/src/lib.rs b/sci-rs-core/src/lib.rs index 83fb565f..f52e09a4 100644 --- a/sci-rs-core/src/lib.rs +++ b/sci-rs-core/src/lib.rs @@ -74,4 +74,6 @@ impl fmt::Display for Error { impl error::Error for Error {} +/// Collection of numpy-like functions for use by sci-rs. +/// Provide behaviour parity against Numpy, even if the types are not identical. pub mod num_rs; diff --git a/sci-rs-core/src/num_rs/convolve/mod.rs b/sci-rs-core/src/num_rs/convolve/mod.rs index e927424e..78da2f06 100644 --- a/sci-rs-core/src/num_rs/convolve/mod.rs +++ b/sci-rs-core/src/num_rs/convolve/mod.rs @@ -1,5 +1,10 @@ mod ndarray_conv_binds; +use crate::{Error, Result}; +use alloc::string::ToString; +use ndarray::{Array1, ArrayView1}; +use ndarray_conv::{ConvExt, PaddingMode}; + /// Convolution mode determines behavior near edges and output size pub enum ConvolveMode { /// Full convolution, output size is `in1.len() + in2.len() - 1` @@ -9,3 +14,126 @@ pub enum ConvolveMode { /// Same convolution, output size is `in1.len()` Same, } + +/// Best effort parallel behaviour with numpy's convolve method. We take `v` as the convolution +/// kernel. +/// +/// Returns the discrete, linear convolution of two one-dimensional sequences. +/// +/// # Parameters +/// * `a` : (N,) [[array_like]]([ndarray::Array1]) +/// Signal to be (linearly) convolved. +/// * `v` : (M,) [[array_like]]([ndarray::Array1]) +/// Second one-dimensional input array. Convolution kernel by reference. +/// * `mode` : [ConvolveMode] +/// [ConvolveMode::Full]: +/// By default, mode is 'full'. This returns the convolution at each point of overlap, with an +/// output shape of (N+M-1,). At the end-points of the convolution, the signals do not overlap +/// completely, and boundary effects may be seen. +/// +/// [ConvolveMode::Same]: +/// Mode 'same' returns output of length ``max(M, N)``. Boundary effects are still visible. +/// +/// [ConvolveMode::Valid]: +/// Mode 'valid' returns output of length ``max(M, N) - min(M, N) + 1``. The convolution +/// product is only given for points where the signals overlap completely. Values outside the +/// signal boundary have no effect. +/// +/// # Panics +/// We assume that `v` is shorter than `a`. +/// +/// # Examples +/// With [ConvolveMode::Full]: +/// ``` +/// use ndarray::array; +/// use sci_rs_core::num_rs::{ConvolveMode, convolve}; +/// +/// let a = array![1., 2., 3.]; +/// let v = array![0., 1., 0.5]; +/// +/// let expected = array![0., 1., 2.5, 4., 1.5]; +/// let result = convolve(a, (&v).into(), ConvolveMode::Full).unwrap(); +/// assert_eq!(result, expected); +/// ``` +/// With [ConvolveMode::Same]: +/// ``` +/// use ndarray::array; +/// use sci_rs_core::num_rs::{ConvolveMode, convolve}; +/// +/// let a = array![1., 2., 3.]; +/// let v = array![0., 1., 0.5]; +/// +/// let expected = array![1., 2.5, 4.]; +/// let result = convolve(a, (&v).into(), ConvolveMode::Same).unwrap(); +/// assert_eq!(result, expected); +/// ``` +/// With [ConvolveMode::Same]: +/// ``` +/// use ndarray::array; +/// use sci_rs_core::num_rs::{ConvolveMode, convolve}; +/// +/// let a = array![1., 2., 3.]; +/// let v = array![0., 1., 0.5]; +/// +/// let expected = array![2.5]; +/// let result = convolve(a, (&v).into(), ConvolveMode::Valid).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn convolve(a: Array1, v: ArrayView1, mode: ConvolveMode) -> Result> +where + T: num_traits::NumAssign + core::marker::Copy + core::fmt::Debug, +{ + // Treat v as the convolution kernel. + debug_assert!(v.len() <= a.len()); + + // Flip the convolution kernel (see [ndarray_conv#6](https://github.com/TYPEmber/ndarray-conv/issues/6)) + // waiting for ndarray_conv v0.4.2 to not require for us to flip + let v: Array1<_> = { + let mut v = v.to_vec(); + v.reverse(); + v.into() + }; + + // Convolve + a.conv(&v, mode.into(), PaddingMode::Zeros) + .map_err(|e| Error::Conv { + reason: e.to_string(), + }) +} + +#[cfg(test)] +mod linear_convolve { + use super::*; + use alloc::vec; + use ndarray::array; + + #[test] + fn full() { + let a = array![1., 2., 3.]; + let v = array![0., 1., 0.5]; + + let expected = array![0., 1., 2.5, 4., 1.5]; + let result = convolve(a, (&v).into(), ConvolveMode::Full).unwrap(); + assert_eq!(result, expected); + } + + #[test] + fn same() { + let a = array![1., 2., 3.]; + let v = array![0., 1., 0.5]; + + let expected = array![1., 2.5, 4.]; + let result = convolve(a, (&v).into(), ConvolveMode::Same).unwrap(); + assert_eq!(result, expected); + } + + #[test] + fn valid() { + let a = array![1., 2., 3.]; + let v = array![0., 1., 0.5]; + + let expected = array![2.5]; + let result = convolve(a, (&v).into(), ConvolveMode::Valid).unwrap(); + assert_eq!(result, expected); + } +} diff --git a/sci-rs-core/src/num_rs/mod.rs b/sci-rs-core/src/num_rs/mod.rs index 424dc967..e6e9679c 100644 --- a/sci-rs-core/src/num_rs/mod.rs +++ b/sci-rs-core/src/num_rs/mod.rs @@ -1,2 +1,4 @@ +#[cfg(feature = "alloc")] mod convolve; +#[cfg(feature = "alloc")] pub use convolve::*; From ea43d53495ee1c0ae692749f8b738895b7919096 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Mon, 28 Apr 2025 15:42:05 +0800 Subject: [PATCH 08/68] Change np-convolve to use ArrayView instead of ArrayBase in the 1st arg One can use into the Make and ArrayView from Array but not the other way round, this makes the function signature abit more friendly without needing to use `.to_owned()` needlessly. --- sci-rs-core/src/num_rs/convolve/mod.rs | 28 +++++++++++++++++--------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/sci-rs-core/src/num_rs/convolve/mod.rs b/sci-rs-core/src/num_rs/convolve/mod.rs index 78da2f06..a4f7a79b 100644 --- a/sci-rs-core/src/num_rs/convolve/mod.rs +++ b/sci-rs-core/src/num_rs/convolve/mod.rs @@ -24,7 +24,7 @@ pub enum ConvolveMode { /// * `a` : (N,) [[array_like]]([ndarray::Array1]) /// Signal to be (linearly) convolved. /// * `v` : (M,) [[array_like]]([ndarray::Array1]) -/// Second one-dimensional input array. Convolution kernel by reference. +/// Second one-dimensional input array. /// * `mode` : [ConvolveMode] /// [ConvolveMode::Full]: /// By default, mode is 'full'. This returns the convolution at each point of overlap, with an @@ -52,7 +52,7 @@ pub enum ConvolveMode { /// let v = array![0., 1., 0.5]; /// /// let expected = array![0., 1., 2.5, 4., 1.5]; -/// let result = convolve(a, (&v).into(), ConvolveMode::Full).unwrap(); +/// let result = convolve((&a).into(), (&v).into(), ConvolveMode::Full).unwrap(); /// assert_eq!(result, expected); /// ``` /// With [ConvolveMode::Same]: @@ -64,7 +64,7 @@ pub enum ConvolveMode { /// let v = array![0., 1., 0.5]; /// /// let expected = array![1., 2.5, 4.]; -/// let result = convolve(a, (&v).into(), ConvolveMode::Same).unwrap(); +/// let result = convolve((&a).into(), (&v).into(), ConvolveMode::Same).unwrap(); /// assert_eq!(result, expected); /// ``` /// With [ConvolveMode::Same]: @@ -76,11 +76,12 @@ pub enum ConvolveMode { /// let v = array![0., 1., 0.5]; /// /// let expected = array![2.5]; -/// let result = convolve(a, (&v).into(), ConvolveMode::Valid).unwrap(); +/// let result = convolve((&a).into(), (&v).into(), ConvolveMode::Valid).unwrap(); /// assert_eq!(result, expected); /// ``` -pub fn convolve(a: Array1, v: ArrayView1, mode: ConvolveMode) -> Result> +pub fn convolve(a: ArrayView1, v: ArrayView1, mode: ConvolveMode) -> Result> where + // ? Debug for ndarray_conv::ConvExt::conv T: num_traits::NumAssign + core::marker::Copy + core::fmt::Debug, { // Treat v as the convolution kernel. @@ -95,10 +96,17 @@ where }; // Convolve - a.conv(&v, mode.into(), PaddingMode::Zeros) - .map_err(|e| Error::Conv { + let result = a.conv(&v, mode.into(), PaddingMode::Zeros); + #[cfg(feature = "alloc")] + { + result.map_err(|e| Error::Conv { reason: e.to_string(), }) + } + #[cfg(not(feature = "alloc"))] + { + result.map_err({ Error::Conv }) + } } #[cfg(test)] @@ -113,7 +121,7 @@ mod linear_convolve { let v = array![0., 1., 0.5]; let expected = array![0., 1., 2.5, 4., 1.5]; - let result = convolve(a, (&v).into(), ConvolveMode::Full).unwrap(); + let result = convolve((&a).into(), (&v).into(), ConvolveMode::Full).unwrap(); assert_eq!(result, expected); } @@ -123,7 +131,7 @@ mod linear_convolve { let v = array![0., 1., 0.5]; let expected = array![1., 2.5, 4.]; - let result = convolve(a, (&v).into(), ConvolveMode::Same).unwrap(); + let result = convolve((&a).into(), (&v).into(), ConvolveMode::Same).unwrap(); assert_eq!(result, expected); } @@ -133,7 +141,7 @@ mod linear_convolve { let v = array![0., 1., 0.5]; let expected = array![2.5]; - let result = convolve(a, (&v).into(), ConvolveMode::Valid).unwrap(); + let result = convolve((&a).into(), (&v).into(), ConvolveMode::Valid).unwrap(); assert_eq!(result, expected); } } From aee705df96c3a3b21504bfdc30aa18094ed9a38e Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Mon, 28 Apr 2025 15:53:09 +0800 Subject: [PATCH 09/68] Remove kernel vs signal debug assertion There is nothing distinguishing between kernel and signal with both arguments now being identical in type. --- sci-rs-core/src/num_rs/convolve/mod.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/sci-rs-core/src/num_rs/convolve/mod.rs b/sci-rs-core/src/num_rs/convolve/mod.rs index a4f7a79b..9a9dc213 100644 --- a/sci-rs-core/src/num_rs/convolve/mod.rs +++ b/sci-rs-core/src/num_rs/convolve/mod.rs @@ -84,9 +84,6 @@ where // ? Debug for ndarray_conv::ConvExt::conv T: num_traits::NumAssign + core::marker::Copy + core::fmt::Debug, { - // Treat v as the convolution kernel. - debug_assert!(v.len() <= a.len()); - // Flip the convolution kernel (see [ndarray_conv#6](https://github.com/TYPEmber/ndarray-conv/issues/6)) // waiting for ndarray_conv v0.4.2 to not require for us to flip let v: Array1<_> = { From 1b484a060ad397f578d135841d678a0c0aeaff6b Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Tue, 29 Apr 2025 12:30:44 +0800 Subject: [PATCH 10/68] Add sci-rs-core Cargo.toml --- sci-rs-core/Cargo.toml | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 sci-rs-core/Cargo.toml diff --git a/sci-rs-core/Cargo.toml b/sci-rs-core/Cargo.toml new file mode 100644 index 00000000..2a1e814e --- /dev/null +++ b/sci-rs-core/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "sci-rs-core" +version = "0.0.0" +edition = "2021" +authors = ["Jacob Trueb "] +description = "Core library for sci-rs internals." +license = "MIT OR Apache-2.0" +repository = "https://github.com/qsib-cbie/sci-rs.git" +homepage = "https://github.com/qsib-cbie/sci-rs.git" +readme = "../README.md" +keywords = ["scipy", "dsp", "signal", "filter", "design"] +categories = ["science", "mathematics", "no-std", "embedded"] + + +[package.metadata.docs.rs] +all-features = true + +[features] +default = ['alloc'] + +# Allow allocating vecs, matrices, etc. +alloc = [] + +# Enable FFT and standard library features +std = ['alloc'] + +[dependencies] +ndarray = { version = "0.16.1", default-features = false } +ndarray-conv = { version = "0.4.1" } +num-traits = { version = "0.2.15", default-features = false } From fc251347dc02c3913a4388d211c1b135c8af0730 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Tue, 29 Apr 2025 12:38:34 +0800 Subject: [PATCH 11/68] Build sci-rs-core with features used to build sci-rs --- sci-rs/Cargo.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sci-rs/Cargo.toml b/sci-rs/Cargo.toml index 1269f678..99647819 100644 --- a/sci-rs/Cargo.toml +++ b/sci-rs/Cargo.toml @@ -19,10 +19,10 @@ all-features = true default = ['alloc'] # Allow allocating vecs, matrices, etc. -alloc = ['nalgebra/alloc', 'nalgebra/libm', 'kalmanfilt/alloc'] +alloc = ['nalgebra/alloc', 'nalgebra/libm', 'kalmanfilt/alloc', 'sci-rs-core/alloc'] # Enable FFT and standard library features -std = ['nalgebra/std', 'nalgebra/macros', 'rustfft', 'alloc'] +std = ['nalgebra/std', 'nalgebra/macros', 'rustfft', 'alloc','sci-rs-core/std'] # Enable debug plotting through python system calls plot = ['std'] @@ -36,7 +36,7 @@ lstsq = { version = "0.6.0", default-features = false } rustfft = { version = "6.2.0", optional = true } kalmanfilt = { version = "0.3.0", default-features = false } gaussfilt = { version = "0.1.3", default-features = false } -sci-rs-core = { path = "../sci-rs-core" } +sci-rs-core = { path = "../sci-rs-core", default-features = false } [dev-dependencies] approx = "0.5.1" From f21a6bcdfe7f45304083821e52dffe9a1789e5c2 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sun, 6 Jul 2025 21:05:59 +0800 Subject: [PATCH 12/68] Update ndarray-conv to 0.5.0 in core This version of ndarray-conv accounts for both cross-correlation and convolution. --- sci-rs-core/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sci-rs-core/Cargo.toml b/sci-rs-core/Cargo.toml index 2a1e814e..26fe7017 100644 --- a/sci-rs-core/Cargo.toml +++ b/sci-rs-core/Cargo.toml @@ -26,5 +26,5 @@ std = ['alloc'] [dependencies] ndarray = { version = "0.16.1", default-features = false } -ndarray-conv = { version = "0.4.1" } +ndarray-conv = { version = "0.5.0" } num-traits = { version = "0.2.15", default-features = false } From 182658d3fe37d436a4f9ebb64213e9a97464bd46 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Mon, 14 Jul 2025 22:17:17 +0800 Subject: [PATCH 13/68] Fix linear convolve behaviour Behavior of convolve in ndarray-conv was changed with 0.5.0. --- sci-rs-core/src/num_rs/convolve/mod.rs | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/sci-rs-core/src/num_rs/convolve/mod.rs b/sci-rs-core/src/num_rs/convolve/mod.rs index 9a9dc213..2d80c7f7 100644 --- a/sci-rs-core/src/num_rs/convolve/mod.rs +++ b/sci-rs-core/src/num_rs/convolve/mod.rs @@ -81,17 +81,8 @@ pub enum ConvolveMode { /// ``` pub fn convolve(a: ArrayView1, v: ArrayView1, mode: ConvolveMode) -> Result> where - // ? Debug for ndarray_conv::ConvExt::conv - T: num_traits::NumAssign + core::marker::Copy + core::fmt::Debug, + T: num_traits::NumAssign + core::marker::Copy, { - // Flip the convolution kernel (see [ndarray_conv#6](https://github.com/TYPEmber/ndarray-conv/issues/6)) - // waiting for ndarray_conv v0.4.2 to not require for us to flip - let v: Array1<_> = { - let mut v = v.to_vec(); - v.reverse(); - v.into() - }; - // Convolve let result = a.conv(&v, mode.into(), PaddingMode::Zeros); #[cfg(feature = "alloc")] From 54fc05fa253e05437df0da0fffae75d977947077 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Wed, 30 Apr 2025 11:28:42 +0800 Subject: [PATCH 14/68] Add initial lfilter implementation for 1d This is a squash of 3 commits: 1. Failing unit test with a panic atm, more work needs to be done. 2. Further lfilter fix 3. Fix lfilter for 1D signal Had to step into python to actually see what the slicing was doing... --- sci-rs/src/signal/filter/lfilter.rs | 178 ++++++++++++++++++++++++++++ sci-rs/src/signal/filter/mod.rs | 8 +- 2 files changed, 185 insertions(+), 1 deletion(-) create mode 100644 sci-rs/src/signal/filter/lfilter.rs diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs new file mode 100644 index 00000000..440b8d49 --- /dev/null +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -0,0 +1,178 @@ +use alloc::vec::Vec; +use core::marker::Copy; +use ndarray::{ + Array, Array1, ArrayBase, ArrayView, ArrayView1, ArrayViewMut1, Axis, Data, Dim, IntoDimension, + Ix, IxDyn, RemoveAxis, SliceInfo, SliceInfoElem, +}; +use num_traits::{FromPrimitive, Num, NumAssign}; + +/// /// Internal function for obtaining length of all axis as array from input from input. +/// +/// This is almost the same as `a.shape()`, but is a array [T; N] instead of a Vec. +/// +/// # Parameters +/// `a`: Array whose shape is needed as a slice. +fn ndarray_ndim_as_array<'a, S, T, const N: usize>(a: &ArrayBase>) -> [Ix; N] +where + [Ix; N]: IntoDimension>, + Dim<[Ix; N]>: RemoveAxis, + T: FromPrimitive, + S: Data + 'a, +{ + let mut tmp = [0; N]; + (0..N).for_each(|axis| tmp[axis] = a.len_of(Axis(axis))); + tmp +} + +/// Filter data along one-dimension with an IIR or FIR filter. +/// +/// Filter a data sequence, `x`, using a digital filter. This works for many +/// fundamental data types (including Object type). The filter is a direct +/// form II transposed implementation of the standard difference equation +/// (see Notes). +/// +/// The function [super::sosfilt] (and filter design using ``output='sos'``) should be +/// preferred over `lfilter` for most filtering tasks, as second-order sections +/// have fewer numerical problems. +/// +/// ## Parameters +/// * `b` : array_like +/// The numerator coefficient vector in a 1-D sequence. +/// * `a` : array_like +/// The denominator coefficient vector in a 1-D sequence. If ``a[0]`` +/// is not 1, then both `a` and `b` are normalized by ``a[0]``. +/// * `x` : array_like +/// An N-dimensional input array. +/// * `axis`: Option +/// Default to `-1` if `None`. +/// Panics in accordance with [ndarray::ArrayBase::axis_iter]. +/// * `zi`: array_like +/// Currently not implemented. +/// Initial conditions for filter delays. It is a vector +/// (or array of vectors for an N-dimensional input) of length +/// ``max(len(a), len(b)) - 1``. If `zi` is None or is not given then +/// initial rest is assumed. See `lfiltic` and [super::lfilter_zi] for more information. +/// +/// ## Returns +/// * `y` : array +/// The output of the digital filter. +/// * `zf` : array, optional +/// If `zi` is None, this is not returned, otherwise, `zf` holds the +/// final filter delay values. +/// +/// # See Also +/// * [super::lfilter_zi] +/// +/// # Notes +/// +/// # Examples +/// On a 1-dimensional signal: +/// ``` +/// use ndarray::{array, ArrayBase, Dim, Ix, OwnedRepr}; +/// use sci_rs::signal::filter::lfilter; +/// +/// let b = array![5., 4., 1., 2.]; +/// let a = array![1.]; +/// let x = array![1., 2., 3., 4., 3., 5., 6.]; +/// let expected = array![5., 14., 24., 36., 38., 47., 61.]; +/// let (result, _) = lfilter((&b).into(), (&a).into(), x, None, None); +/// +/// assert_eq!(result.len(), expected.len()); +/// result.into_iter().zip(expected).for_each(|(r, e)| { +/// assert_eq!(r, e); +/// }) +/// ``` +/// +/// # Panics +/// Currently yet to implement for `zi = Some(...)`, nor for `a.len() > 1`. +/// Panics if axis is out or range. +// NOTE: zi's TypeSig inherits from lfilter's output, in accordance with examples section of +// documentation, both lfilter_zi and this should eventually support NDArray. +#[cfg(feature = "alloc")] +pub fn lfilter<'a, T, S, const N: usize>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: ArrayBase>, + axis: Option, + zi: Option>, +) -> (Array>, Option>) +where + [Ix; N]: IntoDimension>, + Dim<[Ix; N]>: RemoveAxis, + T: NumAssign + FromPrimitive + Copy + 'a, + S: Data + 'a, +{ + if a.len() > 1 { + unimplemented!() + }; + if zi.is_some() { + unimplemented!() + }; + + // We make a best effort to convert into appropriate axis object. + let (axis, axis_inner): (Axis, usize) = { + let axis_inner: isize = axis.unwrap_or(-1); + if axis_inner >= 0 { + (Axis(axis_inner as usize), axis_inner as usize) + } else { + let axis_inner = (x.ndim() as isize + axis_inner) as usize; + (Axis(axis_inner), axis_inner) + } + }; + + let b: Array1 = b.mapv(|bi| bi / a[0]); + + let (out_dim, out_dim_inner): (Dim<_>, [Ix; N]) = { + let mut tmp: [Ix; N] = ndarray_ndim_as_array(&x); + (IntoDimension::into_dimension(tmp), tmp) + }; + let mut out = ArrayBase::zeros(out_dim); + + out.lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .for_each(|(mut out_slice, y)| { + // np.convolve uses full mode, but is eventually slices out with + // ```py + // ind = out_full.ndim * [slice(None)] # creates the "[:, :, ..., :]" slicer + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) # [:out_full.shape[..] - len(b) + 1] + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + let out_full = convolve(y, (&b).into(), ConvolveMode::Full).unwrap(); + let out_full_slice: ArrayView1 = out_full + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .reborrow(); + out_full_slice.assign_to(&mut out_slice); + }); + + (out, None) +} + +#[cfg(test)] +mod test { + use super::*; + use alloc::vec; + + #[test] + fn one_dim_no_zi() { + use ndarray::{array, ArrayBase, Dim, Ix, OwnedRepr}; + let b = array![5., 4., 1., 2.]; + let a = array![1.]; + let x = array![1., 2., 3., 4., 3., 5., 6.]; + let expected = array![5., 14., 24., 36., 38., 47., 61.]; + + let (result, _) = lfilter((&b).into(), (&a).into(), x, None, None); + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_eq!(r, e); + }) + } +} diff --git a/sci-rs/src/signal/filter/mod.rs b/sci-rs/src/signal/filter/mod.rs index 805c65b7..5e5b4d29 100644 --- a/sci-rs/src/signal/filter/mod.rs +++ b/sci-rs/src/signal/filter/mod.rs @@ -13,7 +13,9 @@ pub use kalmanfilt::kalman::kalman_filter; /// pub use gaussfilt as gaussian_filter; -/// Digital IIR/FIR filter design +/// Digital IIR/FIR filter design +/// Functions located in the [`Filter design` section of +/// `scipy.signal`](https://docs.scipy.org/doc/scipy/reference/signal.html#filter-design). pub mod design; mod ext; @@ -22,6 +24,8 @@ mod sosfilt; pub use ext::*; pub use sosfilt::*; +#[cfg(feature = "alloc")] +mod lfilter; #[cfg(feature = "alloc")] mod lfilter_zi; #[cfg(feature = "alloc")] @@ -31,6 +35,8 @@ mod sosfilt_zi; #[cfg(feature = "alloc")] mod sosfiltfilt; +#[cfg(feature = "alloc")] +pub use lfilter::*; #[cfg(feature = "alloc")] pub use lfilter_zi::*; #[cfg(feature = "alloc")] From 6ff02821790e91ab7ad2994557672258158c1874 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Wed, 11 Jun 2025 15:25:31 +0800 Subject: [PATCH 15/68] Change lfilter to return an error if axis argument is invalid Unfortunately, it's not quite possible to specify at compile time that `axis: isize` sastisifes the bounds wrt `N`. Also check that a[0] exists and is non-zero. --- sci-rs/src/signal/filter/lfilter.rs | 66 +++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 8 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 440b8d49..5eb3bee6 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -5,6 +5,7 @@ use ndarray::{ Ix, IxDyn, RemoveAxis, SliceInfo, SliceInfoElem, }; use num_traits::{FromPrimitive, Num, NumAssign}; +use sci_rs_core::{Error, Result}; /// /// Internal function for obtaining length of all axis as array from input from input. /// @@ -75,7 +76,7 @@ where /// let a = array![1.]; /// let x = array![1., 2., 3., 4., 3., 5., 6.]; /// let expected = array![5., 14., 24., 36., 38., 47., 61.]; -/// let (result, _) = lfilter((&b).into(), (&a).into(), x, None, None); +/// let (result, _) = lfilter((&b).into(), (&a).into(), x, None, None).unwrap(); /// /// assert_eq!(result.len(), expected.len()); /// result.into_iter().zip(expected).for_each(|(r, e)| { @@ -95,7 +96,7 @@ pub fn lfilter<'a, T, S, const N: usize>( x: ArrayBase>, axis: Option, zi: Option>, -) -> (Array>, Option>) +) -> Result<(Array>, Option>)> where [Ix; N]: IntoDimension>, Dim<[Ix; N]>: RemoveAxis, @@ -105,9 +106,21 @@ where if a.len() > 1 { unimplemented!() }; - if zi.is_some() { - unimplemented!() - }; + + // Before we convert into the appropriate axis object, we have to check at runtime that the + // axis value specified is within -N <= axis < N. + if axis.is_some_and(|axis| { + !(if axis < 0 { + axis.unsigned_abs() <= N + } else { + axis.unsigned_abs() < N + }) + }) { + return Err(Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + }); + } // We make a best effort to convert into appropriate axis object. let (axis, axis_inner): (Axis, usize) = { @@ -120,6 +133,19 @@ where } }; + if a.is_empty() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: + "Empty 1D array will result in inf/nan result. Consider setting to `array![1.]`." + .into(), + }); + } else if a.first().unwrap().is_zero() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: "First element of a found to be zero.".into(), + }); + } let b: Array1 = b.mapv(|bi| bi / a[0]); let (out_dim, out_dim_inner): (Dim<_>, [Ix; N]) = { @@ -152,27 +178,51 @@ where out_full_slice.assign_to(&mut out_slice); }); - (out, None) + Ok((out, None)) } #[cfg(test)] mod test { use super::*; use alloc::vec; + use ndarray::{array, ArrayBase, Dim, Ix, OwnedRepr}; #[test] fn one_dim_no_zi() { - use ndarray::{array, ArrayBase, Dim, Ix, OwnedRepr}; let b = array![5., 4., 1., 2.]; let a = array![1.]; let x = array![1., 2., 3., 4., 3., 5., 6.]; let expected = array![5., 14., 24., 36., 38., 47., 61.]; - let (result, _) = lfilter((&b).into(), (&a).into(), x, None, None); + let Ok((result, None)) = lfilter((&b).into(), (&a).into(), x, None, None) else { + panic!("Should not have errored") + }; assert_eq!(result.len(), expected.len()); result.into_iter().zip(expected).for_each(|(r, e)| { assert_eq!(r, e); }) } + + #[test] + fn invalid_axis() { + let b = array![5., 4., 1., 2.]; + let a = array![1.]; + let x = array![1., 2., 3., 4., 3., 5., 6.]; + + let result = lfilter((&b).into(), (&a).into(), x.clone(), Some(2), None); + assert!(result.is_err()); + + let result = lfilter((&b).into(), (&a).into(), x.clone(), Some(1), None); + assert!(result.is_err()); + + let result = lfilter((&b).into(), (&a).into(), x.clone(), Some(0), None); + assert!(result.is_ok()); + + let result = lfilter((&b).into(), (&a).into(), x.clone(), Some(-1), None); + assert!(result.is_ok()); + + let result = lfilter((&b).into(), (&a).into(), x, Some(-2), None); + assert!(result.is_err()); + } } From 391cdc77472acb06b641dcd3039a77090fa1f057 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Thu, 12 Jun 2025 16:28:21 +0800 Subject: [PATCH 16/68] Remove unnecessary reborrow Variable lifetime capture through reborrow only for subsequent `assign_to` is unnecessary. --- sci-rs/src/signal/filter/lfilter.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 5eb3bee6..366c16b5 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -165,7 +165,7 @@ where // ``` use sci_rs_core::num_rs::{convolve, ConvolveMode}; let out_full = convolve(y, (&b).into(), ConvolveMode::Full).unwrap(); - let out_full_slice: ArrayView1 = out_full + out_full .slice( SliceInfo::try_from([SliceInfoElem::Slice { start: 0, @@ -174,8 +174,7 @@ where }]) .unwrap(), ) - .reborrow(); - out_full_slice.assign_to(&mut out_slice); + .assign_to(&mut out_slice); }); Ok((out, None)) From 90b0bd4dc2d7fa80050c36904da077dbaea2e9b0 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Fri, 13 Jun 2025 16:19:30 +0800 Subject: [PATCH 17/68] Assert that lfilter arguments are not 0-dimensional arrays This is later checked in `_validate_x()` function. We instead frontload it into the body. --- sci-rs/src/signal/filter/lfilter.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 366c16b5..4fcc1bb1 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -103,6 +103,14 @@ where T: NumAssign + FromPrimitive + Copy + 'a, S: Data + 'a, { + if N == 0 { + // `_validate_x` condition - ndarray allows for 0-dimensional arrays + return Err(Error::InvalidArg { + arg: "x".into(), + reason: "Linear filter requires at least 1-dimensional `x`.".into(), + }); + } + if a.len() > 1 { unimplemented!() }; From c6734dea13523c907010fa05117c4830ab9ac73b Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Tue, 15 Jul 2025 18:25:39 +0800 Subject: [PATCH 18/68] Remove unnecessary cfg parameter `alloc` feature flag for lfilter is already specified in the parent module. --- sci-rs/src/signal/filter/lfilter.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 4fcc1bb1..48bf4045 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -89,7 +89,6 @@ where /// Panics if axis is out or range. // NOTE: zi's TypeSig inherits from lfilter's output, in accordance with examples section of // documentation, both lfilter_zi and this should eventually support NDArray. -#[cfg(feature = "alloc")] pub fn lfilter<'a, T, S, const N: usize>( b: ArrayView1<'a, T>, a: ArrayView1<'a, T>, From fc31076b94b59f7673a607bbee2d51c94a1a6a39 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Tue, 15 Jul 2025 21:56:13 +0800 Subject: [PATCH 19/68] Add more lfilter tests in the context of 1-dimensional input without zi Progressively make the test cases more complex so that its easier to identify. We also clarify the filter as being FIR as the `a` polynomial has been set to 1. --- sci-rs/src/signal/filter/lfilter.rs | 49 ++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 14 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 48bf4045..7fd4e0c6 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -191,23 +191,44 @@ where mod test { use super::*; use alloc::vec; + use approx::assert_relative_eq; use ndarray::{array, ArrayBase, Dim, Ix, OwnedRepr}; + // Tests that have a = [1.] with zi = None on input x with dim = 1. #[test] - fn one_dim_no_zi() { - let b = array![5., 4., 1., 2.]; - let a = array![1.]; - let x = array![1., 2., 3., 4., 3., 5., 6.]; - let expected = array![5., 14., 24., 36., 38., 47., 61.]; - - let Ok((result, None)) = lfilter((&b).into(), (&a).into(), x, None, None) else { - panic!("Should not have errored") - }; - - assert_eq!(result.len(), expected.len()); - result.into_iter().zip(expected).for_each(|(r, e)| { - assert_eq!(r, e); - }) + fn one_dim_fir_no_zi() { + { + // Tests for b.sum() > 1. + let b = array![5., 4., 1., 2.]; + let a = array![1.]; + let x = array![1., 2., 3., 4., 3., 5., 6.]; + let expected = array![5., 14., 24., 36., 38., 47., 61.]; + + let Ok((result, None)) = lfilter((&b).into(), (&a).into(), x, None, None) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_eq!(r, e); + }) + } + { + // Tests for b[i] < 0 for some i, such that b.sum() = 1. + let b = array![0.7, -0.3, 0.6]; + let a = array![1.]; + let x = array![1., 2., 3., 4., 3., 5., 6.]; + let expected = array![0.7, 1.1, 2.1, 3.1, 2.7, 5., 4.5]; + + let Ok((result, None)) = lfilter((&b).into(), (&a).into(), x, None, None) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } } #[test] From f5e6579ec4d8600dd1be80cb1e551597a48f6a69 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Wed, 16 Jul 2025 19:53:16 +0800 Subject: [PATCH 20/68] Refactor Axis yielding logic This helps reduce from the flow of lfilter code. --- sci-rs/src/signal/filter/lfilter.rs | 69 ++++++++++++++++++----------- 1 file changed, 44 insertions(+), 25 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 7fd4e0c6..ddaa53d2 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -25,6 +25,49 @@ where tmp } +/// Internal function for casting into [Axis] and appropriate usize from isize. +/// +/// # Parameters +/// axis: The user-specificed axis which filter is to be applied on. +/// x: The input-data whose axis object that will be manipulated against. +fn check_and_get_axis<'a, T, S, const N: usize>( + axis: Option, + x: &ArrayBase>, +) -> Result<(Axis, usize)> +where + [Ix; N]: IntoDimension>, + Dim<[Ix; N]>: RemoveAxis, + T: NumAssign + FromPrimitive + Copy + 'a, + S: Data + 'a, +{ + // Before we convert into the appropriate axis object, we have to check at runtime that the + // axis value specified is within -N <= axis < N. + if axis.is_some_and(|axis| { + !(if axis < 0 { + axis.unsigned_abs() <= N + } else { + axis.unsigned_abs() < N + }) + }) { + return Err(Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + }); + } + + // We make a best effort to convert into appropriate axis object. + let axis_inner: isize = axis.unwrap_or(-1); + if axis_inner >= 0 { + Ok((Axis(axis_inner as usize), axis_inner.unsigned_abs())) + } else { + let axis_inner = x + .ndim() + .checked_add_signed(axis_inner) + .expect("Invalid add to `axis` option"); + Ok((Axis(axis_inner), axis_inner)) + } +} + /// Filter data along one-dimension with an IIR or FIR filter. /// /// Filter a data sequence, `x`, using a digital filter. This works for many @@ -114,31 +157,7 @@ where unimplemented!() }; - // Before we convert into the appropriate axis object, we have to check at runtime that the - // axis value specified is within -N <= axis < N. - if axis.is_some_and(|axis| { - !(if axis < 0 { - axis.unsigned_abs() <= N - } else { - axis.unsigned_abs() < N - }) - }) { - return Err(Error::InvalidArg { - arg: "axis".into(), - reason: "index out of range.".into(), - }); - } - - // We make a best effort to convert into appropriate axis object. - let (axis, axis_inner): (Axis, usize) = { - let axis_inner: isize = axis.unwrap_or(-1); - if axis_inner >= 0 { - (Axis(axis_inner as usize), axis_inner as usize) - } else { - let axis_inner = (x.ndim() as isize + axis_inner) as usize; - (Axis(axis_inner), axis_inner) - } - }; + let (axis, axis_inner) = check_and_get_axis(axis, &x)?; if a.is_empty() { return Err(Error::InvalidArg { From a6c22f75562a33d3adb3399e314ddd2f1ebaa7dc Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Thu, 17 Jul 2025 18:43:54 +0800 Subject: [PATCH 21/68] Include linear_filter function Found within .cc file. Only provide the function signature first. This function is called when `a.len() > 1`. Function signature of lfilter is also changed to reflect what zi is expected to be from both lfilter and linear_filter. --- sci-rs/src/signal/filter/lfilter.rs | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index ddaa53d2..b8694eb1 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -137,8 +137,8 @@ pub fn lfilter<'a, T, S, const N: usize>( a: ArrayView1<'a, T>, x: ArrayBase>, axis: Option, - zi: Option>, -) -> Result<(Array>, Option>)> + zi: Option>>, +) -> Result<(Array>, Option>>)> where [Ix; N]: IntoDimension>, Dim<[Ix; N]>: RemoveAxis, @@ -154,7 +154,7 @@ where } if a.len() > 1 { - unimplemented!() + return linear_filter(b, a, x, axis, zi); }; let (axis, axis_inner) = check_and_get_axis(axis, &x)?; @@ -206,6 +206,23 @@ where Ok((out, None)) } +/// Internal function called by [lfilter] for situation a.len() > 1. +fn linear_filter<'a, T, S, const N: usize>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: ArrayBase>, + axis: Option, + zi: Option>>, +) -> Result<(Array>, Option>>)> +where + [Ix; N]: IntoDimension>, + Dim<[Ix; N]>: RemoveAxis, + T: NumAssign + FromPrimitive + Copy + 'a, + S: Data + 'a, +{ + todo!() +} + #[cfg(test)] mod test { use super::*; From c6d4504f8bbbbb61d47629f0926c67d2145e83cf Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Fri, 18 Jul 2025 21:33:00 +0800 Subject: [PATCH 22/68] Improve internal function name and documentation Made it clear that it was taking the shape instead of the ndim of a NDArray and making into a array `[T; N]`. --- sci-rs/src/signal/filter/lfilter.rs | 45 ++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index b8694eb1..410de1b4 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -7,13 +7,13 @@ use ndarray::{ use num_traits::{FromPrimitive, Num, NumAssign}; use sci_rs_core::{Error, Result}; -/// /// Internal function for obtaining length of all axis as array from input from input. +/// Internal function for obtaining length of all axis as array from input from input. /// /// This is almost the same as `a.shape()`, but is a array [T; N] instead of a Vec. /// /// # Parameters /// `a`: Array whose shape is needed as a slice. -fn ndarray_ndim_as_array<'a, S, T, const N: usize>(a: &ArrayBase>) -> [Ix; N] +fn ndarray_shape_as_array<'a, S, T, const N: usize>(a: &ArrayBase>) -> [Ix; N] where [Ix; N]: IntoDimension>, Dim<[Ix; N]>: RemoveAxis, @@ -172,26 +172,45 @@ where reason: "First element of a found to be zero.".into(), }); } - let b: Array1 = b.mapv(|bi| bi / a[0]); + let b: Array1 = b.mapv(|bi| bi / a[0]); // b /= a[0] - let (out_dim, out_dim_inner): (Dim<_>, [Ix; N]) = { - let mut tmp: [Ix; N] = ndarray_ndim_as_array(&x); + let (out_full_dim, out_full_dim_inner): (Dim<_>, [Ix; N]) = { + let mut tmp: [Ix; N] = ndarray_shape_as_array(&x); + tmp[axis_inner] += b.len_of(Axis(0)) - 1; // From np.convolve(..., 'full') (IntoDimension::into_dimension(tmp), tmp) }; - let mut out = ArrayBase::zeros(out_dim); - out.lanes_mut(axis) + let mut out_full: Array> = ArrayBase::zeros(out_full_dim); + out_full + .lanes_mut(axis) .into_iter() .zip(x.lanes(axis)) // Almost basically np.apply_along_axis - .for_each(|(mut out_slice, y)| { - // np.convolve uses full mode, but is eventually slices out with + .for_each(|(mut out_full_slice, y)| { + // np.convolve uses full mode // ```py - // ind = out_full.ndim * [slice(None)] # creates the "[:, :, ..., :]" slicer - // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) # [:out_full.shape[..] - len(b) + 1] + // out_full = np.apply_along_axis(lambda y: np.convolve(b, y), axis, x) // ``` use sci_rs_core::num_rs::{convolve, ConvolveMode}; - let out_full = convolve(y, (&b).into(), ConvolveMode::Full).unwrap(); - out_full + convolve(y, (&b).into(), ConvolveMode::Full) + .unwrap() + .assign_to(&mut out_full_slice); + }); + + let (out_dim, out_dim_inner) = { + let mut tmp: [Ix; N] = ndarray_shape_as_array(&x); + (IntoDimension::into_dimension(tmp), tmp) + }; + let mut out = ArrayBase::zeros(out_dim); + out.lanes_mut(axis) + .into_iter() + .zip(out_full.lanes(axis)) + .for_each(|(mut out_slice, out_full_slice)| { + // ```py + // # Create the [...; :out_full.shape[axis] - len(b) + 1; ...] at index=axis + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) + // out = out_full[tuple(ind)] + // ``` + out_full_slice .slice( SliceInfo::try_from([SliceInfoElem::Slice { start: 0, From a33bf926e74d0f0af6c370cd189aeba54ee60881 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Wed, 23 Jul 2025 17:33:48 +0800 Subject: [PATCH 23/68] Partially undo refactor to optimize `zi.is_none()` path Instead of refactoring the non-zi code to follow the structure of zi code, which results in an additional heap allocation, we instead branch the code so that the scenario where zi is None does not need to undergo the additional heap allocation step. --- sci-rs/src/signal/filter/lfilter.rs | 138 +++++++++++++++++++--------- 1 file changed, 94 insertions(+), 44 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 410de1b4..453072be 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -174,55 +174,105 @@ where } let b: Array1 = b.mapv(|bi| bi / a[0]); // b /= a[0] - let (out_full_dim, out_full_dim_inner): (Dim<_>, [Ix; N]) = { - let mut tmp: [Ix; N] = ndarray_shape_as_array(&x); - tmp[axis_inner] += b.len_of(Axis(0)) - 1; // From np.convolve(..., 'full') - (IntoDimension::into_dimension(tmp), tmp) - }; + if let Some(mut zi) = zi { + // Use a separate branch to avoid unnecessary heap allocation of `out_full` in `zi` = None + // case. - let mut out_full: Array> = ArrayBase::zeros(out_full_dim); - out_full - .lanes_mut(axis) - .into_iter() - .zip(x.lanes(axis)) // Almost basically np.apply_along_axis - .for_each(|(mut out_full_slice, y)| { - // np.convolve uses full mode - // ```py - // out_full = np.apply_along_axis(lambda y: np.convolve(b, y), axis, x) - // ``` - use sci_rs_core::num_rs::{convolve, ConvolveMode}; - convolve(y, (&b).into(), ConvolveMode::Full) - .unwrap() - .assign_to(&mut out_full_slice); - }); + let zi = { + // if zi.ndim != x.ndim { return Err(...) } is signature asserted. - let (out_dim, out_dim_inner) = { - let mut tmp: [Ix; N] = ndarray_shape_as_array(&x); - (IntoDimension::into_dimension(tmp), tmp) - }; - let mut out = ArrayBase::zeros(out_dim); - out.lanes_mut(axis) - .into_iter() - .zip(out_full.lanes(axis)) - .for_each(|(mut out_slice, out_full_slice)| { + todo!(); + zi + }; + + let (out_full_dim, out_full_dim_inner): (Dim<_>, [Ix; N]) = { + let mut tmp: [Ix; N] = ndarray_shape_as_array(&x); + tmp[axis_inner] += b.len_of(Axis(0)) - 1; // From np.convolve(..., 'full') + (IntoDimension::into_dimension(tmp), tmp) + }; + + let mut out_full: Array> = ArrayBase::zeros(out_full_dim); + out_full + .lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .for_each(|(mut out_full_slice, y)| { + // np.convolve uses full mode by default + // ```py + // out_full = np.apply_along_axis(lambda y: np.convolve(b, y), axis, x) + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + convolve(y, (&b).into(), ConvolveMode::Full) + .unwrap() + .assign_to(&mut out_full_slice); + }); + { // ```py - // # Create the [...; :out_full.shape[axis] - len(b) + 1; ...] at index=axis - // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) - // out = out_full[tuple(ind)] + // ind[axis] = slice(zi.shape[axis]) + // out_full[tuple(ind)] += zi // ``` - out_full_slice - .slice( - SliceInfo::try_from([SliceInfoElem::Slice { - start: 0, - end: Some(out_dim_inner[axis_inner] as isize), - step: 1, - }]) - .unwrap(), - ) - .assign_to(&mut out_slice); - }); + todo!() + }; - Ok((out, None)) + let (out_dim, out_dim_inner) = { + let mut tmp: [Ix; N] = ndarray_shape_as_array(&x); + (IntoDimension::into_dimension(tmp), tmp) + }; + let mut out = ArrayBase::zeros(out_dim); + out.lanes_mut(axis) + .into_iter() + .zip(out_full.lanes(axis)) + .for_each(|(mut out_slice, out_full_slice)| { + // ```py + // # Create the [...; :out_full.shape[axis] - len(b) + 1; ...] at index=axis + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) + // out = out_full[tuple(ind)] + // ``` + out_full_slice + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + }); + + Ok((out, todo!())) + } else { + let (out_dim, out_dim_inner): (Dim<_>, [Ix; N]) = { + let mut tmp: [Ix; N] = ndarray_shape_as_array(&x); + (IntoDimension::into_dimension(tmp), tmp) + }; + let mut out = ArrayBase::zeros(out_dim); + + out.lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .for_each(|(mut out_slice, y)| { + // np.convolve uses full mode, but is eventually slices out with + // ```py + // ind = out_full.ndim * [slice(None)] # creates the "[:, :, ..., :]" slice r + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) # [:out_full.shape[ ..] - len(b) + 1] + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + let out_full = convolve(y, (&b).into(), ConvolveMode::Full).unwrap(); + out_full + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + }); + + Ok((out, None)) + } } /// Internal function called by [lfilter] for situation a.len() > 1. From b8954e5e66663f306aeaf0ffa5d6cb938c27c4a7 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Thu, 24 Jul 2025 16:57:00 +0800 Subject: [PATCH 24/68] Propagate error from convolve instead of unwrap within lfilter If any error occurs in numpy-esque convolve, we can propagate it out instead of unwrapping it and panicking. --- sci-rs/src/signal/filter/lfilter.rs | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 453072be..9c10f38c 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -196,16 +196,15 @@ where .lanes_mut(axis) .into_iter() .zip(x.lanes(axis)) // Almost basically np.apply_along_axis - .for_each(|(mut out_full_slice, y)| { + .try_for_each(|(mut out_full_slice, y)| { // np.convolve uses full mode by default // ```py // out_full = np.apply_along_axis(lambda y: np.convolve(b, y), axis, x) // ``` use sci_rs_core::num_rs::{convolve, ConvolveMode}; - convolve(y, (&b).into(), ConvolveMode::Full) - .unwrap() - .assign_to(&mut out_full_slice); - }); + convolve(y, (&b).into(), ConvolveMode::Full)?.assign_to(&mut out_full_slice); + Ok(()) + })?; { // ```py // ind[axis] = slice(zi.shape[axis]) @@ -251,14 +250,14 @@ where out.lanes_mut(axis) .into_iter() .zip(x.lanes(axis)) // Almost basically np.apply_along_axis - .for_each(|(mut out_slice, y)| { + .try_for_each(|(mut out_slice, y)| { // np.convolve uses full mode, but is eventually slices out with // ```py // ind = out_full.ndim * [slice(None)] # creates the "[:, :, ..., :]" slice r // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) # [:out_full.shape[ ..] - len(b) + 1] // ``` use sci_rs_core::num_rs::{convolve, ConvolveMode}; - let out_full = convolve(y, (&b).into(), ConvolveMode::Full).unwrap(); + let out_full = convolve(y, (&b).into(), ConvolveMode::Full)?; out_full .slice( SliceInfo::try_from([SliceInfoElem::Slice { @@ -269,7 +268,8 @@ where .unwrap(), ) .assign_to(&mut out_slice); - }); + Ok(()) + })?; Ok((out, None)) } From f0c971a3ed482c606fc9bca9cc11fb0529fbf6ab Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Fri, 25 Jul 2025 17:04:06 +0800 Subject: [PATCH 25/68] Introduce LFilter as trait instead ndarray::SliceInfo::try_from only works for [SliceInfoElem; N], where N is up to and including 8. As such, whiltout resorting to Nightly compiler for `generic_const_exprs` feature, we will eventually implement lfilter for Array<_, Dim<[Ix; N]>> instead. Due to some other requirements, lfilter will instead be implemented up to 6 dimensions. See rust-lang:rust#76560. --- sci-rs/src/signal/filter/lfilter.rs | 88 +++++++++++++++++++++++++++++ 1 file changed, 88 insertions(+) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 9c10f38c..d1089fd9 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -68,6 +68,94 @@ where } } +/// Implement lfilter for fixed dimension of input array `x`. +/// +/// Valid only from 1 to 6 dimensional arrays. +pub trait LFilter +where + S: ndarray::RawData, +{ + /// Filter data `x` along one-dimension with an IIR or FIR filter. + /// + /// Filter a data sequence, `x`, using a digital filter. This works for many + /// fundamental data types (including Object type). The filter is a direct + /// form II transposed implementation of the standard difference equation + /// (see Notes). + /// + /// The function [super::sosfilt] (and filter design using ``output='sos'``) should be + /// preferred over `lfilter` for most filtering tasks, as second-order sections + /// have fewer numerical problems. + /// + /// ## Parameters + /// * `b` : array_like + /// The numerator coefficient vector in a 1-D sequence. + /// * `a` : array_like + /// The denominator coefficient vector in a 1-D sequence. If ``a[0]`` + /// is not 1, then both `a` and `b` are normalized by ``a[0]``. + /// * `x` : array_like + /// An N-dimensional input array. + /// * `axis`: Option + /// Default to `-1` if `None`. + /// Panics in accordance with [ndarray::ArrayBase::axis_iter]. + /// * `zi`: array_like + /// Currently not implemented. + /// Initial conditions for filter delays. It is a vector + /// (or array of vectors for an N-dimensional input) of length + /// ``max(len(a), len(b)) - 1``. If `zi` is None or is not given then + /// initial rest is assumed. See `lfiltic` and [super::lfilter_zi] for more information. + /// + /// ## Returns + /// * `y` : array + /// The output of the digital filter. + /// * `zf` : array, optional + /// If `zi` is None, this is not returned, otherwise, `zf` holds the + /// final filter delay values. + /// + /// # See Also + /// * [super::lfilter_zi_dyn] + /// + /// # Notes + /// For compile time reasons, lfilter is implemented per ArrayN at the moment. + /// + /// # Examples + /// On a 1-dimensional signal: + //// ``` + //// use ndarray::{array, ArrayBase, Array1, ArrayView1, Dim, Ix, OwnedRepr}; + //// use sci_rs::signal::filter::LFilter; + //// + //// let b = array![5., 4., 1., 2.]; + //// let a = array![1.]; + //// let x = array![1., 2., 3., 4., 3., 5., 6.]; + //// let expected = array![5., 14., 24., 36., 38., 47., 61.]; + //// let (result, _) = ArrayView1::lfilter((&b).into(), (&a).into(), (&x).into(), None, None).unwrap(); // By ref + //// + //// assert_eq!(result.len(), expected.len()); + //// result.into_iter().zip(expected).for_each(|(r, e)| { + //// assert_eq!(r, e); + //// }); + //// + //// let (result, _) = Array1::lfilter((&b).into(), (&a).into(), x, None, None).unwrap(); // By value + //// ``` + /// + /// # Panics + /// Currently yet to implement for `zi = Some(...)`, nor for `a.len() > 1`. + /// Panics if axis is out or range. + // NOTE: zi's TypeSig inherits from lfilter's output, in accordance with examples section of + // documentation, both lfilter_zi and this should eventually support NDArray. + fn lfilter<'a, T>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + zi: Option>>, + ) -> Result<(Array>, Option>>)> + where + [Ix; N]: IntoDimension>, + Dim<[Ix; N]>: RemoveAxis, + T: NumAssign + FromPrimitive + Copy + 'a, + S: Data + 'a; +} + /// Filter data along one-dimension with an IIR or FIR filter. /// /// Filter a data sequence, `x`, using a digital filter. This works for many From 301b779bab2b45858acbeddad64008ef0876d280 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Mon, 28 Jul 2025 15:25:36 +0800 Subject: [PATCH 26/68] Move lfilter implementation into macro Correspondingly update all test cases. See previous commit for rationale. --- sci-rs/src/signal/filter/lfilter.rs | 384 ++++++++++++---------------- 1 file changed, 167 insertions(+), 217 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index d1089fd9..952340b6 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -119,23 +119,23 @@ where /// /// # Examples /// On a 1-dimensional signal: - //// ``` - //// use ndarray::{array, ArrayBase, Array1, ArrayView1, Dim, Ix, OwnedRepr}; - //// use sci_rs::signal::filter::LFilter; - //// - //// let b = array![5., 4., 1., 2.]; - //// let a = array![1.]; - //// let x = array![1., 2., 3., 4., 3., 5., 6.]; - //// let expected = array![5., 14., 24., 36., 38., 47., 61.]; - //// let (result, _) = ArrayView1::lfilter((&b).into(), (&a).into(), (&x).into(), None, None).unwrap(); // By ref - //// - //// assert_eq!(result.len(), expected.len()); - //// result.into_iter().zip(expected).for_each(|(r, e)| { - //// assert_eq!(r, e); - //// }); - //// - //// let (result, _) = Array1::lfilter((&b).into(), (&a).into(), x, None, None).unwrap(); // By value - //// ``` + /// ``` + /// use ndarray::{array, ArrayBase, Array1, ArrayView1, Dim, Ix, OwnedRepr}; + /// use sci_rs::signal::filter::LFilter; + /// + /// let b = array![5., 4., 1., 2.]; + /// let a = array![1.]; + /// let x = array![1., 2., 3., 4., 3., 5., 6.]; + /// let expected = array![5., 14., 24., 36., 38., 47., 61.]; + /// let (result, _) = ArrayView1::lfilter((&b).into(), (&a).into(), (&x).into(), None, None).unwrap(); // By ref + /// + /// assert_eq!(result.len(), expected.len()); + /// result.into_iter().zip(expected).for_each(|(r, e)| { + /// assert_eq!(r, e); + /// }); + /// + /// let (result, _) = Array1::lfilter((&b).into(), (&a).into(), x, None, None).unwrap(); // By value + /// ``` /// /// # Panics /// Currently yet to implement for `zi = Some(...)`, nor for `a.len() > 1`. @@ -156,211 +156,152 @@ where S: Data + 'a; } -/// Filter data along one-dimension with an IIR or FIR filter. -/// -/// Filter a data sequence, `x`, using a digital filter. This works for many -/// fundamental data types (including Object type). The filter is a direct -/// form II transposed implementation of the standard difference equation -/// (see Notes). -/// -/// The function [super::sosfilt] (and filter design using ``output='sos'``) should be -/// preferred over `lfilter` for most filtering tasks, as second-order sections -/// have fewer numerical problems. -/// -/// ## Parameters -/// * `b` : array_like -/// The numerator coefficient vector in a 1-D sequence. -/// * `a` : array_like -/// The denominator coefficient vector in a 1-D sequence. If ``a[0]`` -/// is not 1, then both `a` and `b` are normalized by ``a[0]``. -/// * `x` : array_like -/// An N-dimensional input array. -/// * `axis`: Option -/// Default to `-1` if `None`. -/// Panics in accordance with [ndarray::ArrayBase::axis_iter]. -/// * `zi`: array_like -/// Currently not implemented. -/// Initial conditions for filter delays. It is a vector -/// (or array of vectors for an N-dimensional input) of length -/// ``max(len(a), len(b)) - 1``. If `zi` is None or is not given then -/// initial rest is assumed. See `lfiltic` and [super::lfilter_zi] for more information. -/// -/// ## Returns -/// * `y` : array -/// The output of the digital filter. -/// * `zf` : array, optional -/// If `zi` is None, this is not returned, otherwise, `zf` holds the -/// final filter delay values. -/// -/// # See Also -/// * [super::lfilter_zi] -/// -/// # Notes -/// -/// # Examples -/// On a 1-dimensional signal: -/// ``` -/// use ndarray::{array, ArrayBase, Dim, Ix, OwnedRepr}; -/// use sci_rs::signal::filter::lfilter; -/// -/// let b = array![5., 4., 1., 2.]; -/// let a = array![1.]; -/// let x = array![1., 2., 3., 4., 3., 5., 6.]; -/// let expected = array![5., 14., 24., 36., 38., 47., 61.]; -/// let (result, _) = lfilter((&b).into(), (&a).into(), x, None, None).unwrap(); -/// -/// assert_eq!(result.len(), expected.len()); -/// result.into_iter().zip(expected).for_each(|(r, e)| { -/// assert_eq!(r, e); -/// }) -/// ``` -/// -/// # Panics -/// Currently yet to implement for `zi = Some(...)`, nor for `a.len() > 1`. -/// Panics if axis is out or range. -// NOTE: zi's TypeSig inherits from lfilter's output, in accordance with examples section of -// documentation, both lfilter_zi and this should eventually support NDArray. -pub fn lfilter<'a, T, S, const N: usize>( - b: ArrayView1<'a, T>, - a: ArrayView1<'a, T>, - x: ArrayBase>, - axis: Option, - zi: Option>>, -) -> Result<(Array>, Option>>)> -where - [Ix; N]: IntoDimension>, - Dim<[Ix; N]>: RemoveAxis, - T: NumAssign + FromPrimitive + Copy + 'a, - S: Data + 'a, -{ - if N == 0 { - // `_validate_x` condition - ndarray allows for 0-dimensional arrays - return Err(Error::InvalidArg { - arg: "x".into(), - reason: "Linear filter requires at least 1-dimensional `x`.".into(), - }); - } +macro_rules! lfilter_for_dim { + ($N:literal) => { + impl LFilter for ArrayBase> + where + S: ndarray::RawData, + { + fn lfilter<'a, T>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + zi: Option>>, + ) -> Result<(Array>, Option>>)> + where + [Ix; $N]: IntoDimension>, + Dim<[Ix; $N]>: RemoveAxis, + T: NumAssign + FromPrimitive + Copy + 'a, + S: Data + 'a, + { + if a.len() > 1 { + return linear_filter(b, a, x, axis, zi); + }; - if a.len() > 1 { - return linear_filter(b, a, x, axis, zi); - }; + let (axis, axis_inner) = check_and_get_axis(axis, &x)?; - let (axis, axis_inner) = check_and_get_axis(axis, &x)?; + if a.is_empty() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: + "Empty 1D array will result in inf/nan result. Consider setting to `array![1.]`." + .into(), + }); + } else if a.first().unwrap().is_zero() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: "First element of a found to be zero.".into(), + }); + } + let b: Array1 = b.mapv(|bi| bi / a[0]); // b /= a[0] - if a.is_empty() { - return Err(Error::InvalidArg { - arg: "a".into(), - reason: - "Empty 1D array will result in inf/nan result. Consider setting to `array![1.]`." - .into(), - }); - } else if a.first().unwrap().is_zero() { - return Err(Error::InvalidArg { - arg: "a".into(), - reason: "First element of a found to be zero.".into(), - }); - } - let b: Array1 = b.mapv(|bi| bi / a[0]); // b /= a[0] + if let Some(zi) = zi { + // Use a separate branch to avoid unnecessary heap allocation of `out_full` in `zi` = None + // case. - if let Some(mut zi) = zi { - // Use a separate branch to avoid unnecessary heap allocation of `out_full` in `zi` = None - // case. + let zi = { + // if zi.ndim != x.ndim { return Err(...) } is signature asserted. - let zi = { - // if zi.ndim != x.ndim { return Err(...) } is signature asserted. + todo!(); + zi + }; - todo!(); - zi - }; + let (out_full_dim, out_full_dim_inner): (Dim<_>, [Ix; $N]) = { + let mut tmp: [Ix; $N] = ndarray_shape_as_array(&x); + tmp[axis_inner] += b.len_of(Axis(0)) - 1; // From np.convolve(..., 'full') + (IntoDimension::into_dimension(tmp), tmp) + }; - let (out_full_dim, out_full_dim_inner): (Dim<_>, [Ix; N]) = { - let mut tmp: [Ix; N] = ndarray_shape_as_array(&x); - tmp[axis_inner] += b.len_of(Axis(0)) - 1; // From np.convolve(..., 'full') - (IntoDimension::into_dimension(tmp), tmp) - }; + let mut out_full: Array> = ArrayBase::zeros(out_full_dim); + out_full + .lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_full_slice, y)| { + // np.convolve uses full mode by default + // ```py + // out_full = np.apply_along_axis(lambda y: np.convolve(b, y), axis, x) + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + convolve(y, (&b).into(), ConvolveMode::Full)? + .assign_to(&mut out_full_slice); + Ok(()) + })?; + { + // ```py + // ind[axis] = slice(zi.shape[axis]) + // out_full[tuple(ind)] += zi + // ``` + todo!() + }; - let mut out_full: Array> = ArrayBase::zeros(out_full_dim); - out_full - .lanes_mut(axis) - .into_iter() - .zip(x.lanes(axis)) // Almost basically np.apply_along_axis - .try_for_each(|(mut out_full_slice, y)| { - // np.convolve uses full mode by default - // ```py - // out_full = np.apply_along_axis(lambda y: np.convolve(b, y), axis, x) - // ``` - use sci_rs_core::num_rs::{convolve, ConvolveMode}; - convolve(y, (&b).into(), ConvolveMode::Full)?.assign_to(&mut out_full_slice); - Ok(()) - })?; - { - // ```py - // ind[axis] = slice(zi.shape[axis]) - // out_full[tuple(ind)] += zi - // ``` - todo!() - }; + let (out_dim, out_dim_inner) = { + let mut tmp: [Ix; $N] = ndarray_shape_as_array(&x); + (IntoDimension::into_dimension(tmp), tmp) + }; + let mut out = ArrayBase::zeros(out_dim); + out.lanes_mut(axis) + .into_iter() + .zip(out_full.lanes(axis)) + .for_each(|(mut out_slice, out_full_slice)| { + // ```py + // # Create the [...; :out_full.shape[axis] - len(b) + 1; ...] at index=axis + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) + // out = out_full[tuple(ind)] + // ``` + out_full_slice + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + }); - let (out_dim, out_dim_inner) = { - let mut tmp: [Ix; N] = ndarray_shape_as_array(&x); - (IntoDimension::into_dimension(tmp), tmp) - }; - let mut out = ArrayBase::zeros(out_dim); - out.lanes_mut(axis) - .into_iter() - .zip(out_full.lanes(axis)) - .for_each(|(mut out_slice, out_full_slice)| { - // ```py - // # Create the [...; :out_full.shape[axis] - len(b) + 1; ...] at index=axis - // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) - // out = out_full[tuple(ind)] - // ``` - out_full_slice - .slice( - SliceInfo::try_from([SliceInfoElem::Slice { - start: 0, - end: Some(out_dim_inner[axis_inner] as isize), - step: 1, - }]) - .unwrap(), - ) - .assign_to(&mut out_slice); - }); + Ok((out, todo!())) + } else { + // In contrast to the case where zi.is_some(), we can inline a slicing operation to reduce + // one extra heap allocation. - Ok((out, todo!())) - } else { - let (out_dim, out_dim_inner): (Dim<_>, [Ix; N]) = { - let mut tmp: [Ix; N] = ndarray_shape_as_array(&x); - (IntoDimension::into_dimension(tmp), tmp) - }; - let mut out = ArrayBase::zeros(out_dim); + let (out_dim, out_dim_inner): (Dim<_>, [Ix; $N]) = { + let mut tmp: [Ix; $N] = ndarray_shape_as_array(&x); + (IntoDimension::into_dimension(tmp), tmp) + }; + let mut out = ArrayBase::zeros(out_dim); - out.lanes_mut(axis) - .into_iter() - .zip(x.lanes(axis)) // Almost basically np.apply_along_axis - .try_for_each(|(mut out_slice, y)| { - // np.convolve uses full mode, but is eventually slices out with - // ```py - // ind = out_full.ndim * [slice(None)] # creates the "[:, :, ..., :]" slice r - // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) # [:out_full.shape[ ..] - len(b) + 1] - // ``` - use sci_rs_core::num_rs::{convolve, ConvolveMode}; - let out_full = convolve(y, (&b).into(), ConvolveMode::Full)?; - out_full - .slice( - SliceInfo::try_from([SliceInfoElem::Slice { - start: 0, - end: Some(out_dim_inner[axis_inner] as isize), - step: 1, - }]) - .unwrap(), - ) - .assign_to(&mut out_slice); - Ok(()) - })?; + out.lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_slice, y)| { + // np.convolve uses full mode, but is eventually slices out with + // ```py + // ind = out_full.ndim * [slice(None)] # creates the "[:, :, ..., :]" slice r + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) # [:out_full.shape[ ..] - len(b) + 1] + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + let out_full = convolve(y, (&b).into(), ConvolveMode::Full)?; + out_full + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + Ok(()) + })?; - Ok((out, None)) - } + Ok((out, None)) + } + } + } + }; } /// Internal function called by [lfilter] for situation a.len() > 1. @@ -380,6 +321,13 @@ where todo!() } +lfilter_for_dim!(1); +lfilter_for_dim!(2); +lfilter_for_dim!(3); +lfilter_for_dim!(4); +lfilter_for_dim!(5); +lfilter_for_dim!(6); + #[cfg(test)] mod test { use super::*; @@ -397,7 +345,8 @@ mod test { let x = array![1., 2., 3., 4., 3., 5., 6.]; let expected = array![5., 14., 24., 36., 38., 47., 61.]; - let Ok((result, None)) = lfilter((&b).into(), (&a).into(), x, None, None) else { + let Ok((result, None)) = Array1::lfilter((&b).into(), (&a).into(), x, None, None) + else { panic!("Should not have errored") }; @@ -413,7 +362,8 @@ mod test { let x = array![1., 2., 3., 4., 3., 5., 6.]; let expected = array![0.7, 1.1, 2.1, 3.1, 2.7, 5., 4.5]; - let Ok((result, None)) = lfilter((&b).into(), (&a).into(), x, None, None) else { + let Ok((result, None)) = Array1::lfilter((&b).into(), (&a).into(), x, None, None) + else { panic!("Should not have errored") }; @@ -430,19 +380,19 @@ mod test { let a = array![1.]; let x = array![1., 2., 3., 4., 3., 5., 6.]; - let result = lfilter((&b).into(), (&a).into(), x.clone(), Some(2), None); + let result = ArrayView1::lfilter((&b).into(), (&a).into(), (&x).into(), Some(2), None); assert!(result.is_err()); - let result = lfilter((&b).into(), (&a).into(), x.clone(), Some(1), None); + let result = Array1::lfilter((&b).into(), (&a).into(), x.clone(), Some(1), None); assert!(result.is_err()); - let result = lfilter((&b).into(), (&a).into(), x.clone(), Some(0), None); + let result = Array1::lfilter((&b).into(), (&a).into(), x.clone(), Some(0), None); assert!(result.is_ok()); - let result = lfilter((&b).into(), (&a).into(), x.clone(), Some(-1), None); + let result = Array1::lfilter((&b).into(), (&a).into(), x.clone(), Some(-1), None); assert!(result.is_ok()); - let result = lfilter((&b).into(), (&a).into(), x, Some(-2), None); + let result = Array1::lfilter((&b).into(), (&a).into(), x, Some(-2), None); assert!(result.is_err()); } } From 8d99c3f8590d09c7859353881031528a421398d1 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Wed, 30 Jul 2025 11:31:21 +0800 Subject: [PATCH 27/68] Add happy zi.is_some() case Currently, the as_strided trick is still not quite working. This commit also adds a corresponding unit test. --- sci-rs/src/signal/filter/lfilter.rs | 135 +++++++++++++++++++++++++--- 1 file changed, 123 insertions(+), 12 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 952340b6..4bfe6836 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -199,12 +199,49 @@ macro_rules! lfilter_for_dim { if let Some(zi) = zi { // Use a separate branch to avoid unnecessary heap allocation of `out_full` in `zi` = None // case. + let mut zi = zi.to_owned(); - let zi = { - // if zi.ndim != x.ndim { return Err(...) } is signature asserted. + // if zi.ndim != x.ndim { return Err(...) } is signature asserted. - todo!(); - zi + let mut expected_shape: [usize; $N] = x.shape().try_into().unwrap(); + *expected_shape // expected_shape[axis] = b.shape[0] - 1 + .get_mut(axis_inner) + .expect("invalid axis_inner") = b + .shape() + .first() + .expect("Could not get 0th axis len of b") + .checked_sub(1) + .expect("underflowing subtract"); + + if *zi.shape() != expected_shape { + let strides: [Ix; $N] = { + let zi_shape = zi.shape(); + let zi_strides = zi.strides(); + + // Waiting for try_collect() from nightly... we use this Vec> -> Result> method.. + let tmp_heap: Vec> = (0..$N) + .map(|k| { + if zi_shape[k] == expected_shape[k] { + zi_strides[k].try_into().map_err(|_| Error::InvalidArg { + arg: "zi".into(), + reason: "zi found with negative stride".into(), + }) + } else if k != axis_inner && zi_shape[k] == 1 { + Ok(0) + } else { + Err(Error::InvalidArg { + arg: "zi".into(), + reason: "Unexpected shape for parameter zi".into(), + }) + } + }) + .collect(); + let tmp_heap: Result> = tmp_heap.into_iter().collect(); + + tmp_heap?.try_into().unwrap() + }; + + zi = todo!(); }; let (out_full_dim, out_full_dim_inner): (Dim<_>, [Ix; $N]) = { @@ -228,16 +265,29 @@ macro_rules! lfilter_for_dim { .assign_to(&mut out_full_slice); Ok(()) })?; + + // ```py + // ind[axis] = slice(zi.shape[axis]) + // out_full[tuple(ind)] += zi + // ``` { - // ```py - // ind[axis] = slice(zi.shape[axis]) - // out_full[tuple(ind)] += zi - // ``` - todo!() - }; + let slice_info: SliceInfo<_, Dim<[Ix; $N]>, Dim<[Ix; $N]>> = { + let t = zi.shape()[axis_inner]; + let mut tmp = [SliceInfoElem::from(..); $N]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: 0, + end: Some(t as isize), + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; // Does not work because unless N: N<=6 cannot be bounded on type_sig + let mut s = out_full.slice_mut(&slice_info); + s += &zi; + } let (out_dim, out_dim_inner) = { - let mut tmp: [Ix; $N] = ndarray_shape_as_array(&x); + let tmp: [Ix; $N] = ndarray_shape_as_array(&x); (IntoDimension::into_dimension(tmp), tmp) }; let mut out = ArrayBase::zeros(out_dim); @@ -262,7 +312,30 @@ macro_rules! lfilter_for_dim { .assign_to(&mut out_slice); }); - Ok((out, todo!())) + // ```py + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1, None) + // zf = out_full[tuple(ind)] + // ``` + let zf = { + let slice_info: SliceInfo<_, Dim<[Ix; $N]>, Dim<[Ix; $N]>> = { + let t = out_full.shape()[axis_inner] + .checked_add(1) + .unwrap() + .checked_sub(b.len()) + .unwrap(); + let mut tmp = [SliceInfoElem::from(..); $N]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: t as isize, + end: None, + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; + out_full.slice(slice_info).to_owned() + }; + + Ok((out, Some(zf))) } else { // In contrast to the case where zi.is_some(), we can inline a slicing operation to reduce // one extra heap allocation. @@ -374,6 +447,44 @@ mod test { } } + #[test] + fn one_dim_fir_with_zi() { + { + // Case which does not falls into zi.shape() != expected_shape branch + let b = array![0.5, 0.4]; + let a = array![1.]; + let x = array![ + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + ]; + let zi = array![[-1.6], [1.4]]; + let expected = array![ + [-3.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7], + [-0.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7] + ]; + let expected_zi = array![[0.4], [0.4]]; + + let Ok((result, Some(r_zi))) = Array::<_, Dim<[Ix; 2]>>::lfilter( + (&b).into(), + (&a).into(), + x, + None, + Some((&zi).into()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + } + #[test] fn invalid_axis() { let b = array![5., 4., 1., 2.]; From 47be4c473259f007fff988dfc558c2d0d812a0a2 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Mon, 11 Aug 2025 17:52:23 +0800 Subject: [PATCH 28/68] Implement path that handles the unexpected shape of zi ShapeBuilder trait along with Array::as_slice allows as to replicate the stride_tricks method to create a different array view. It is not yet presently clear if as_slice_memory_order should be used instead, but the tests that are included are passing. It might be possible to remove the reallocation of zi at the start of the scope now that this is settled. --- sci-rs/src/signal/filter/lfilter.rs | 92 ++++++++++++++++++++++++++++- 1 file changed, 90 insertions(+), 2 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 4bfe6836..184b6b4c 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -2,7 +2,7 @@ use alloc::vec::Vec; use core::marker::Copy; use ndarray::{ Array, Array1, ArrayBase, ArrayView, ArrayView1, ArrayViewMut1, Axis, Data, Dim, IntoDimension, - Ix, IxDyn, RemoveAxis, SliceInfo, SliceInfoElem, + Ix, IxDyn, RemoveAxis, ShapeBuilder, SliceInfo, SliceInfoElem, }; use num_traits::{FromPrimitive, Num, NumAssign}; use sci_rs_core::{Error, Result}; @@ -241,7 +241,9 @@ macro_rules! lfilter_for_dim { tmp_heap?.try_into().unwrap() }; - zi = todo!(); + zi = ArrayView::from_shape(expected_shape.strides(strides), zi.as_slice().unwrap()) + .unwrap() + .to_owned(); }; let (out_full_dim, out_full_dim_inner): (Dim<_>, [Ix; $N]) = { @@ -474,6 +476,92 @@ mod test { panic!("Should not have errored") }; + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]]; + let zi = array![[0.4], [0.45], [0.05]]; + let expected = array![ + [5.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.85, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85], + ]; + let expected_zi = array![ + [4.25, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.], + ]; + + let Ok((result, Some(r_zi))) = Array::<_, Dim<[Ix; 2]>>::lfilter( + (&b).into(), + (&a).into(), + x, + Some(0), + Some((&zi).into()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch for 3D input + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![ + [[0.2, 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]], + [[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]] + ]; + let zi = array![[[0.4], [0.45], [0.05]], [[0.6], [0.15], [0.25]]]; + let expected = array![ + [ + [1.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.53, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85] + ], + [ + [5.6, 10.6, 15.6, 20.6, 15.6, 25.6, 30.6], + [40.55, 0.95, 6.35, 1.75, 16.35, 37.15, 32.55] + ] + ]; + let expected_zi = array![ + [ + [3.45, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [7.6, -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ], + [ + [4.45, 2.25, 3.65, 4.25, 4.45, 8.05, 8.65], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ] + ]; + + let Ok((result, Some(r_zi))) = Array::<_, Dim<[Ix; 3]>>::lfilter( + (&b).into(), + (&a).into(), + x, + Some(1), + Some((&zi).into()), + ) else { + panic!("Should not have errored") + }; + assert_eq!(result.len(), expected.len()); result.into_iter().zip(expected).for_each(|(r, e)| { assert_relative_eq!(r, e, max_relative = 1e-6); From f7d2c68d3db0ff0719da203ac1d411325b0649f4 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Tue, 12 Aug 2025 16:10:37 +0800 Subject: [PATCH 29/68] Touch up lfilter documentation Update to show when lfilter is not yet implemented for. --- sci-rs/src/signal/filter/lfilter.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 184b6b4c..d748ddc7 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -82,7 +82,7 @@ where /// form II transposed implementation of the standard difference equation /// (see Notes). /// - /// The function [super::sosfilt] (and filter design using ``output='sos'``) should be + /// The function [super::sosfilt_dyn] (and filter design using ``output='sos'``) should be /// preferred over `lfilter` for most filtering tasks, as second-order sections /// have fewer numerical problems. /// @@ -94,7 +94,7 @@ where /// is not 1, then both `a` and `b` are normalized by ``a[0]``. /// * `x` : array_like /// An N-dimensional input array. - /// * `axis`: Option + /// * `axis`: `Option` /// Default to `-1` if `None`. /// Panics in accordance with [ndarray::ArrayBase::axis_iter]. /// * `zi`: array_like @@ -102,7 +102,7 @@ where /// Initial conditions for filter delays. It is a vector /// (or array of vectors for an N-dimensional input) of length /// ``max(len(a), len(b)) - 1``. If `zi` is None or is not given then - /// initial rest is assumed. See `lfiltic` and [super::lfilter_zi] for more information. + /// initial rest is assumed. See `lfiltic` and [super::lfilter_zi_dyn] for more information. /// /// ## Returns /// * `y` : array @@ -138,8 +138,7 @@ where /// ``` /// /// # Panics - /// Currently yet to implement for `zi = Some(...)`, nor for `a.len() > 1`. - /// Panics if axis is out or range. + /// Currently yet to implement for `a.len() > 1`. // NOTE: zi's TypeSig inherits from lfilter's output, in accordance with examples section of // documentation, both lfilter_zi and this should eventually support NDArray. fn lfilter<'a, T>( From c2cd592133ebdfa7270153362ad5e404e76f529a Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Wed, 13 Aug 2025 14:32:12 +0800 Subject: [PATCH 30/68] Change lfilter_zi documentation to point to new lfilter implementation We point to the trait. --- sci-rs/src/signal/filter/lfilter_zi.rs | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter_zi.rs b/sci-rs/src/signal/filter/lfilter_zi.rs index 6c92df5a..48296030 100644 --- a/sci-rs/src/signal/filter/lfilter_zi.rs +++ b/sci-rs/src/signal/filter/lfilter_zi.rs @@ -9,20 +9,17 @@ use alloc::vec; #[cfg(feature = "alloc")] use alloc::vec::Vec; +/// Construct initial conditions for [lfilter][super::lfilter::LFilter] for step response +/// steady-state. /// -/// Construct initial conditions for lfilter for step response steady-state. -/// -/// Compute an initial state `zi` for the `lfilter` function that corresponds -/// to the steady state of the step response. +/// Compute an initial state `zi` for the `lfilter` function that corresponds to the steady state +/// of the step response. /// /// A typical use of this function is to set the initial state so that the /// output of the filter starts at the same value as the first element of /// the signal to be filtered. /// -/// /// -/// -/// #[inline] pub fn lfilter_zi_dyn(b: &[F], a: &[F]) -> Vec where @@ -38,7 +35,7 @@ where .expect("There must be at least one nonzero `a` coefficient.") .0; - // Mormalize to a[0] == 1 + // Normalize to a[0] == 1 let mut a = a.iter().skip(ai0).cloned().collect::>(); let mut b = b.to_vec(); let a0 = a[0]; From 396887a9da5bf33e41af69533b161a13acdd7287 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Fri, 15 Aug 2025 16:59:19 +0800 Subject: [PATCH 31/68] Remove unnecessary heap allocation of zi Reborrow is all you need. --- sci-rs/src/signal/filter/lfilter.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index d748ddc7..b32b96b1 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -195,10 +195,10 @@ macro_rules! lfilter_for_dim { } let b: Array1 = b.mapv(|bi| bi / a[0]); // b /= a[0] - if let Some(zi) = zi { + if let Some(zii) = zi { // Use a separate branch to avoid unnecessary heap allocation of `out_full` in `zi` = None // case. - let mut zi = zi.to_owned(); + let mut zi = zii.reborrow(); // if zi.ndim != x.ndim { return Err(...) } is signature asserted. @@ -240,9 +240,8 @@ macro_rules! lfilter_for_dim { tmp_heap?.try_into().unwrap() }; - zi = ArrayView::from_shape(expected_shape.strides(strides), zi.as_slice().unwrap()) - .unwrap() - .to_owned(); + zi = ArrayView::from_shape(expected_shape.strides(strides), zii.as_slice().unwrap()) + .unwrap(); }; let (out_full_dim, out_full_dim_inner): (Dim<_>, [Ix; $N]) = { From 4aaa73ae9286dee0024a2de0ecfda27d2873ab05 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Tue, 23 Sep 2025 13:53:22 +0800 Subject: [PATCH 32/68] Fix up docstring in lfilter Some warning that came up in `cargo doc`. --- sci-rs/src/signal/filter/lfilter.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index b32b96b1..404024d9 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -9,7 +9,7 @@ use sci_rs_core::{Error, Result}; /// Internal function for obtaining length of all axis as array from input from input. /// -/// This is almost the same as `a.shape()`, but is a array [T; N] instead of a Vec. +/// This is almost the same as `a.shape()`, but is a array `[T; N]` instead of a `Vec`. /// /// # Parameters /// `a`: Array whose shape is needed as a slice. @@ -377,7 +377,7 @@ macro_rules! lfilter_for_dim { }; } -/// Internal function called by [lfilter] for situation a.len() > 1. +/// Internal function called by [LFilter::lfilter] for situation a.len() > 1. fn linear_filter<'a, T, S, const N: usize>( b: ArrayView1<'a, T>, a: ArrayView1<'a, T>, From 279cc09feb5efd92147283f1f626d3073de7c482 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Thu, 2 Oct 2025 16:40:57 +0800 Subject: [PATCH 33/68] Prepare for IxDyn compatible lfilter function Change the function signature to linear_filter first to be more permissive. --- sci-rs/src/signal/filter/lfilter.rs | 34 ++++++++++++++--------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 404024d9..a92e5703 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -1,8 +1,9 @@ use alloc::vec::Vec; use core::marker::Copy; use ndarray::{ - Array, Array1, ArrayBase, ArrayView, ArrayView1, ArrayViewMut1, Axis, Data, Dim, IntoDimension, - Ix, IxDyn, RemoveAxis, ShapeBuilder, SliceInfo, SliceInfoElem, + Array, Array1, ArrayBase, ArrayD, ArrayView, ArrayView1, ArrayViewMut1, Axis, Data, Dim, + Dimension, IntoDimension, Ix, IxDyn, RemoveAxis, ShapeBuilder, SliceArg, SliceInfo, + SliceInfoElem, }; use num_traits::{FromPrimitive, Num, NumAssign}; use sci_rs_core::{Error, Result}; @@ -30,7 +31,7 @@ where /// # Parameters /// axis: The user-specificed axis which filter is to be applied on. /// x: The input-data whose axis object that will be manipulated against. -fn check_and_get_axis<'a, T, S, const N: usize>( +fn check_and_get_axis_st<'a, T, S, const N: usize>( axis: Option, x: &ArrayBase>, ) -> Result<(Axis, usize)> @@ -178,7 +179,7 @@ macro_rules! lfilter_for_dim { return linear_filter(b, a, x, axis, zi); }; - let (axis, axis_inner) = check_and_get_axis(axis, &x)?; + let (axis, axis_inner) = check_and_get_axis_st(axis, &x)?; if a.is_empty() { return Err(Error::InvalidArg { @@ -377,30 +378,29 @@ macro_rules! lfilter_for_dim { }; } +lfilter_for_dim!(1); +lfilter_for_dim!(2); +lfilter_for_dim!(3); +lfilter_for_dim!(4); +lfilter_for_dim!(5); +lfilter_for_dim!(6); + /// Internal function called by [LFilter::lfilter] for situation a.len() > 1. -fn linear_filter<'a, T, S, const N: usize>( +fn linear_filter<'a, T, S, D>( b: ArrayView1<'a, T>, a: ArrayView1<'a, T>, - x: ArrayBase>, + x: ArrayBase, axis: Option, - zi: Option>>, -) -> Result<(Array>, Option>>)> + zi: Option>, +) -> Result<(Array, Option>)> where - [Ix; N]: IntoDimension>, - Dim<[Ix; N]>: RemoveAxis, + D: Dimension + RemoveAxis, T: NumAssign + FromPrimitive + Copy + 'a, S: Data + 'a, { todo!() } -lfilter_for_dim!(1); -lfilter_for_dim!(2); -lfilter_for_dim!(3); -lfilter_for_dim!(4); -lfilter_for_dim!(5); -lfilter_for_dim!(6); - #[cfg(test)] mod test { use super::*; From 207c625c5ac022144d3f5ea69abf9ff87308a57d Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Fri, 3 Oct 2025 17:24:48 +0800 Subject: [PATCH 34/68] Add IxDyn compatible lfilter Also add tests for lfilter. Note that `lanes_mut` etc are parrallelizeable with Rayon, and may need more work in the future. --- sci-rs/src/signal/filter/lfilter.rs | 476 +++++++++++++++++++++++++++- 1 file changed, 474 insertions(+), 2 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index a92e5703..2537e013 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -1,4 +1,4 @@ -use alloc::vec::Vec; +use alloc::{vec, vec::Vec}; use core::marker::Copy; use ndarray::{ Array, Array1, ArrayBase, ArrayD, ArrayView, ArrayView1, ArrayViewMut1, Axis, Data, Dim, @@ -385,6 +385,337 @@ lfilter_for_dim!(4); lfilter_for_dim!(5); lfilter_for_dim!(6); +/// Filter data `x` along one-dimension with an IIR or FIR filter. +/// +/// Filter a data sequence, `x`, using a digital filter. This works for many +/// fundamental data types (including Object type). The filter is a direct +/// form II transposed implementation of the standard difference equation +/// (see Notes). +/// +/// The function [super::sosfilt_dyn] (and filter design using ``output='sos'``) should be +/// preferred over `lfilter` for most filtering tasks, as second-order sections +/// have fewer numerical problems. +/// +/// ## Parameters +/// * `b` : array_like +/// The numerator coefficient vector in a 1-D sequence. +/// * `a` : array_like +/// The denominator coefficient vector in a 1-D sequence. If ``a[0]`` +/// is not 1, then both `a` and `b` are normalized by ``a[0]``. +/// * `x` : array_like +/// An N-dimensional input array. +/// * `axis`: `Option` +/// Default to `-1` if `None`. +/// Panics in accordance with [ndarray::ArrayBase::axis_iter]. +/// * `zi`: array_like +/// Currently not implemented. +/// Initial conditions for filter delays. It is a vector +/// (or array of vectors for an N-dimensional input) of length +/// ``max(len(a), len(b)) - 1``. If `zi` is None or is not given then +/// initial rest is assumed. See `lfiltic` and [super::lfilter_zi_dyn] for more information. +/// +/// ## Returns +/// * `y` : array +/// The output of the digital filter. +/// * `zf` : array, optional +/// If `zi` is None, this is not returned, otherwise, `zf` holds the +/// final filter delay values. +/// +/// # See Also +/// * [super::lfilter_zi_dyn] +/// +/// # Notes +/// If Array<_, IxDyn as provided by this function is not desired, consider using [LFilter]. +/// +/// # Examples +/// On a 1-dimensional signal: +/// ``` +/// use ndarray::{array, ArrayBase, Array1, ArrayView1, Dim, Ix, OwnedRepr}; +/// use sci_rs::signal::filter::lfilter; +/// +/// let b = array![5., 4., 1., 2.]; +/// let a = array![1.]; +/// let x = array![1., 2., 3., 4., 3., 5., 6.]; +/// let expected = array![5., 14., 24., 36., 38., 47., 61.]; +/// let (result, _) = lfilter((&b).into(), (&a).into(), x.view(), None, None).unwrap(); // By ref +/// +/// assert_eq!(result.len(), expected.len()); +/// result.into_iter().zip(expected).for_each(|(r, e)| { +/// assert_eq!(r, e); +/// }); +/// +/// let (result, _) = lfilter((&b).into(), (&a).into(), x.clone().into_dyn(), None, None).unwrap(); // Dynamic arrays +/// let (result, _) = lfilter((&b).into(), (&a).into(), x, None, None).unwrap(); // By value +/// ``` +/// +/// # Panics +/// Currently yet to implement for `a.len() > 1`. +// NOTE: zi's TypeSig inherits from lfilter's output, in accordance with examples section of +// documentation, both lfilter_zi and this should eventually support NDArray. +pub fn lfilter<'a, T, S, D>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: ArrayBase, + axis: Option, + zi: Option>, +) -> Result<(Array, Option>)> +where + S: Data + 'a, + T: NumAssign + FromPrimitive + Copy + 'a, + D: Dimension + RemoveAxis, + SliceInfo, D, D>: SliceArg, +{ + let ndim = D::NDIM.unwrap_or(x.ndim()); + + if ndim == 0 { + return Err(Error::InvalidArg { + arg: "x".into(), + reason: "Linear filter requires at least 1-dimensional `x`.".into(), + }); + } + + if a.len() > 1 { + todo!(); + }; + + let (axis, axis_inner) = { + // Before we convert into the appropriate axis object, we have to check at runtime that the + // axis value specified is within -N <= axis < N. + if axis.is_some_and(|axis| { + !(if axis < 0 { + axis.unsigned_abs() <= ndim + } else { + axis.unsigned_abs() < ndim + }) + }) { + return Err(Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + }); + } + + // We make a best effort to convert into appropriate axis object. + let axis_inner: isize = axis.unwrap_or(-1); + if axis_inner >= 0 { + Ok((Axis(axis_inner as usize), axis_inner.unsigned_abs())) + } else { + let axis_inner = x + .ndim() + .checked_add_signed(axis_inner) + .expect("Invalid add to `axis` option"); + Ok((Axis(axis_inner), axis_inner)) + } + }?; + + if a.is_empty() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: + "Empty 1D array will result in inf/nan result. Consider setting to `array![1.]`." + .into(), + }); + } else if a.first().unwrap().is_zero() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: "First element of a found to be zero.".into(), + }); + } + let b: Array1 = b.mapv(|bi| bi / a[0]); // b /= a[0] + + if let Some(zii) = zi { + // Use a separate branch to avoid unnecessary heap allocation of `out_full` in `zi` = None + // case. + let mut zi = zii.clone().reborrow().into_dyn(); + + // if zi.ndim != x.ndim { return Err(...) } is signature asserted. + + let mut expected_shape: Vec = x.shape().to_vec(); + *expected_shape // expected_shape[axis] = b.shape[0] - 1 + .get_mut(axis_inner) + .expect("invalid axis_inner") = b + .shape() + .first() + .expect("Could not get 0th axis len of b") + .checked_sub(1) + .expect("underflowing subtract"); + + if *zi.shape() != expected_shape { + let strides: Vec = { + let zi_shape = zi.shape(); + let zi_strides = zi.strides(); + + // Waiting for try_collect() from nightly... we use this Vec> -> Result> method.. + let tmp_heap: Vec> = (0..ndim) + .map(|k| { + if zi_shape[k] == expected_shape[k] { + zi_strides[k].try_into().map_err(|_| Error::InvalidArg { + arg: "zi".into(), + reason: "zi found with negative stride".into(), + }) + } else if k != axis_inner && zi_shape[k] == 1 { + Ok(0) + } else { + Err(Error::InvalidArg { + arg: "zi".into(), + reason: "Unexpected shape for parameter zi".into(), + }) + } + }) + .collect(); + let tmp_heap: Result> = tmp_heap.into_iter().collect(); + + tmp_heap?.try_into().unwrap() + }; + + // ArrayView::from_shape(strides, + // zi.as_slice_memory_order().unwrap()).unwrap().to_owned() + zi = ArrayView::from_shape((expected_shape).strides(strides), zii.as_slice().unwrap()) + .unwrap(); + }; + + let (out_full_dim, out_full_dim_inner): (Dim<_>, Vec) = { + let mut tmp = x.shape().to_vec(); + tmp[axis_inner] += b.len_of(Axis(0)) - 1; // From np.convolve(..., 'full') + (IntoDimension::into_dimension(tmp.as_ref()), tmp) + }; + + let mut out_full = ArrayD::::zeros(out_full_dim); + out_full + .lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_full_slice, y)| { + // np.convolve uses full mode by default + // ```py + // out_full = np.apply_along_axis(lambda y: np.convolve(b, y), axis, x) + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + convolve(y, (&b).into(), ConvolveMode::Full)?.assign_to(&mut out_full_slice); + Ok(()) + })?; + + // ```py + // ind[axis] = slice(zi.shape[axis]) + // out_full[tuple(ind)] += zi + // ``` + { + let slice_info: SliceInfo<_, D, D> = { + let t = zi.shape()[axis_inner]; + let mut tmp = vec![SliceInfoElem::from(..); ndim]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: 0, + end: Some(t as isize), + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; // Does not work because unless N: N<=6 cannot be bounded on type_sig + let mut s = out_full.slice_mut(&slice_info); + s += &zi; + } + + let (out_dim, out_dim_inner) = { + // let mut out_dim_inner = out_full_dim_inner; + // if let Some(inner) = out_dim_inner.get_mut(axis_inner) { + // *inner = inner + // .checked_sub({ + // // Safety: b is Array1 + // *b.shape().first().unwrap() + // }) + // // Safety: inner is defined by having added b.len() + // .unwrap() + // + 1; + // } else { + // unsafe { unreachable_unchecked() }; + // }; + // (IntoDimension::into_dimension(out_dim_inner), out_dim_inner) + let tmp = x.shape(); + (IntoDimension::into_dimension(tmp), tmp) + }; + let mut out = ArrayBase::zeros(out_dim); + out.lanes_mut(axis) + .into_iter() + .zip(out_full.lanes(axis)) + .for_each(|(mut out_slice, out_full_slice)| { + // ```py + // # Create the [...; :out_full.shape[axis] - len(b) + 1; ...] at index=axis + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) + // out = out_full[tuple(ind)] + // ``` + out_full_slice + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + }); + + // ```py + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1, None) + // zf = out_full[tuple(ind)] + // ``` + let zf = { + let slice_info: SliceInfo<_, D, IxDyn> = { + let t = out_full.shape()[axis_inner] + .checked_add(1) + .unwrap() + .checked_sub(b.len()) + .unwrap(); + let mut tmp = vec![SliceInfoElem::from(..); ndim]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: t as isize, + end: None, + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; + out_full.slice(slice_info).to_owned() + }; + + Ok((out, Some(zf))) + } else { + // In contrast to the case where zi.is_some(), we can inline a slicing operation to reduce + // one extra heap allocation. + + let (out_dim, out_dim_inner) = { + let tmp = x.shape(); + (IntoDimension::into_dimension(tmp), tmp) + }; + let mut out = ArrayBase::zeros(out_dim); + + out.lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_slice, y)| { + // np.convolve uses full mode, but is eventually slices out with + // ```py + // ind = out_full.ndim * [slice(None)] # creates the "[:, :, ..., :]" slice r + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) # [:out_full.shape[ ..] - len(b) + 1] + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + let out_full = convolve(y, (&b).into(), ConvolveMode::Full)?; + out_full + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + Ok(()) + })?; + + Ok((out, None)) + } +} + /// Internal function called by [LFilter::lfilter] for situation a.len() > 1. fn linear_filter<'a, T, S, D>( b: ArrayView1<'a, T>, @@ -406,7 +737,7 @@ mod test { use super::*; use alloc::vec; use approx::assert_relative_eq; - use ndarray::{array, ArrayBase, Dim, Ix, OwnedRepr}; + use ndarray::{array, ArrayBase, Dim, Ix, OwnedRepr, ViewRepr}; // Tests that have a = [1.] with zi = None on input x with dim = 1. #[test] @@ -592,4 +923,145 @@ mod test { let result = Array1::lfilter((&b).into(), (&a).into(), x, Some(-2), None); assert!(result.is_err()); } + + #[test] + fn dyn_dim_fir_with_zi() { + { + // Case which does not falls into zi.shape() != expected_shape branch + let b = array![0.5, 0.4]; + let a = array![1.]; + let x = array![ + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + ]; + let zi = array![[-1.6], [1.4]]; + let expected = array![ + [-3.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7], + [-0.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7] + ]; + let expected_zi = array![[0.4], [0.4]]; + + // Test static dim input + let Ok((result, Some(r_zi))) = + lfilter((&b).into(), (&a).into(), x.view(), None, Some((&zi).into())) + else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(&expected).for_each(|(r, &e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(&expected_zi).for_each(|(r, &e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + + // Test dyn input + let Ok((result, Some(r_zi))) = lfilter( + (&b).into(), + (&a).into(), + x.into_dyn(), + None, + Some(zi.into_dyn().view()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]]; + let zi = array![[0.4], [0.45], [0.05]]; + let expected = array![ + [5.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.85, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85], + ]; + let expected_zi = array![ + [4.25, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.], + ]; + + let Ok((result, Some(r_zi))) = lfilter( + (&b).into(), + (&a).into(), + x.into_dyn(), + Some(0), + Some(zi.into_dyn().view()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch for 3D input + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![ + [[0.2, 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]], + [[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]] + ]; + let zi = array![[[0.4], [0.45], [0.05]], [[0.6], [0.15], [0.25]]]; + let expected = array![ + [ + [1.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.53, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85] + ], + [ + [5.6, 10.6, 15.6, 20.6, 15.6, 25.6, 30.6], + [40.55, 0.95, 6.35, 1.75, 16.35, 37.15, 32.55] + ] + ]; + let expected_zi = array![ + [ + [3.45, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [7.6, -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ], + [ + [4.45, 2.25, 3.65, 4.25, 4.45, 8.05, 8.65], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ] + ]; + + let Ok((result, Some(r_zi))) = lfilter( + (&b).into(), + (&a).into(), + x.into_dyn(), + Some(1), + Some(zi.into_dyn().view()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + } } From 8fb1e520c732a315c3a0a30d517914156da19b41 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sat, 4 Oct 2025 15:12:19 +0800 Subject: [PATCH 35/68] Remove useless conversion of Vec in lfilter This was caught by clippy. --- sci-rs/src/signal/filter/lfilter.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 2537e013..1c67815f 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -564,7 +564,7 @@ where .collect(); let tmp_heap: Result> = tmp_heap.into_iter().collect(); - tmp_heap?.try_into().unwrap() + tmp_heap? }; // ArrayView::from_shape(strides, From ae13a03dc5baae8bd4ccea3771c0589b2abb627d Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sun, 5 Oct 2025 15:49:55 +0800 Subject: [PATCH 36/68] Provision types for lfilter return types Suggested by clippy on the website. It is yet to be apparent if this provisioned type necessarily eases the reasoning of the return of the functions. --- sci-rs/src/signal/filter/lfilter.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 1c67815f..99656b95 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -8,6 +8,9 @@ use ndarray::{ use num_traits::{FromPrimitive, Num, NumAssign}; use sci_rs_core::{Error, Result}; +type LFilterResult = (Array>, Option>>); +type LFilterDynResult = (Array, Option>); + /// Internal function for obtaining length of all axis as array from input from input. /// /// This is almost the same as `a.shape()`, but is a array `[T; N]` instead of a `Vec`. @@ -148,7 +151,7 @@ where x: Self, axis: Option, zi: Option>>, - ) -> Result<(Array>, Option>>)> + ) -> Result> where [Ix; N]: IntoDimension>, Dim<[Ix; N]>: RemoveAxis, @@ -458,7 +461,7 @@ pub fn lfilter<'a, T, S, D>( x: ArrayBase, axis: Option, zi: Option>, -) -> Result<(Array, Option>)> +) -> Result> where S: Data + 'a, T: NumAssign + FromPrimitive + Copy + 'a, @@ -723,7 +726,7 @@ fn linear_filter<'a, T, S, D>( x: ArrayBase, axis: Option, zi: Option>, -) -> Result<(Array, Option>)> +) -> Result> where D: Dimension + RemoveAxis, T: NumAssign + FromPrimitive + Copy + 'a, From a3b6b54f22097fe54dcb8e5f8f75f9b519cb68c0 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sat, 18 Oct 2025 13:20:27 +0800 Subject: [PATCH 37/68] Change LFilter trait bounds This avoids the use of RawData and makes return type T more explicit with container S. --- sci-rs/src/signal/filter/lfilter.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 99656b95..61432cbb 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -75,9 +75,9 @@ where /// Implement lfilter for fixed dimension of input array `x`. /// /// Valid only from 1 to 6 dimensional arrays. -pub trait LFilter +pub trait LFilter where - S: ndarray::RawData, + S: Data, { /// Filter data `x` along one-dimension with an IIR or FIR filter. /// @@ -145,7 +145,7 @@ where /// Currently yet to implement for `a.len() > 1`. // NOTE: zi's TypeSig inherits from lfilter's output, in accordance with examples section of // documentation, both lfilter_zi and this should eventually support NDArray. - fn lfilter<'a, T>( + fn lfilter<'a>( b: ArrayView1<'a, T>, a: ArrayView1<'a, T>, x: Self, @@ -161,11 +161,11 @@ where macro_rules! lfilter_for_dim { ($N:literal) => { - impl LFilter for ArrayBase> + impl LFilter for ArrayBase> where - S: ndarray::RawData, + S: Data, { - fn lfilter<'a, T>( + fn lfilter<'a>( b: ArrayView1<'a, T>, a: ArrayView1<'a, T>, x: Self, @@ -176,7 +176,7 @@ macro_rules! lfilter_for_dim { [Ix; $N]: IntoDimension>, Dim<[Ix; $N]>: RemoveAxis, T: NumAssign + FromPrimitive + Copy + 'a, - S: Data + 'a, + S: 'a, { if a.len() > 1 { return linear_filter(b, a, x, axis, zi); From 4cd1042091632e7d85cbf298bf5ced5a4ee4853b Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sat, 18 Oct 2025 15:35:47 +0800 Subject: [PATCH 38/68] Move `lfilter::ndarray_shape_...` to `arraytools::..._st` The shape as `[usize; N]` function is not used by just lfilter. Also changed the name and documentation to reflect the `const N` requirement. --- sci-rs/src/signal/filter/arraytools.rs | 23 +++++++++++++++++++++++ sci-rs/src/signal/filter/lfilter.rs | 25 ++++--------------------- sci-rs/src/signal/filter/mod.rs | 5 +++++ 3 files changed, 32 insertions(+), 21 deletions(-) create mode 100644 sci-rs/src/signal/filter/arraytools.rs diff --git a/sci-rs/src/signal/filter/arraytools.rs b/sci-rs/src/signal/filter/arraytools.rs new file mode 100644 index 00000000..77fe666f --- /dev/null +++ b/sci-rs/src/signal/filter/arraytools.rs @@ -0,0 +1,23 @@ +//! Functions for acting on a axis of an array. +//! +//! Designed for ndarrays; with scipy's internal nomenclature. + +use ndarray::{ArrayBase, Axis, Data, Dim, Dimension, IntoDimension, Ix, RemoveAxis}; +use sci_rs_core::{Error, Result}; + +/// Internal function for obtaining length of all axis as array from input from input. +/// +/// This is almost the same as `a.shape()`, but is a array `[T; N]` instead of a slice `&[T]`. +/// +/// # Parameters +/// `a`: Array whose shape is needed as a slice. +pub(crate) fn ndarray_shape_as_array_st<'a, S, T, const N: usize>( + a: &ArrayBase>, +) -> [Ix; N] +where + [Ix; N]: IntoDimension>, + Dim<[Ix; N]>: RemoveAxis, + S: Data + 'a, +{ + a.shape().try_into().expect("Could not cast shape to array") +} diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 61432cbb..f2d28b6d 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -1,3 +1,4 @@ +use super::arraytools::ndarray_shape_as_array_st; use alloc::{vec, vec::Vec}; use core::marker::Copy; use ndarray::{ @@ -11,24 +12,6 @@ use sci_rs_core::{Error, Result}; type LFilterResult = (Array>, Option>>); type LFilterDynResult = (Array, Option>); -/// Internal function for obtaining length of all axis as array from input from input. -/// -/// This is almost the same as `a.shape()`, but is a array `[T; N]` instead of a `Vec`. -/// -/// # Parameters -/// `a`: Array whose shape is needed as a slice. -fn ndarray_shape_as_array<'a, S, T, const N: usize>(a: &ArrayBase>) -> [Ix; N] -where - [Ix; N]: IntoDimension>, - Dim<[Ix; N]>: RemoveAxis, - T: FromPrimitive, - S: Data + 'a, -{ - let mut tmp = [0; N]; - (0..N).for_each(|axis| tmp[axis] = a.len_of(Axis(axis))); - tmp -} - /// Internal function for casting into [Axis] and appropriate usize from isize. /// /// # Parameters @@ -249,7 +232,7 @@ macro_rules! lfilter_for_dim { }; let (out_full_dim, out_full_dim_inner): (Dim<_>, [Ix; $N]) = { - let mut tmp: [Ix; $N] = ndarray_shape_as_array(&x); + let mut tmp: [Ix; $N] = ndarray_shape_as_array_st(&x); tmp[axis_inner] += b.len_of(Axis(0)) - 1; // From np.convolve(..., 'full') (IntoDimension::into_dimension(tmp), tmp) }; @@ -291,7 +274,7 @@ macro_rules! lfilter_for_dim { } let (out_dim, out_dim_inner) = { - let tmp: [Ix; $N] = ndarray_shape_as_array(&x); + let tmp: [Ix; $N] = ndarray_shape_as_array_st(&x); (IntoDimension::into_dimension(tmp), tmp) }; let mut out = ArrayBase::zeros(out_dim); @@ -345,7 +328,7 @@ macro_rules! lfilter_for_dim { // one extra heap allocation. let (out_dim, out_dim_inner): (Dim<_>, [Ix; $N]) = { - let mut tmp: [Ix; $N] = ndarray_shape_as_array(&x); + let mut tmp: [Ix; $N] = ndarray_shape_as_array_st(&x); (IntoDimension::into_dimension(tmp), tmp) }; let mut out = ArrayBase::zeros(out_dim); diff --git a/sci-rs/src/signal/filter/mod.rs b/sci-rs/src/signal/filter/mod.rs index 5e5b4d29..ec0c5f2c 100644 --- a/sci-rs/src/signal/filter/mod.rs +++ b/sci-rs/src/signal/filter/mod.rs @@ -24,6 +24,11 @@ mod sosfilt; pub use ext::*; pub use sosfilt::*; +#[cfg(feature = "alloc")] +mod arraytools; +#[cfg(feature = "alloc")] +use arraytools::*; + #[cfg(feature = "alloc")] mod lfilter; #[cfg(feature = "alloc")] From ab84226046b02ac9b50b108755451ea8e3c2bffb Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Fri, 24 Oct 2025 15:05:02 +0800 Subject: [PATCH 39/68] Change zeroing to direct heap-reservation for lfilter intermediates Assume init for uninit arrays instead of specifying zeros with required shape gives an improved speed in lfilter. --- sci-rs/src/signal/filter/lfilter.rs | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index f2d28b6d..2a1879f5 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -237,7 +237,8 @@ macro_rules! lfilter_for_dim { (IntoDimension::into_dimension(tmp), tmp) }; - let mut out_full: Array> = ArrayBase::zeros(out_full_dim); + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out_full = unsafe { Array::uninit(out_full_dim).assume_init() }; out_full .lanes_mut(axis) .into_iter() @@ -277,7 +278,8 @@ macro_rules! lfilter_for_dim { let tmp: [Ix; $N] = ndarray_shape_as_array_st(&x); (IntoDimension::into_dimension(tmp), tmp) }; - let mut out = ArrayBase::zeros(out_dim); + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; out.lanes_mut(axis) .into_iter() .zip(out_full.lanes(axis)) @@ -331,7 +333,8 @@ macro_rules! lfilter_for_dim { let mut tmp: [Ix; $N] = ndarray_shape_as_array_st(&x); (IntoDimension::into_dimension(tmp), tmp) }; - let mut out = ArrayBase::zeros(out_dim); + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; out.lanes_mut(axis) .into_iter() @@ -618,7 +621,8 @@ where let tmp = x.shape(); (IntoDimension::into_dimension(tmp), tmp) }; - let mut out = ArrayBase::zeros(out_dim); + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; out.lanes_mut(axis) .into_iter() .zip(out_full.lanes(axis)) @@ -672,7 +676,7 @@ where let tmp = x.shape(); (IntoDimension::into_dimension(tmp), tmp) }; - let mut out = ArrayBase::zeros(out_dim); + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; // Safety: All elements are overwritten by convolve in subsequent step. out.lanes_mut(axis) .into_iter() From ef26b6c8c1db8f648ab89dab6e97f82d0a329fac Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Fri, 24 Oct 2025 16:54:37 +0800 Subject: [PATCH 40/68] Add const inlining of user's `axis: Option` to `usize` Tried to see in Cargo-asm if this is inlined but `cargo-show-asm` isn't picking up `lfilter` as a function. This will have to be tested in the benching function. --- sci-rs/src/signal/filter/lfilter.rs | 69 +++++++++++++++++++++++++---- 1 file changed, 60 insertions(+), 9 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 2a1879f5..9c0f39b6 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -17,23 +17,71 @@ type LFilterDynResult = (Array, Option>); /// # Parameters /// axis: The user-specificed axis which filter is to be applied on. /// x: The input-data whose axis object that will be manipulated against. -fn check_and_get_axis_st<'a, T, S, const N: usize>( +/// +/// # Notes +/// Const nature of this function means error has to be manually created. +#[inline] +const fn check_and_get_axis_st<'a, T, S, const N: usize>( axis: Option, x: &ArrayBase>, +) -> core::result::Result<(Axis, usize), ()> +where + S: Data + 'a, +{ + // Before we convert into the appropriate axis object, we have to check at runtime that the + // axis value specified is within -N <= axis < N. + match axis { + None => (), + Some(axis) if axis.is_negative() => { + if axis.unsigned_abs() > N { + return Err(()); + } + } + Some(axis) => { + if axis.unsigned_abs() >= N { + return Err(()); + } + } + } + + // We make a best effort to convert into appropriate axis object. + let axis_inner: isize = match axis { + Some(axis) => axis, + None => -1, + }; + if axis_inner >= 0 { + Ok((Axis(axis_inner as usize), axis_inner.unsigned_abs())) + } else { + let axis_inner = N + .checked_add_signed(axis_inner) + .expect("Invalid add to `axis` option"); + Ok((Axis(axis_inner), axis_inner)) + } +} + +/// Internal function for casting into [Axis] and appropriate usize from isize. +/// [check_and_get_axis_st] but without const, especially for IxDyn arrays. +/// +/// # Parameters +/// axis: The user-specificed axis which filter is to be applied on. +/// x: The input-data whose axis object that will be manipulated against. +#[inline] +fn check_and_get_axis_dyn<'a, T, S, D>( + axis: Option, + x: &ArrayBase, ) -> Result<(Axis, usize)> where - [Ix; N]: IntoDimension>, - Dim<[Ix; N]>: RemoveAxis, - T: NumAssign + FromPrimitive + Copy + 'a, + D: Dimension, S: Data + 'a, { + let ndim = D::NDIM.unwrap_or(x.ndim()); // Before we convert into the appropriate axis object, we have to check at runtime that the // axis value specified is within -N <= axis < N. if axis.is_some_and(|axis| { !(if axis < 0 { - axis.unsigned_abs() <= N + axis.unsigned_abs() <= ndim } else { - axis.unsigned_abs() < N + axis.unsigned_abs() < ndim }) }) { return Err(Error::InvalidArg { @@ -47,8 +95,7 @@ where if axis_inner >= 0 { Ok((Axis(axis_inner as usize), axis_inner.unsigned_abs())) } else { - let axis_inner = x - .ndim() + let axis_inner = ndim .checked_add_signed(axis_inner) .expect("Invalid add to `axis` option"); Ok((Axis(axis_inner), axis_inner)) @@ -165,7 +212,11 @@ macro_rules! lfilter_for_dim { return linear_filter(b, a, x, axis, zi); }; - let (axis, axis_inner) = check_and_get_axis_st(axis, &x)?; + let (axis, axis_inner) = check_and_get_axis_st(axis, &x) + .map_err(|_| Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + })?; if a.is_empty() { return Err(Error::InvalidArg { From 952223d824cc01d984b89af3a019a50b7ec7b1d5 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sat, 25 Oct 2025 13:23:23 +0800 Subject: [PATCH 41/68] Change lfilter(for IxDyn arrays) to use get_axis_dyn This centralizes areas of concern. --- sci-rs/src/signal/filter/lfilter.rs | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 9c0f39b6..c93c9912 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -518,34 +518,7 @@ where todo!(); }; - let (axis, axis_inner) = { - // Before we convert into the appropriate axis object, we have to check at runtime that the - // axis value specified is within -N <= axis < N. - if axis.is_some_and(|axis| { - !(if axis < 0 { - axis.unsigned_abs() <= ndim - } else { - axis.unsigned_abs() < ndim - }) - }) { - return Err(Error::InvalidArg { - arg: "axis".into(), - reason: "index out of range.".into(), - }); - } - - // We make a best effort to convert into appropriate axis object. - let axis_inner: isize = axis.unwrap_or(-1); - if axis_inner >= 0 { - Ok((Axis(axis_inner as usize), axis_inner.unsigned_abs())) - } else { - let axis_inner = x - .ndim() - .checked_add_signed(axis_inner) - .expect("Invalid add to `axis` option"); - Ok((Axis(axis_inner), axis_inner)) - } - }?; + let (axis, axis_inner) = check_and_get_axis_dyn(axis, &x)?; if a.is_empty() { return Err(Error::InvalidArg { From 61d4392ae8d33c2d63d5821683c7eb8dbaa65f7e Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sun, 26 Oct 2025 15:16:36 +0800 Subject: [PATCH 42/68] Change check_and_get_axis_* to return only `usize` Callee should have the responsibility of constructing the `Axis` object. --- sci-rs/src/signal/filter/lfilter.rs | 33 ++++++++++++++++------------- 1 file changed, 18 insertions(+), 15 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index c93c9912..355266ab 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -24,7 +24,7 @@ type LFilterDynResult = (Array, Option>); const fn check_and_get_axis_st<'a, T, S, const N: usize>( axis: Option, x: &ArrayBase>, -) -> core::result::Result<(Axis, usize), ()> +) -> core::result::Result where S: Data + 'a, { @@ -50,12 +50,12 @@ where None => -1, }; if axis_inner >= 0 { - Ok((Axis(axis_inner as usize), axis_inner.unsigned_abs())) + Ok(axis_inner.unsigned_abs()) } else { let axis_inner = N .checked_add_signed(axis_inner) .expect("Invalid add to `axis` option"); - Ok((Axis(axis_inner), axis_inner)) + Ok(axis_inner) } } @@ -66,10 +66,7 @@ where /// axis: The user-specificed axis which filter is to be applied on. /// x: The input-data whose axis object that will be manipulated against. #[inline] -fn check_and_get_axis_dyn<'a, T, S, D>( - axis: Option, - x: &ArrayBase, -) -> Result<(Axis, usize)> +fn check_and_get_axis_dyn<'a, T, S, D>(axis: Option, x: &ArrayBase) -> Result where D: Dimension, S: Data + 'a, @@ -93,12 +90,12 @@ where // We make a best effort to convert into appropriate axis object. let axis_inner: isize = axis.unwrap_or(-1); if axis_inner >= 0 { - Ok((Axis(axis_inner as usize), axis_inner.unsigned_abs())) + Ok(axis_inner.unsigned_abs()) } else { let axis_inner = ndim .checked_add_signed(axis_inner) .expect("Invalid add to `axis` option"); - Ok((Axis(axis_inner), axis_inner)) + Ok(axis_inner) } } @@ -212,11 +209,14 @@ macro_rules! lfilter_for_dim { return linear_filter(b, a, x, axis, zi); }; - let (axis, axis_inner) = check_and_get_axis_st(axis, &x) - .map_err(|_| Error::InvalidArg { - arg: "axis".into(), - reason: "index out of range.".into(), - })?; + let (axis, axis_inner) = { + let ax = check_and_get_axis_st(axis, &x) + .map_err(|_| Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + })?; + (Axis(ax), ax) + }; if a.is_empty() { return Err(Error::InvalidArg { @@ -518,7 +518,10 @@ where todo!(); }; - let (axis, axis_inner) = check_and_get_axis_dyn(axis, &x)?; + let (axis, axis_inner) = { + let ax = check_and_get_axis_dyn(axis, &x)?; + (Axis(ax), ax) + }; if a.is_empty() { return Err(Error::InvalidArg { From 07ba98a516f9c8f7e398561bf0d044f6cbb4bfcf Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Mon, 27 Oct 2025 16:18:42 +0800 Subject: [PATCH 43/68] Move `lfilter::check_and_get_axis*` to `arraytools::` These are fairly generic functions that parse user input into valid values of usizes (in their respective context being the arrays they act on), and thus are more suited to be in array tools. --- sci-rs/src/signal/filter/arraytools.rs | 90 ++++++++++++++++++++++++++ sci-rs/src/signal/filter/lfilter.rs | 89 +------------------------ 2 files changed, 91 insertions(+), 88 deletions(-) diff --git a/sci-rs/src/signal/filter/arraytools.rs b/sci-rs/src/signal/filter/arraytools.rs index 77fe666f..029ef172 100644 --- a/sci-rs/src/signal/filter/arraytools.rs +++ b/sci-rs/src/signal/filter/arraytools.rs @@ -5,6 +5,96 @@ use ndarray::{ArrayBase, Axis, Data, Dim, Dimension, IntoDimension, Ix, RemoveAxis}; use sci_rs_core::{Error, Result}; +/// Internal function for casting into [Axis] and appropriate usize from isize. +/// +/// # Parameters +/// axis: The user-specificed axis which filter is to be applied on. +/// x: The input-data whose axis object that will be manipulated against. +/// +/// # Notes +/// Const nature of this function means error has to be manually created. +#[inline] +pub(crate) const fn check_and_get_axis_st<'a, T, S, const N: usize>( + axis: Option, + x: &ArrayBase>, +) -> core::result::Result +where + S: Data + 'a, +{ + // Before we convert into the appropriate axis object, we have to check at runtime that the + // axis value specified is within -N <= axis < N. + match axis { + None => (), + Some(axis) if axis.is_negative() => { + if axis.unsigned_abs() > N { + return Err(()); + } + } + Some(axis) => { + if axis.unsigned_abs() >= N { + return Err(()); + } + } + } + + // We make a best effort to convert into appropriate axis object. + let axis_inner: isize = match axis { + Some(axis) => axis, + None => -1, + }; + if axis_inner >= 0 { + Ok(axis_inner.unsigned_abs()) + } else { + let axis_inner = N + .checked_add_signed(axis_inner) + .expect("Invalid add to `axis` option"); + Ok(axis_inner) + } +} + +/// Internal function for casting into [Axis] and appropriate usize from isize. +/// [check_and_get_axis_st] but without const, especially for IxDyn arrays. +/// +/// # Parameters +/// axis: The user-specificed axis which filter is to be applied on. +/// x: The input-data whose axis object that will be manipulated against. +#[inline] +pub(crate) fn check_and_get_axis_dyn<'a, T, S, D>( + axis: Option, + x: &ArrayBase, +) -> Result +where + D: Dimension, + S: Data + 'a, +{ + let ndim = D::NDIM.unwrap_or(x.ndim()); + // Before we convert into the appropriate axis object, we have to check at runtime that the + // axis value specified is within -N <= axis < N. + if axis.is_some_and(|axis| { + !(if axis < 0 { + axis.unsigned_abs() <= ndim + } else { + axis.unsigned_abs() < ndim + }) + }) { + return Err(Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + }); + } + + // We make a best effort to convert into appropriate axis object. + let axis_inner: isize = axis.unwrap_or(-1); + if axis_inner >= 0 { + Ok(axis_inner.unsigned_abs()) + } else { + let axis_inner = ndim + .checked_add_signed(axis_inner) + .expect("Invalid add to `axis` option"); + Ok(axis_inner) + } +} + /// Internal function for obtaining length of all axis as array from input from input. /// /// This is almost the same as `a.shape()`, but is a array `[T; N]` instead of a slice `&[T]`. diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index 355266ab..e8dfb047 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -1,4 +1,4 @@ -use super::arraytools::ndarray_shape_as_array_st; +use super::arraytools::{check_and_get_axis_dyn, check_and_get_axis_st, ndarray_shape_as_array_st}; use alloc::{vec, vec::Vec}; use core::marker::Copy; use ndarray::{ @@ -12,93 +12,6 @@ use sci_rs_core::{Error, Result}; type LFilterResult = (Array>, Option>>); type LFilterDynResult = (Array, Option>); -/// Internal function for casting into [Axis] and appropriate usize from isize. -/// -/// # Parameters -/// axis: The user-specificed axis which filter is to be applied on. -/// x: The input-data whose axis object that will be manipulated against. -/// -/// # Notes -/// Const nature of this function means error has to be manually created. -#[inline] -const fn check_and_get_axis_st<'a, T, S, const N: usize>( - axis: Option, - x: &ArrayBase>, -) -> core::result::Result -where - S: Data + 'a, -{ - // Before we convert into the appropriate axis object, we have to check at runtime that the - // axis value specified is within -N <= axis < N. - match axis { - None => (), - Some(axis) if axis.is_negative() => { - if axis.unsigned_abs() > N { - return Err(()); - } - } - Some(axis) => { - if axis.unsigned_abs() >= N { - return Err(()); - } - } - } - - // We make a best effort to convert into appropriate axis object. - let axis_inner: isize = match axis { - Some(axis) => axis, - None => -1, - }; - if axis_inner >= 0 { - Ok(axis_inner.unsigned_abs()) - } else { - let axis_inner = N - .checked_add_signed(axis_inner) - .expect("Invalid add to `axis` option"); - Ok(axis_inner) - } -} - -/// Internal function for casting into [Axis] and appropriate usize from isize. -/// [check_and_get_axis_st] but without const, especially for IxDyn arrays. -/// -/// # Parameters -/// axis: The user-specificed axis which filter is to be applied on. -/// x: The input-data whose axis object that will be manipulated against. -#[inline] -fn check_and_get_axis_dyn<'a, T, S, D>(axis: Option, x: &ArrayBase) -> Result -where - D: Dimension, - S: Data + 'a, -{ - let ndim = D::NDIM.unwrap_or(x.ndim()); - // Before we convert into the appropriate axis object, we have to check at runtime that the - // axis value specified is within -N <= axis < N. - if axis.is_some_and(|axis| { - !(if axis < 0 { - axis.unsigned_abs() <= ndim - } else { - axis.unsigned_abs() < ndim - }) - }) { - return Err(Error::InvalidArg { - arg: "axis".into(), - reason: "index out of range.".into(), - }); - } - - // We make a best effort to convert into appropriate axis object. - let axis_inner: isize = axis.unwrap_or(-1); - if axis_inner >= 0 { - Ok(axis_inner.unsigned_abs()) - } else { - let axis_inner = ndim - .checked_add_signed(axis_inner) - .expect("Invalid add to `axis` option"); - Ok(axis_inner) - } -} - /// Implement lfilter for fixed dimension of input array `x`. /// /// Valid only from 1 to 6 dimensional arrays. From d100c408c00f99fcc2698e11c9ed638ed162af48 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Tue, 28 Oct 2025 14:56:50 +0800 Subject: [PATCH 44/68] Remove unnecessary trait requirements of LFilter Some of the trait requirements bound to lfilter as a function can be removed. --- sci-rs/src/signal/filter/lfilter.rs | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs index e8dfb047..32abdfb5 100644 --- a/sci-rs/src/signal/filter/lfilter.rs +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -2,9 +2,8 @@ use super::arraytools::{check_and_get_axis_dyn, check_and_get_axis_st, ndarray_s use alloc::{vec, vec::Vec}; use core::marker::Copy; use ndarray::{ - Array, Array1, ArrayBase, ArrayD, ArrayView, ArrayView1, ArrayViewMut1, Axis, Data, Dim, - Dimension, IntoDimension, Ix, IxDyn, RemoveAxis, ShapeBuilder, SliceArg, SliceInfo, - SliceInfoElem, + Array, Array1, ArrayBase, ArrayD, ArrayView, ArrayView1, Axis, Data, Dim, Dimension, + IntoDimension, Ix, IxDyn, ShapeBuilder, SliceArg, SliceInfo, SliceInfoElem, }; use num_traits::{FromPrimitive, Num, NumAssign}; use sci_rs_core::{Error, Result}; @@ -93,8 +92,6 @@ where zi: Option>>, ) -> Result> where - [Ix; N]: IntoDimension>, - Dim<[Ix; N]>: RemoveAxis, T: NumAssign + FromPrimitive + Copy + 'a, S: Data + 'a; } @@ -113,8 +110,6 @@ macro_rules! lfilter_for_dim { zi: Option>>, ) -> Result<(Array>, Option>>)> where - [Ix; $N]: IntoDimension>, - Dim<[Ix; $N]>: RemoveAxis, T: NumAssign + FromPrimitive + Copy + 'a, S: 'a, { @@ -415,7 +410,7 @@ pub fn lfilter<'a, T, S, D>( where S: Data + 'a, T: NumAssign + FromPrimitive + Copy + 'a, - D: Dimension + RemoveAxis, + D: Dimension, SliceInfo, D, D>: SliceArg, { let ndim = D::NDIM.unwrap_or(x.ndim()); @@ -655,8 +650,8 @@ fn linear_filter<'a, T, S, D>( zi: Option>, ) -> Result> where - D: Dimension + RemoveAxis, - T: NumAssign + FromPrimitive + Copy + 'a, + D: Dimension, + T: 'a, S: Data + 'a, { todo!() From 858974a07ffef705ed5368e1d5e418484216f537 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Mon, 18 Aug 2025 15:39:23 +0800 Subject: [PATCH 45/68] Update lfilter_zi_dyn signature to use Array1 This is necessary for the return type to be consistent with lfilter. --- sci-rs/src/signal/filter/lfilter_zi.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sci-rs/src/signal/filter/lfilter_zi.rs b/sci-rs/src/signal/filter/lfilter_zi.rs index 48296030..560e207d 100644 --- a/sci-rs/src/signal/filter/lfilter_zi.rs +++ b/sci-rs/src/signal/filter/lfilter_zi.rs @@ -1,5 +1,6 @@ use core::{iter::Sum, ops::SubAssign}; use nalgebra::{DMatrix, OMatrix, RealField, SMatrix, Scalar}; +use ndarray::Array1; use num_traits::{Float, One, Zero}; use crate::linalg::companion_dyn; @@ -21,7 +22,7 @@ use alloc::vec::Vec; /// /// #[inline] -pub fn lfilter_zi_dyn(b: &[F], a: &[F]) -> Vec +pub fn lfilter_zi_dyn(b: &[F], a: &[F]) -> Array1 where F: RealField + Copy + PartialEq + Scalar + Zero + One + Sum + SubAssign, { @@ -76,7 +77,7 @@ where } } - zi + zi.into() } #[cfg(test)] From 7d5b796c8fdcb9b634681742d571938e850b9c07 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sat, 18 Oct 2025 19:39:29 +0800 Subject: [PATCH 46/68] Remove wrong lfilter_zi assumption There is no requirement that the arguments passed into lfilter has to be IIR in the way where the denominator length is identical to the numerator length. FIR can be treated as a special case of IIR as a matter of fact. --- sci-rs/src/signal/filter/lfilter_zi.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/sci-rs/src/signal/filter/lfilter_zi.rs b/sci-rs/src/signal/filter/lfilter_zi.rs index 560e207d..24d0dc83 100644 --- a/sci-rs/src/signal/filter/lfilter_zi.rs +++ b/sci-rs/src/signal/filter/lfilter_zi.rs @@ -26,7 +26,6 @@ pub fn lfilter_zi_dyn(b: &[F], a: &[F]) -> Array1 where F: RealField + Copy + PartialEq + Scalar + Zero + One + Sum + SubAssign, { - assert!(b.len() == a.len()); let m = b.len(); let ai0 = a From c72cbd6950637bcf5cae6e1af5f10110497f69ba Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Mon, 29 Sep 2025 16:54:19 +0800 Subject: [PATCH 47/68] Add axis_slice function with bad signature The functionality currently works, but the signature cannot explicitly specify the return type to be `D`. This is made possible by using heap allocation of the `SliceInfo` through `Vec` instead of using the macro/trait method in Lfilter, in an attempt have another method to circumvent the need for a nightly compiler that uses similar trait bounds as NDArray without needing a nightly compiler. This function is eventually used in even/odd/const extensions (which has only been implemented on nalgebra arrays and not NDArrays) so far. --- sci-rs/src/signal/filter/arraytools.rs | 106 ++++++++++++++++++++++++- 1 file changed, 105 insertions(+), 1 deletion(-) diff --git a/sci-rs/src/signal/filter/arraytools.rs b/sci-rs/src/signal/filter/arraytools.rs index 029ef172..145fc93e 100644 --- a/sci-rs/src/signal/filter/arraytools.rs +++ b/sci-rs/src/signal/filter/arraytools.rs @@ -2,7 +2,11 @@ //! //! Designed for ndarrays; with scipy's internal nomenclature. -use ndarray::{ArrayBase, Axis, Data, Dim, Dimension, IntoDimension, Ix, RemoveAxis}; +use alloc::{vec, vec::Vec}; +use ndarray::{ + ArrayBase, ArrayView, Axis, Data, Dim, Dimension, IntoDimension, Ix, RemoveAxis, SliceArg, + SliceInfo, SliceInfoElem, +}; use sci_rs_core::{Error, Result}; /// Internal function for casting into [Axis] and appropriate usize from isize. @@ -111,3 +115,103 @@ where { a.shape().try_into().expect("Could not cast shape to array") } + +/// Takes a slice along `axis` from `a`. +/// +/// # Parameters +/// * `a`: Array being sliced from. +/// * `start`: `Option`. None defaults to 0. +/// * `end`: `Option`. +/// * `step`: `Option`. None default to 1. +/// * `axis`: `Option`. None defaults to -1. +/// +/// # Errors +/// - Axis is out of bounds. +/// - Start/stop elements are out of bounds. +/// +pub fn axis_slice( + a: &ArrayBase, + start: Option, + end: Option, + step: Option, + axis: Option, +) -> Result, D, D> as SliceArg>::OutDim>> +where + S: Data, + D: Dimension, + SliceInfo, D, D>: SliceArg, +{ + if D::NDIM.is_none() { + return Err(Error::InvalidArg { + arg: 'a'.into(), + reason: "IxDyn array is not supported".into(), + }); + } + #[allow(non_snake_case)] + let N = D::NDIM.unwrap(); + + // Axis object and its corresponding usize internal. + let (axis, axis_inner) = { + if axis.is_some_and(|axis| { + !(if axis < 0 { + axis.unsigned_abs() <= D::NDIM.unwrap() + } else { + axis.unsigned_abs() < D::NDIM.unwrap() + }) + }) { + return Err(Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + }); + } + + // We make a best effort to convert into appropriate axis object. + let axis_inner: isize = axis.unwrap_or(-1); + if axis_inner >= 0 { + (Axis(axis_inner as usize), axis_inner.unsigned_abs()) + } else { + let axis_inner = a + .ndim() + .checked_add_signed(axis_inner) + .expect("Invalid add to `axis` option"); + (Axis(axis_inner), axis_inner) + } + }; + + let sl = SliceInfo::<_, D, D>::try_from({ + let mut tmp = vec![ + SliceInfoElem::Slice { + start: 0, + end: None, + step: 1, + }; + N + ]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: start.unwrap_or(0), + end, + step: step.unwrap_or(1), + }; + + tmp + }) + .unwrap(); + + Ok(a.slice(&sl)) +} + +#[cfg(test)] +mod test { + use super::*; + use ndarray::array; + + #[test] + fn axis_slice_doc() { + let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; + + assert_eq!( + axis_slice(&a, Some(0), Some(1), Some(1), Some(1)).unwrap(), + array![[1], [4], [7]] + ); + } +} From 0e301b134a42d714fc846f760345f372aaf6a60c Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:53:31 +0800 Subject: [PATCH 48/68] Fix axis_slice function signature This will avoid having `, _, _>>` everywhere. --- sci-rs/src/signal/filter/arraytools.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sci-rs/src/signal/filter/arraytools.rs b/sci-rs/src/signal/filter/arraytools.rs index 145fc93e..1c4e8aa5 100644 --- a/sci-rs/src/signal/filter/arraytools.rs +++ b/sci-rs/src/signal/filter/arraytools.rs @@ -135,11 +135,11 @@ pub fn axis_slice( end: Option, step: Option, axis: Option, -) -> Result, D, D> as SliceArg>::OutDim>> +) -> Result> where S: Data, D: Dimension, - SliceInfo, D, D>: SliceArg, + SliceInfo, D, D>: SliceArg, { if D::NDIM.is_none() { return Err(Error::InvalidArg { From cf93afc6ed5312e2aa9fe99e3d96603e84123a0f Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Wed, 1 Oct 2025 17:12:38 +0800 Subject: [PATCH 49/68] Add IxDyn tests for axis slice Tests should fail for now --- sci-rs/src/signal/filter/arraytools.rs | 31 +++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/sci-rs/src/signal/filter/arraytools.rs b/sci-rs/src/signal/filter/arraytools.rs index 1c4e8aa5..bbf1daf7 100644 --- a/sci-rs/src/signal/filter/arraytools.rs +++ b/sci-rs/src/signal/filter/arraytools.rs @@ -203,8 +203,9 @@ where #[cfg(test)] mod test { use super::*; - use ndarray::array; + use ndarray::{array, Array, ArrayD, IxDyn}; + /// Tests on IxN arrays. #[test] fn axis_slice_doc() { let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; @@ -213,5 +214,33 @@ mod test { axis_slice(&a, Some(0), Some(1), Some(1), Some(1)).unwrap(), array![[1], [4], [7]] ); + assert_eq!( + axis_slice(&a, Some(0), Some(2), Some(1), Some(0)).unwrap(), + array![[1, 2, 3], [4, 5, 6]] + ); + } + + /// Test on IxDyn Arrays. + #[test] + fn axis_slice_doc_dyn() { + let a = { + let mut y: Array<_, IxDyn> = ArrayD::::zeros(IxDyn(&[2, 3])); + y[[0, 0]] = 5; + y[[0, 1]] = 6; + y[[0, 2]] = 7; + y[[1, 0]] = 1; + y[[1, 1]] = 2; + y[[1, 2]] = 3; + + y + }; + + assert_eq!( + axis_slice(&a, Some(0), Some(1), Some(1), Some(1)) + .unwrap() + .into_dimensionality() + .unwrap(), + array![[5], [1]] + ); } } From 19398ff3b1a49085bb0bda8d071ba39ce1d35ee4 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Wed, 1 Oct 2025 17:42:54 +0800 Subject: [PATCH 50/68] Add support for IxDyn arrays in axis slice --- sci-rs/src/signal/filter/arraytools.rs | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/sci-rs/src/signal/filter/arraytools.rs b/sci-rs/src/signal/filter/arraytools.rs index bbf1daf7..5e3ab6fe 100644 --- a/sci-rs/src/signal/filter/arraytools.rs +++ b/sci-rs/src/signal/filter/arraytools.rs @@ -141,22 +141,15 @@ where D: Dimension, SliceInfo, D, D>: SliceArg, { - if D::NDIM.is_none() { - return Err(Error::InvalidArg { - arg: 'a'.into(), - reason: "IxDyn array is not supported".into(), - }); - } - #[allow(non_snake_case)] - let N = D::NDIM.unwrap(); + let ndim = D::NDIM.unwrap_or(a.ndim()); // Axis object and its corresponding usize internal. let (axis, axis_inner) = { if axis.is_some_and(|axis| { !(if axis < 0 { - axis.unsigned_abs() <= D::NDIM.unwrap() + axis.unsigned_abs() <= ndim } else { - axis.unsigned_abs() < D::NDIM.unwrap() + axis.unsigned_abs() < ndim }) }) { return Err(Error::InvalidArg { @@ -185,7 +178,7 @@ where end: None, step: 1, }; - N + ndim ]; tmp[axis_inner] = SliceInfoElem::Slice { start: start.unwrap_or(0), From 7ac0e3b0cfd58f993c40d7438380a4bbb848fb6a Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Mon, 6 Oct 2025 18:20:26 +0800 Subject: [PATCH 51/68] Add FiltFilt function related enums Also updates the corresponding `ext` that is used by `sosfiltfilt` to denote the difference. --- sci-rs/src/signal/filter/ext.rs | 11 +++++------ sci-rs/src/signal/filter/filtfilt.rs | 12 ++++++++++++ 2 files changed, 17 insertions(+), 6 deletions(-) create mode 100644 sci-rs/src/signal/filter/filtfilt.rs diff --git a/sci-rs/src/signal/filter/ext.rs b/sci-rs/src/signal/filter/ext.rs index 26654625..b9b16a13 100644 --- a/sci-rs/src/signal/filter/ext.rs +++ b/sci-rs/src/signal/filter/ext.rs @@ -3,9 +3,10 @@ use core::ops::Sub; use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, Dyn, OMatrix, Scalar}; use num_traits::{One, Zero}; -/// /// Pad types. /// +/// Used by [super::sosfiltfilt]. +/// This differs from [super::FiltFiltPadType] which has different semantics. pub enum Pad { /// No padding. None, @@ -20,9 +21,7 @@ pub enum Pad { Constant, } -/// -/// Pad an array. -/// +/// Pad an [nalgebra] array. pub fn pad( padtype: Pad, mut padlen: Option, @@ -89,9 +88,9 @@ where } } +/// Pad an [nalgebra] array with odd extension. /// -/// Pad an array with odd extension. -/// +// This differs from [super::FiltFiltPadType]'s ext that acts on [ndarray]. pub fn odd_ext_dyn(x: OMatrix, n: usize, axis: usize) -> OMatrix where T: Scalar + Copy + Zero + One + Sub, diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs new file mode 100644 index 00000000..2d884e10 --- /dev/null +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -0,0 +1,12 @@ +/// Padding utilised in [FiltFilt::filtfilt]. +// WARN: Related/Duplicate: [super::Pad]. +#[derive(Copy, Clone, Default)] +pub enum FiltFiltPadType { + /// Odd extensions + #[default] + Odd, + /// Even extensions + Even, + /// Constant extensions + Const, +} From 9ea208228aae8656557e362c9a8fe9ee787a344b Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Tue, 7 Oct 2025 11:51:47 +0800 Subject: [PATCH 52/68] Split axis_slice into unsafe to allow for caller to skip dim-safety It is possible that the calling function already does some safety checks on the relevant axis checks, and having axis_slice repeatedly run this is unnecessary. --- sci-rs/src/signal/filter/arraytools.rs | 68 ++++++++++++++++++-------- 1 file changed, 47 insertions(+), 21 deletions(-) diff --git a/sci-rs/src/signal/filter/arraytools.rs b/sci-rs/src/signal/filter/arraytools.rs index 5e3ab6fe..40ee5089 100644 --- a/sci-rs/src/signal/filter/arraytools.rs +++ b/sci-rs/src/signal/filter/arraytools.rs @@ -128,7 +128,6 @@ where /// # Errors /// - Axis is out of bounds. /// - Start/stop elements are out of bounds. -/// pub fn axis_slice( a: &ArrayBase, start: Option, @@ -143,8 +142,7 @@ where { let ndim = D::NDIM.unwrap_or(a.ndim()); - // Axis object and its corresponding usize internal. - let (axis, axis_inner) = { + let axis = { if axis.is_some_and(|axis| { !(if axis < 0 { axis.unsigned_abs() <= ndim @@ -158,29 +156,57 @@ where }); } - // We make a best effort to convert into appropriate axis object. - let axis_inner: isize = axis.unwrap_or(-1); - if axis_inner >= 0 { - (Axis(axis_inner as usize), axis_inner.unsigned_abs()) + // We make a best effort to convert into appropriate usize. + let axis: isize = axis.unwrap_or(-1); + if axis >= 0 { + axis.unsigned_abs() } else { - let axis_inner = a - .ndim() - .checked_add_signed(axis_inner) - .expect("Invalid add to `axis` option"); - (Axis(axis_inner), axis_inner) + a.ndim() + .checked_add_signed(axis) + .expect("Invalid add to `axis` option") } }; + unsafe { axis_slice_unsafe(a, start, end, step, axis, ndim) } +} + +/// Takes a slice along `axis` from `a`. +/// +/// Assumes that the specified axis is within bounds. +/// +/// # Parameters +/// * `a`: Array being sliced from. +/// * `start`: `Option`. None defaults to 0. +/// * `end`: `Option`. +/// * `step`: `Option`. None default to 1. +/// * `axis`: `usize`. +/// * `a_ndim`: Dimensionality of `a`. This strictly has to be `a.ndim()`. +/// +/// # Panics +/// - Axis is out of bounds. +/// - Start/stop elements are out of bounds. +pub(crate) fn axis_slice_unsafe( + a: &ArrayBase, + start: Option, + end: Option, + step: Option, + axis: usize, + a_ndim: usize, +) -> Result> +where + S: Data, + D: Dimension, + SliceInfo, D, D>: SliceArg, +{ + debug_assert!(if a_ndim != 0 { + axis < a_ndim // Eg: A 1D-array should only have axis = 0 + } else { + axis <= a_ndim // Allow for axis = 0 when ndim = 0. + }); + let sl = SliceInfo::<_, D, D>::try_from({ - let mut tmp = vec![ - SliceInfoElem::Slice { - start: 0, - end: None, - step: 1, - }; - ndim - ]; - tmp[axis_inner] = SliceInfoElem::Slice { + let mut tmp = vec![SliceInfoElem::from(..); a_ndim]; + tmp[axis] = SliceInfoElem::Slice { start: start.unwrap_or(0), end, step: step.unwrap_or(1), From e12975cd9adf879e0e92ef9784488d52949eaa02 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Wed, 8 Oct 2025 13:50:24 +0800 Subject: [PATCH 53/68] Change axis_slice to be functionally equivalent for negative steps This makes it easier than having to fiddle with indices and manually reversing on the caller's end. --- sci-rs/src/signal/filter/arraytools.rs | 46 +++++++++++++++++++++++--- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/sci-rs/src/signal/filter/arraytools.rs b/sci-rs/src/signal/filter/arraytools.rs index 40ee5089..25f9d298 100644 --- a/sci-rs/src/signal/filter/arraytools.rs +++ b/sci-rs/src/signal/filter/arraytools.rs @@ -204,13 +204,35 @@ where axis <= a_ndim // Allow for axis = 0 when ndim = 0. }); + let axis_len = a.shape()[axis] as isize; + let step = step.unwrap_or(1); + + let coerce = |idx: Option, def_pos: isize, def_neg: isize| -> isize { + match idx { + Some(i) if i < 0 => (axis_len + i).max(0), + Some(i) => i.min(axis_len), + None => { + if !step.is_negative() { + def_pos + } else { + def_neg + } + } + } + }; + let (start, end) = { + let mut start = coerce(start, 0, axis_len - 1); + let mut end = coerce(end, axis_len, -1); + if step.is_negative() { + (end + 1, Some(start + 1)) + } else { + (start, Some(end)) + } + }; + let sl = SliceInfo::<_, D, D>::try_from({ let mut tmp = vec![SliceInfoElem::from(..); a_ndim]; - tmp[axis] = SliceInfoElem::Slice { - start: start.unwrap_or(0), - end, - step: step.unwrap_or(1), - }; + tmp[axis] = SliceInfoElem::Slice { start, end, step }; tmp }) @@ -239,6 +261,20 @@ mod test { ); } + /// Tests on IxN arrays with negative step. + #[test] + fn axis_slice_neg_step() { + let a = array![[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]]; + assert_eq!( + axis_slice(&a, Some(2), Some(0), Some(-1), None).unwrap(), + array![[3, 2], [4, 1]] + ); + assert_eq!( + axis_slice(&a, Some(-2), Some(-4), Some(-1), None).unwrap(), + array![[4, 3], [9, 4]] + ); + } + /// Test on IxDyn Arrays. #[test] fn axis_slice_doc_dyn() { From ff70552ba7f52fdbea028e963fa6665e13ad831d Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Thu, 9 Oct 2025 15:17:24 +0800 Subject: [PATCH 54/68] Add odd_extension function Specify under impl of FiltFiltPadType instead of making it its own explicit function as there is no where other function that seems to make use of `odd_ext`. Implement specifically for IxN arrays first and not for IxDyn arrays. --- sci-rs/src/signal/filter/filtfilt.rs | 103 +++++++++++++++++++++++++++ sci-rs/src/signal/filter/mod.rs | 4 ++ 2 files changed, 107 insertions(+) diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs index 2d884e10..0e5d136c 100644 --- a/sci-rs/src/signal/filter/filtfilt.rs +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -1,3 +1,12 @@ +use super::{axis_slice_unsafe, check_and_get_axis_dyn}; +use alloc::vec::Vec; +use core::ops::{Add, Sub}; +use ndarray::{ + Array, ArrayBase, ArrayView, ArrayView1, Axis, Data, Dim, Dimension, Ix, RawData, RemoveAxis, + SliceArg, SliceInfo, SliceInfoElem, +}; +use sci_rs_core::{Error, Result}; + /// Padding utilised in [FiltFilt::filtfilt]. // WARN: Related/Duplicate: [super::Pad]. #[derive(Copy, Clone, Default)] @@ -10,3 +19,97 @@ pub enum FiltFiltPadType { /// Constant extensions Const, } + +impl FiltFiltPadType { + /// Extensions on ndarrays. + /// + /// # Parameters + /// `self`: Type of extension. + /// `x`: Array to extend on. + /// `n`: The number of elements by which to extend `x` at each end of the axis. + /// `axis`: The axis along which to extend `x`. + /// + /// ## Type of extension + /// * odd: Odd extension at the boundaries of an array, generating a new ndarray by making an + /// odd extension of `x` along the specified axis. + fn ext<'a, T, S, D>( + &'a self, + x: ArrayBase, + n: usize, + axis: Option, + ) -> Result> + where + T: Clone + Add + Sub + 'a, + S: Data, + D: Dimension + RemoveAxis, + SliceInfo, D, D>: SliceArg, + { + if D::NDIM.is_none() { + return Err(Error::InvalidArg { + arg: "x".into(), + reason: "IxDyn is not supported".into(), + }); + } + + let ndim = D::NDIM.unwrap(); + + let axis = check_and_get_axis_dyn(axis, &x).map_err(|_| Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + })?; + + match self { + FiltFiltPadType::Odd => { + if n < 1 { + return Ok(x.to_owned()); + } + + let left_end = + unsafe { axis_slice_unsafe(&x, Some(0), Some(1), None, axis, ndim) }?; + let left_ext = unsafe { + axis_slice_unsafe(&x, Some(n as isize), Some(0), Some(-1), axis, ndim) + }?; + let right_end = unsafe { axis_slice_unsafe(&x, Some(-1), None, None, axis, ndim) }?; + let right_ext = unsafe { + axis_slice_unsafe(&x, Some(-2), Some(-2 - (n as isize)), Some(-1), axis, ndim) + }?; + + let ll = left_end.to_owned().add(left_end).sub(left_ext); + let rr = right_end.to_owned().add(right_end).sub(right_ext); + + ndarray::concatenate(Axis(axis), &[ll.view(), x.view(), rr.view()]).map_err(|_| { + Error::InvalidArg { + arg: "x".into(), + reason: "Shape Error".into(), + } + }) + } + FiltFiltPadType::Even => todo!(), + FiltFiltPadType::Const => todo!(), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + use alloc::vec; + use ndarray::array; + + /// Test odd_ext as from documentation. + #[test] + fn odd_ext_doc() { + let odd = FiltFiltPadType::Odd; + let a = array![[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]]; + + let result = odd.ext(a, 2, None).expect("Could not get odd_ext."); + let expected = array![ + [-1, 0, 1, 2, 3, 4, 5, 6, 7], + [-4, -1, 0, 1, 4, 9, 16, 23, 28] + ]; + + ndarray::Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_eq!(r, e)); + } +} diff --git a/sci-rs/src/signal/filter/mod.rs b/sci-rs/src/signal/filter/mod.rs index ec0c5f2c..0c0dc4b5 100644 --- a/sci-rs/src/signal/filter/mod.rs +++ b/sci-rs/src/signal/filter/mod.rs @@ -29,6 +29,8 @@ mod arraytools; #[cfg(feature = "alloc")] use arraytools::*; +#[cfg(feature = "alloc")] +mod filtfilt; #[cfg(feature = "alloc")] mod lfilter; #[cfg(feature = "alloc")] @@ -40,6 +42,8 @@ mod sosfilt_zi; #[cfg(feature = "alloc")] mod sosfiltfilt; +#[cfg(feature = "alloc")] +pub use filtfilt::*; #[cfg(feature = "alloc")] pub use lfilter::*; #[cfg(feature = "alloc")] From cb39a747b7f0846b7b98ed969aa998d3d107f2c6 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Fri, 10 Oct 2025 15:37:38 +0800 Subject: [PATCH 55/68] Add even_extension function Implements even_ext under impl of FiltFiltPadType for IxN arrays. --- sci-rs/src/signal/filter/filtfilt.rs | 35 +++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs index 0e5d136c..65616522 100644 --- a/sci-rs/src/signal/filter/filtfilt.rs +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -32,6 +32,8 @@ impl FiltFiltPadType { /// ## Type of extension /// * odd: Odd extension at the boundaries of an array, generating a new ndarray by making an /// odd extension of `x` along the specified axis. + /// * even: Even extension at the boundaries of an array, generating a new ndarray by making an + /// even extension of `x` along the specified axis. fn ext<'a, T, S, D>( &'a self, x: ArrayBase, @@ -84,7 +86,24 @@ impl FiltFiltPadType { } }) } - FiltFiltPadType::Even => todo!(), + FiltFiltPadType::Even => { + if n < 1 { + return Ok(x.to_owned()); + } + + let left_ext = unsafe { + axis_slice_unsafe(&x, Some(n as isize), Some(0), Some(-1), axis, ndim) + }?; + let right_ext = unsafe { + axis_slice_unsafe(&x, Some(-2), Some(-2 - (n as isize)), Some(-1), axis, ndim) + }?; + + ndarray::concatenate(Axis(axis), &[left_ext.view(), x.view(), right_ext.view()]) + .map_err(|_| Error::InvalidArg { + arg: "x".into(), + reason: "Shape Error".into(), + }) + } FiltFiltPadType::Const => todo!(), } } @@ -112,4 +131,18 @@ mod test { .and(&expected) .for_each(|&r, &e| assert_eq!(r, e)); } + + /// Test odd_ext as from documentation. + #[test] + fn even_ext_doc() { + let even = FiltFiltPadType::Even; + let a = array![[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]]; + + let result = even.ext(a, 2, None).expect("Could not get even_ext."); + let expected = array![[3, 2, 1, 2, 3, 4, 5, 4, 3], [4, 1, 0, 1, 4, 9, 16, 9, 4]]; + + ndarray::Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_eq!(r, e)); + } } From 3ef63a857bf1bc5aa599272bef11a47a9a354bd0 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sat, 11 Oct 2025 15:12:04 +0800 Subject: [PATCH 56/68] Add constant_extension function Implements const_ext under impl of FiltFiltPadType for IxN arrays. --- sci-rs/src/signal/filter/filtfilt.rs | 55 ++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs index 65616522..922f936e 100644 --- a/sci-rs/src/signal/filter/filtfilt.rs +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -1,5 +1,5 @@ use super::{axis_slice_unsafe, check_and_get_axis_dyn}; -use alloc::vec::Vec; +use alloc::{vec, vec::Vec}; use core::ops::{Add, Sub}; use ndarray::{ Array, ArrayBase, ArrayView, ArrayView1, Axis, Data, Dim, Dimension, Ix, RawData, RemoveAxis, @@ -34,6 +34,8 @@ impl FiltFiltPadType { /// odd extension of `x` along the specified axis. /// * even: Even extension at the boundaries of an array, generating a new ndarray by making an /// even extension of `x` along the specified axis. + /// * const: Constant extension at the boundaries of an array, generating a new ndarray by + /// making an constant extension of `x` along the specified axis. fn ext<'a, T, S, D>( &'a self, x: ArrayBase, @@ -41,7 +43,7 @@ impl FiltFiltPadType { axis: Option, ) -> Result> where - T: Clone + Add + Sub + 'a, + T: Clone + Add + Sub + num_traits::One + 'a, S: Data, D: Dimension + RemoveAxis, SliceInfo, D, D>: SliceArg, @@ -104,7 +106,40 @@ impl FiltFiltPadType { reason: "Shape Error".into(), }) } - FiltFiltPadType::Const => todo!(), + FiltFiltPadType::Const => { + if n < 1 { + return Ok(x.to_owned()); + } + + let ones: Array = Array::ones({ + let mut t = vec![1; ndim]; + t[axis] = n; + ndarray::IxDyn(&t) + }) + .into_dimensionality() // This is needed for IxDyn -> IxN + .map_err(|_| Error::InvalidArg { + arg: "x".into(), + reason: "Coercing into identical dimensionality had issue".into(), + })?; + + let left_ext = { + let left_end = + unsafe { axis_slice_unsafe(&x, Some(0), Some(1), None, axis, ndim) }?; + ones.clone() * left_end + }; + + let right_ext = { + let right_end = + unsafe { axis_slice_unsafe(&x, Some(-1), None, None, axis, ndim) }?; + ones * right_end + }; + + ndarray::concatenate(Axis(axis), &[left_ext.view(), x.view(), right_ext.view()]) + .map_err(|_| Error::InvalidArg { + arg: "x".into(), + reason: "Shape Error".into(), + }) + } } } } @@ -145,4 +180,18 @@ mod test { .and(&expected) .for_each(|&r, &e| assert_eq!(r, e)); } + + /// Test const_ext as from documentation. + #[test] + fn const_ext_doc() { + let const_ext = FiltFiltPadType::Const; + let a = array![[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]]; + + let result = const_ext.ext(a, 2, None).expect("Could not get even_ext."); + let expected = array![[1, 1, 1, 2, 3, 4, 5, 5, 5], [0, 0, 0, 1, 4, 9, 16, 16, 16]]; + + ndarray::Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_eq!(r, e)); + } } From a6800befcf3e4c39d4a76a3fe8cb31a2a015a470 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sun, 12 Oct 2025 16:29:37 +0800 Subject: [PATCH 57/68] Refactor out no extension argument for n Short circuit to return an owned variant if n = 0 for `*_ext`. --- sci-rs/src/signal/filter/filtfilt.rs | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs index 922f936e..e6ecc742 100644 --- a/sci-rs/src/signal/filter/filtfilt.rs +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -48,6 +48,10 @@ impl FiltFiltPadType { D: Dimension + RemoveAxis, SliceInfo, D, D>: SliceArg, { + if n < 1 { + return Ok(x.to_owned()); + } + if D::NDIM.is_none() { return Err(Error::InvalidArg { arg: "x".into(), @@ -64,10 +68,6 @@ impl FiltFiltPadType { match self { FiltFiltPadType::Odd => { - if n < 1 { - return Ok(x.to_owned()); - } - let left_end = unsafe { axis_slice_unsafe(&x, Some(0), Some(1), None, axis, ndim) }?; let left_ext = unsafe { @@ -89,10 +89,6 @@ impl FiltFiltPadType { }) } FiltFiltPadType::Even => { - if n < 1 { - return Ok(x.to_owned()); - } - let left_ext = unsafe { axis_slice_unsafe(&x, Some(n as isize), Some(0), Some(-1), axis, ndim) }?; @@ -107,10 +103,6 @@ impl FiltFiltPadType { }) } FiltFiltPadType::Const => { - if n < 1 { - return Ok(x.to_owned()); - } - let ones: Array = Array::ones({ let mut t = vec![1; ndim]; t[axis] = n; From ef0ffa7c85d3482645a543ae2650704aa559321a Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Mon, 13 Oct 2025 15:57:07 +0800 Subject: [PATCH 58/68] Add IxDyn support for `*_ext` functions Also added the corresponding test cases. --- sci-rs/src/signal/filter/filtfilt.rs | 40 ++++++++++++++++++++-------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs index e6ecc742..5c54160d 100644 --- a/sci-rs/src/signal/filter/filtfilt.rs +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -52,14 +52,7 @@ impl FiltFiltPadType { return Ok(x.to_owned()); } - if D::NDIM.is_none() { - return Err(Error::InvalidArg { - arg: "x".into(), - reason: "IxDyn is not supported".into(), - }); - } - - let ndim = D::NDIM.unwrap(); + let ndim = D::NDIM.unwrap_or(x.ndim()); let axis = check_and_get_axis_dyn(axis, &x).map_err(|_| Error::InvalidArg { arg: "axis".into(), @@ -148,7 +141,7 @@ mod test { let odd = FiltFiltPadType::Odd; let a = array![[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]]; - let result = odd.ext(a, 2, None).expect("Could not get odd_ext."); + let result = odd.ext(a.view(), 2, None).expect("Could not get odd_ext."); let expected = array![ [-1, 0, 1, 2, 3, 4, 5, 6, 7], [-4, -1, 0, 1, 4, 9, 16, 23, 28] @@ -157,6 +150,13 @@ mod test { ndarray::Zip::from(&result) .and(&expected) .for_each(|&r, &e| assert_eq!(r, e)); + + let result = odd + .ext(a.into_dyn(), 2, None) + .expect("Could not get odd_ext."); + ndarray::Zip::from(&result) + .and(&expected.into_dyn()) + .for_each(|&r, &e| assert_eq!(r, e)); } /// Test odd_ext as from documentation. @@ -165,12 +165,21 @@ mod test { let even = FiltFiltPadType::Even; let a = array![[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]]; - let result = even.ext(a, 2, None).expect("Could not get even_ext."); + let result = even + .ext(a.view(), 2, None) + .expect("Could not get even_ext."); let expected = array![[3, 2, 1, 2, 3, 4, 5, 4, 3], [4, 1, 0, 1, 4, 9, 16, 9, 4]]; ndarray::Zip::from(&result) .and(&expected) .for_each(|&r, &e| assert_eq!(r, e)); + + let result = even + .ext(a.into_dyn(), 2, None) + .expect("Could not get even_ext."); + ndarray::Zip::from(&result) + .and(&expected.into_dyn()) + .for_each(|&r, &e| assert_eq!(r, e)); } /// Test const_ext as from documentation. @@ -179,11 +188,20 @@ mod test { let const_ext = FiltFiltPadType::Const; let a = array![[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]]; - let result = const_ext.ext(a, 2, None).expect("Could not get even_ext."); + let result = const_ext + .ext(a.view(), 2, None) + .expect("Could not get even_ext."); let expected = array![[1, 1, 1, 2, 3, 4, 5, 5, 5], [0, 0, 0, 1, 4, 9, 16, 16, 16]]; ndarray::Zip::from(&result) .and(&expected) .for_each(|&r, &e| assert_eq!(r, e)); + + let result = const_ext + .ext(a.into_dyn(), 2, None) + .expect("Could not get even_ext."); + ndarray::Zip::from(&result) + .and(&expected.into_dyn()) + .for_each(|&r, &e| assert_eq!(r, e)); } } From 82c739104584dac51a0e4fd3a37297e89750b686 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Tue, 14 Oct 2025 15:56:13 +0800 Subject: [PATCH 59/68] Add error to `*_ext` if `n` is too large for specified axis Rather than panicking, an error can be returned from the caller. --- sci-rs/src/signal/filter/arraytools.rs | 2 ++ sci-rs/src/signal/filter/filtfilt.rs | 47 ++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/sci-rs/src/signal/filter/arraytools.rs b/sci-rs/src/signal/filter/arraytools.rs index 25f9d298..2d58ec33 100644 --- a/sci-rs/src/signal/filter/arraytools.rs +++ b/sci-rs/src/signal/filter/arraytools.rs @@ -127,6 +127,8 @@ where /// /// # Errors /// - Axis is out of bounds. +/// +/// # Panics /// - Start/stop elements are out of bounds. pub fn axis_slice( a: &ArrayBase, diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs index 5c54160d..a95d4dbf 100644 --- a/sci-rs/src/signal/filter/filtfilt.rs +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -59,6 +59,17 @@ impl FiltFiltPadType { reason: "index out of range.".into(), })?; + { + let axis_len = x.shape()[axis]; + if n >= axis_len { + return Err(Error::InvalidArg { + arg: "n".into(), + reason: "Extension of array cannot be longer than array in specified axis." + .into(), + }); + } + } + match self { FiltFiltPadType::Odd => { let left_end = @@ -159,6 +170,18 @@ mod test { .for_each(|&r, &e| assert_eq!(r, e)); } + /// Test odd_ext's limits. + #[test] + fn odd_ext_limits() { + let odd = FiltFiltPadType::Odd; + let a = array![[1, 2, 3, 4], [0, 1, 4, 9]]; + + let result = odd.ext(a.view(), 3, None); + assert!(result.is_ok()); + let result = odd.ext(a, 4, None); + assert!(result.is_err()); + } + /// Test odd_ext as from documentation. #[test] fn even_ext_doc() { @@ -182,6 +205,18 @@ mod test { .for_each(|&r, &e| assert_eq!(r, e)); } + /// Test even_ext's limits. + #[test] + fn even_ext_limits() { + let even = FiltFiltPadType::Even; + let a = array![[1, 2, 3, 4], [0, 1, 4, 9]]; + + let result = even.ext(a.view(), 3, None); + assert!(result.is_ok()); + let result = even.ext(a, 4, None); + assert!(result.is_err()); + } + /// Test const_ext as from documentation. #[test] fn const_ext_doc() { @@ -204,4 +239,16 @@ mod test { .and(&expected.into_dyn()) .for_each(|&r, &e| assert_eq!(r, e)); } + + /// Test const_ext's limits. + #[test] + fn const_ext_limits() { + let const_ext = FiltFiltPadType::Const; + let a = array![[1, 2, 3, 4], [0, 1, 4, 9]]; + + let result = const_ext.ext(a.view(), 3, None); + assert!(result.is_ok()); + let result = const_ext.ext(a, 4, None); + assert!(result.is_err()); + } } From 3ac853067fe0bd21ee2d603d49c652ee67796a78 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Wed, 15 Oct 2025 11:21:58 +0800 Subject: [PATCH 60/68] Remove unnecessary lifetime in `*_ext`'s trait bounds The `'a` lifetime can be elided. --- sci-rs/src/signal/filter/filtfilt.rs | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs index a95d4dbf..3f8310d4 100644 --- a/sci-rs/src/signal/filter/filtfilt.rs +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -36,14 +36,9 @@ impl FiltFiltPadType { /// even extension of `x` along the specified axis. /// * const: Constant extension at the boundaries of an array, generating a new ndarray by /// making an constant extension of `x` along the specified axis. - fn ext<'a, T, S, D>( - &'a self, - x: ArrayBase, - n: usize, - axis: Option, - ) -> Result> + fn ext(&self, x: ArrayBase, n: usize, axis: Option) -> Result> where - T: Clone + Add + Sub + num_traits::One + 'a, + T: Clone + Add + Sub + num_traits::One, S: Data, D: Dimension + RemoveAxis, SliceInfo, D, D>: SliceArg, From 2b335500c2446c2588b02ea86ae59be2840815b5 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Thu, 16 Oct 2025 15:03:24 +0800 Subject: [PATCH 61/68] Add validate_pad This is used in filtfilt. --- sci-rs/src/signal/filter/filtfilt.rs | 60 +++++++++++++++++++++++++++- 1 file changed, 58 insertions(+), 2 deletions(-) diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs index 3f8310d4..6c32ba93 100644 --- a/sci-rs/src/signal/filter/filtfilt.rs +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -2,8 +2,8 @@ use super::{axis_slice_unsafe, check_and_get_axis_dyn}; use alloc::{vec, vec::Vec}; use core::ops::{Add, Sub}; use ndarray::{ - Array, ArrayBase, ArrayView, ArrayView1, Axis, Data, Dim, Dimension, Ix, RawData, RemoveAxis, - SliceArg, SliceInfo, SliceInfoElem, + Array, ArrayBase, ArrayView, ArrayView1, Axis, CowArray, Data, Dim, Dimension, Ix, RawData, + RemoveAxis, SliceArg, SliceInfo, SliceInfoElem, }; use sci_rs_core::{Error, Result}; @@ -135,6 +135,62 @@ impl FiltFiltPadType { } } +/// Arguments for [FiltFilt::filtfilt]. +#[derive(Copy, Clone, Default)] +pub struct FiltFiltPad { + /// Padding type. + pub pad_type: FiltFiltPadType, + /// Length of padding. + pub len: Option, +} + +/// Helper for validating padding of filtfilt. +/// +/// # Parameters +/// `pad`: `Option` +/// A none value from the user specifies no padding, which implies that the pad len is also 0. +/// Otherwise, the user specifies a specific padding and pad_len. +/// `x`: NDArray +/// Array that is being filtered. +/// `axis`: usize +/// Axis of `x` which is being filtered. +/// `ntaps`: usize +/// This simply is `max(a.len(), b.len())`. +/// +/// # Panics +/// `axis` is as acting on `x` is assumed to be valid, otherwise panics. +fn validate_pad( + pad: Option, + x: ArrayView, + axis: usize, + ntaps: usize, +) -> Result<(usize, CowArray)> +where + T: Clone + Add + Sub + num_traits::One, + D: Dimension + RemoveAxis, + SliceInfo, D, D>: SliceArg, +{ + let edge = match pad { + None => 0, + Some(FiltFiltPad { len, .. }) => len.unwrap_or(ntaps * 3), + }; + + { + x.shape().get(axis).ok_or(Error::InvalidArg { + arg: "axis".into(), + reason: "The length of the input vector x must be greater than padlen.".into(), + })?; + } + + let ext = if let Some(FiltFiltPad { pad_type, .. }) = pad { + CowArray::from(pad_type.ext(x, edge, Some(axis as _))?) + } else { + CowArray::from(x) + }; + + Ok((edge, ext)) +} + #[cfg(test)] mod test { use super::*; From 55cec7bb275969a6988b93d6380703589aab2d91 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Fri, 17 Oct 2025 17:56:49 +0800 Subject: [PATCH 62/68] Add tests for validate_pad Obtained from python. --- sci-rs/src/signal/filter/filtfilt.rs | 75 ++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs index 6c32ba93..989d9067 100644 --- a/sci-rs/src/signal/filter/filtfilt.rs +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -156,6 +156,7 @@ pub struct FiltFiltPad { /// Axis of `x` which is being filtered. /// `ntaps`: usize /// This simply is `max(a.len(), b.len())`. +// In an ideal world, this would be shoe-horned into FiltFiltPad's len parameter. /// /// # Panics /// `axis` is as acting on `x` is assumed to be valid, otherwise panics. @@ -302,4 +303,78 @@ mod test { let result = const_ext.ext(a, 4, None); assert!(result.is_err()); } + + /// Tests for when there is no padding. + #[test] + fn pad_none() { + let p = None; + let x = array![1]; + let result = validate_pad(p, x.view(), 0, 0).expect("Could not pad with none."); + + assert_eq!(result.0, 0); + assert_eq!(result.1, x); + } + + /// Tests for when there is even padding. + #[test] + fn pad_even() { + let p = Some(FiltFiltPad { + pad_type: FiltFiltPadType::Even, + len: Some(2), + }); + let x = array![[1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 4, 9, 16, 25, 36, 49]]; + let (result_edge, result) = + validate_pad(p, x.view(), 1, 2).expect("Could not pad with even."); + let expected_edge = 2; + let expected = array![ + [3, 2, 1, 2, 3, 4, 5, 6, 7, 8, 7, 6], + [4, 1, 0, 1, 4, 9, 16, 25, 36, 49, 36, 25] + ]; + + assert_eq!(result_edge, expected_edge); + assert_eq!(result, expected); + ndarray::Zip::from(&result) + .and(&expected) + .for_each(|r, e| assert_eq!(r, e)); + } + + /// Tests for when there is even padding. + #[test] + fn pad_odd() { + let p = Some(FiltFiltPad { + pad_type: FiltFiltPadType::Odd, + len: Some(2), + }); + let x = array![[1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 4, 9, 16, 25, 36, 49]]; + let (result_edge, result) = + validate_pad(p, x.view(), 1, 2).expect("Could not pad with odd."); + let expected_edge = 2; + let expected = array![ + [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [-4, -1, 0, 1, 4, 9, 16, 25, 36, 49, 62, 73] + ]; + + assert_eq!(result_edge, expected_edge); + assert_eq!(result, expected); + } + + /// Tests for when there is const padding. + #[test] + fn pad_const() { + let p = Some(FiltFiltPad { + pad_type: FiltFiltPadType::Const, + len: Some(2), + }); + let x = array![[1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 4, 9, 16, 25, 36, 49]]; + let (result_edge, result) = + validate_pad(p, x.view(), 1, 2).expect("Could not pad with odd."); + let expected_edge = 2; + let expected = array![ + [1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8], + [0, 0, 0, 1, 4, 9, 16, 25, 36, 49, 49, 49] + ]; + + assert_eq!(result_edge, expected_edge); + assert_eq!(result, expected); + } } From 04b4feb2d1bf52e2022a70778a2debf22f14b7f2 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Sun, 19 Oct 2025 17:09:31 +0800 Subject: [PATCH 63/68] Add axis_reverse and axis_reverse_unsafe This are thin wrappers around axis_slice_unsafe with and without checks. --- sci-rs/src/signal/filter/arraytools.rs | 75 ++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/sci-rs/src/signal/filter/arraytools.rs b/sci-rs/src/signal/filter/arraytools.rs index 2d58ec33..22e5389b 100644 --- a/sci-rs/src/signal/filter/arraytools.rs +++ b/sci-rs/src/signal/filter/arraytools.rs @@ -243,6 +243,81 @@ where Ok(a.slice(&sl)) } +/// Reverse the 1-D slices (aka lanes) of `a` along axis `axis`. +/// +/// Returns axis_slice(a, step=-1, axis=axis). +/// +/// # Parameters +/// * `a`: Array being sliced from. +/// * `axis`: `Option`. None defaults to -1. +pub fn axis_reverse( + a: &ArrayBase, + axis: Option, +) -> Result> +where + S: Data, + D: Dimension, + SliceInfo, D, D>: SliceArg, +{ + let ndim = D::NDIM.unwrap_or(a.ndim()); + + let axis = { + if axis.is_some_and(|axis| { + !(if axis < 0 { + axis.unsigned_abs() <= ndim + } else { + axis.unsigned_abs() < ndim + }) + }) { + return Err(Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + }); + } + + // We make a best effort to convert into appropriate usize. + let axis: isize = axis.unwrap_or(-1); + if axis >= 0 { + axis.unsigned_abs() + } else { + a.ndim() + .checked_add_signed(axis) + .expect("Invalid add to `axis` option") + } + }; + + unsafe { axis_slice_unsafe(a, None, None, Some(-1), axis, ndim) } +} + +/// Reverse the 1-D slices (aka lanes) of `a` along axis `axis`. +/// +/// Returns axis_slice(a, step=-1, axis=axis). +/// +/// # Parameters +/// * `a`: Array being sliced from. +/// * `axis`: `usize`. +/// * `a_ndim`: Dimensionality of `a`. This strictly has to be `a.ndim()`. +/// +/// # Panics +/// If axis is out of bounds, and dimensions are wrong. +#[inline] +pub(crate) unsafe fn axis_reverse_unsafe( + a: &ArrayBase, + axis: usize, + a_ndim: usize, +) -> ArrayView<'_, A, D> +where + S: Data, + D: Dimension, + SliceInfo, D, D>: SliceArg, +{ + unsafe { + let r = axis_slice_unsafe(a, None, None, Some(-1), axis, a_ndim); + debug_assert!(r.is_ok()); + r.unwrap_unchecked() + } +} + #[cfg(test)] mod test { use super::*; From 5645f50b6d6d534a1453417604aa625f6c773813 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Mon, 20 Oct 2025 10:29:56 +0800 Subject: [PATCH 64/68] Fix axis_slice for larger indices Previously, it was assumed that the only valid values of start and end must be contained within `0 <= idx < x.shape()[axis]` after unwrapping the negative and counting from the back (starting with -1), hence the use of a max clamp. However, this was wrong, as can be validated against scipy internal. Thus, a case for this "more negative" than `axis_len` is spelt out, with the appropriate fix (removing clamp to 0) is done. --- sci-rs/src/signal/filter/arraytools.rs | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/sci-rs/src/signal/filter/arraytools.rs b/sci-rs/src/signal/filter/arraytools.rs index 22e5389b..c86e71ec 100644 --- a/sci-rs/src/signal/filter/arraytools.rs +++ b/sci-rs/src/signal/filter/arraytools.rs @@ -211,7 +211,7 @@ where let coerce = |idx: Option, def_pos: isize, def_neg: isize| -> isize { match idx { - Some(i) if i < 0 => (axis_len + i).max(0), + Some(i) if i.is_negative() => (axis_len + i), Some(i) => i.min(axis_len), None => { if !step.is_negative() { @@ -228,7 +228,7 @@ where if step.is_negative() { (end + 1, Some(start + 1)) } else { - (start, Some(end)) + (start, Some(end)) // No + 1 breaking into axis_len } }; @@ -352,6 +352,20 @@ mod test { ); } + /// Tests on IxN arrays with negative indices. + #[test] + fn axis_slice_neg_indices_weird() { + let a = array![1, 2, 3, 4]; + assert_eq!( + unsafe { axis_slice(&a, Some(-2), Some(-5), Some(-1), None) }.unwrap(), + array![3, 2, 1] + ); + assert_eq!( + unsafe { axis_slice_unsafe(&a, Some(-2), Some(-5), Some(-1), 0, a.ndim()) }.unwrap(), + array![3, 2, 1] + ); + } + /// Test on IxDyn Arrays. #[test] fn axis_slice_doc_dyn() { From 9a7c9e38498c64db8926a0780ca9b584a05e4e97 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Mon, 20 Oct 2025 15:45:27 +0800 Subject: [PATCH 65/68] Scope FiltFilt into existence Also specify that filtfilt_gust is a separate method that is not yet implemented. As this strongly depends on `lfilter` as a function, with `lfilter` currently only supported for FIR, it does not yet work on IIR. --- sci-rs/src/signal/filter/filtfilt.rs | 83 ++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs index 989d9067..e83b4d26 100644 --- a/sci-rs/src/signal/filter/filtfilt.rs +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -192,6 +192,89 @@ where Ok((edge, ext)) } +/// Implement filtfilt for fixed dimension of input array `x`. +/// +/// Valid only from 1 to 6 dimensional arrays. +/// +/// Note: FiltFilt gust is a separate function not yet implemented. +// Note: Usage of trait and macro for implementation is an inherited from LFilter. +// LFilter for supertrait? +pub trait FiltFilt +where + S: Data, +{ + /// Apply a digital filter forward and backward to a signal. + /// + /// This function applies a linear digital filter twice, once forward and + /// once backwards. The combined filter has zero phase and a filter order + /// twice that of the original. + /// + /// The function provides options for handling the edges of the signal. + /// + /// The function `sosfiltfilt` (and filter design using ``output='sos'``) + /// should be preferred over `filtfilt` for most filtering tasks, as + /// second-order sections have fewer numerical problems. + /// + /// # Parameters + /// * `b`: (N,) array_like + /// The numerator coefficient vector of the filter. + /// * `a`: (N,) array_like + /// The denominator coefficient vector of the filter. If ``a[0]`` + /// is not 1, then both `a` and `b` are normalized b:y ``a[0]``. + /// * `x`: array_like + /// The array of data to be filtered. + /// * `axis`: int, optional + /// The axis of `x` to which the filter is applied. + /// Default is -1. + /// * `pad` + /// [Option::None] here denotes a deliberate absence of padding. + /// * `padtype` [FiltFiltPadType] + /// Must be 'odd', 'even', 'constant', or None. + /// This determines the type of extension to use for the padded signal to which the filter is applied. + /// The default is 'odd'. + /// * `padlen` int or None, optional + /// The number of elements by which to extend `x` at both ends of `axis` before applying + /// the filter. + /// This value must be less than ``x.shape[axis] - 1``. ``padlen=0`` implies no padding. [Option::None] here denotes the default value. + /// The default value is ``3 * max(len(a), len(b))``. + /// + /// # Returns + /// * y : `Array` + /// The filtered output with the same shape as `x`. + /// + /// # See Also + /// sosfiltfilt, lfilter_zi, lfilter, lfiltic, savgol_filter, sosfilt, filtfilt_gust + /// + /// # Notes + /// When `method` is "pad", the function pads the data along the given axis in one of three + /// ways: odd, even or constant. The odd and even extensions have the corresponding symmetry + /// about the end point of the data. The constant extension extends the data with the values + /// at the end points. On both the forward and backward passes, the initial condition of the + /// filter is found by using `lfilter_zi` and scaling it by the end point of the extended data. + fn filtfilt<'a>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + padding: Option, + ) -> Result>>; + + /// Forward-back IIR filter that uses Gustafsson's method. + /// + /// Not yet implemented. + fn filtfilt_gust<'a>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + irlen: Option, + ) -> Result>> + where + Self: Sized, + { + todo!("Gust method of FiltFilt is not yet implemented."); + } +} #[cfg(test)] mod test { use super::*; From 448d3adc06971bfed4ab428334f580e80b8606aa Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Tue, 21 Oct 2025 17:09:09 +0800 Subject: [PATCH 66/68] Add 2-dimensional input tests for FIR filtfilt The values were generated with python. --- sci-rs/src/signal/filter/filtfilt.rs | 148 ++++++++++++++++++++++++++- 1 file changed, 147 insertions(+), 1 deletion(-) diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs index e83b4d26..e6b64562 100644 --- a/sci-rs/src/signal/filter/filtfilt.rs +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -279,7 +279,8 @@ where mod test { use super::*; use alloc::vec; - use ndarray::array; + use approx::assert_relative_eq; + use ndarray::{array, Zip}; /// Test odd_ext as from documentation. #[test] @@ -460,4 +461,149 @@ mod test { assert_eq!(result_edge, expected_edge); assert_eq!(result, expected); } + + /// Tests that filtfilt works with no padding with a FIR filter. + #[test] + fn filtfilt_2d_fir_none_pad() { + let b = array![0.1, 0.2, 0.1, -0.3, 0.2, 0.4, 0.2, 0.1]; + let a = array![1.]; + let x = { + let rows_n = 40; + let mut x = Array::zeros((2, rows_n)); + x.row_mut(0) + .assign(&Array::linspace(1 as _, rows_n as _, rows_n)); + x.row_mut(1) + .assign(&Array::from_iter((0..rows_n).map(|i| (i as f64).powi(2)))); + + x + }; + let result = Array::<_, Dim<[_; 2]>>::filtfilt(b.view(), a.view(), x, Some(1), None) + .expect("Could not filtfilt none_pad"); + let expected = array![ + [ + 2.14, 2.84, 3.67, 4.43, 5.2, 6.06, 7.01, 8., 9., 10., 11., 12., 13., 14., 15., 16., + 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., + 33., 33.9, 34.6, 34.9, 35., 35.4, 35.7, 35.8 + ], + [ + 5.56, 8.54, 13.05, 19.15, 26.78, 36.04, 47.11, 60.12, 75.12, 92.12, 111.12, 132.12, + 155.12, 180.12, 207.12, 236.12, 267.12, 300.12, 335.12, 372.12, 411.12, 452.12, + 495.12, 540.12, 587.12, 636.12, 687.12, 740.12, 795.12, 852.12, 911.12, 972.12, + 1035.12, 1093.06, 1138.68, 1157.46, 1162.72, 1189.36, 1209.74, 1216.6 + ] + ]; + + Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_relative_eq!(r, e, max_relative = 1e-6)); + } + + /// Tests that filtfilt works with some padding with a FIR filter. + #[test] + fn filtfilt_2d_fir_some_pad() { + let b = array![0.1, 0.2, 0.1, -0.3, 0.2, 0.4, 0.2, 0.1]; + let a = array![1.]; + let x = { + let rows_n = 40; + let mut x = Array::zeros((2, rows_n)); + x.row_mut(0) + .assign(&Array::linspace(1 as _, rows_n as _, rows_n)); + x.row_mut(1) + .assign(&Array::from_iter((0..rows_n).map(|i| (i as f64).powi(2)))); + + x + }; + let pad_arg = FiltFiltPad { + pad_type: FiltFiltPadType::default(), + len: Some(4), + }; + let result = + Array::<_, Dim<[_; 2]>>::filtfilt(b.view(), a.view(), x, Some(1), Some(pad_arg)) + .expect("Could not filtfilt none_pad"); + let expected = array![ + [ + 1.2, 2.06, 3.01, 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., + 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., + 34., 35., 36., 37., 37.9, 38.6, 38.9 + ], + [ + 1.94, 5.52, 11.07, 18.18, 26.44, 35.96, 47.1, 60.12, 75.12, 92.12, 111.12, 132.12, + 155.12, 180.12, 207.12, 236.12, 267.12, 300.12, 335.12, 372.12, 411.12, 452.12, + 495.12, 540.12, 587.12, 636.12, 687.12, 740.12, 795.12, 852.12, 911.12, 972.12, + 1035.12, 1100.1, 1166.96, 1235.44, 1305.18, 1368.54, 1418.2, 1439.28 + ] + ]; + + Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_relative_eq!(r, e, max_relative = 1e-6)); + } + + /// Tests that filtfilt works with default padding with a FIR filter. + #[test] + fn filtfilt_2d_fir_default_pad() { + let b = array![0.1, 0.2, 0.1, -0.3, 0.2, 0.4, 0.2, 0.1]; + let a = array![1.]; + let x = { + let rows_n = 40; + let mut x = Array::zeros((2, rows_n)); + x.row_mut(0) + .assign(&Array::linspace(1 as _, rows_n as _, rows_n)); + x.row_mut(1) + .assign(&Array::from_iter((0..rows_n).map(|i| (i as f64).powi(2)))); + + x + }; + let result = Array::<_, Dim<[_; 2]>>::filtfilt( + b.view(), + a.view(), + x, + Some(1), + Some(FiltFiltPad::default()), + ) + .expect("Could not filtfilt none_pad"); + let expected = array![ + [ + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., + 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., + 35., 36., 37., 38., 39., 40. + ], + [ + 0., 4.96, 10.98, 18.18, 26.44, 35.96, 47.1, 60.12, 75.12, 92.12, 111.12, 132.12, + 155.12, 180.12, 207.12, 236.12, 267.12, 300.12, 335.12, 372.12, 411.12, 452.12, + 495.12, 540.12, 587.12, 636.12, 687.12, 740.12, 795.12, 852.12, 911.12, 972.12, + 1035.12, 1100.1, 1166.96, 1235.44, 1305.18, 1375.98, 1447.96, 1521. + ] + ]; + + Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_relative_eq!(r, e, max_relative = 1e-6, epsilon = 1e-10)); + } + + /// Tests that is an error if the specified padding is a lot longer than the array. + #[test] + fn filfilt_2d_fir_limit() { + let b = array![0.1, 0.2, 0.1, -0.3, 0.2, 0.4, 0.2, 0.1]; + let a = array![1.]; + let x = { + let rows_n = 4; + let mut x = Array::zeros((2, rows_n)); + x.row_mut(0) + .assign(&Array::linspace(1 as _, rows_n as _, rows_n)); + x.row_mut(1) + .assign(&Array::from_iter((0..rows_n).map(|i| (i as f64).powi(2)))); + + x + }; + let result = Array::<_, Dim<[_; 2]>>::filtfilt( + b.view(), + a.view(), + x, + Some(1), + Some(FiltFiltPad::default()), + ); + + assert!(result.is_err()); + } } From bd5506bf6faac25962261f68886f7ec835c1326e Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Wed, 22 Oct 2025 14:42:02 +0800 Subject: [PATCH 67/68] Add filtfilt implementation specifically for 2-dimensional input This is still valid only for FIR as lfilter is designed specifically for FIR windows so far. --- sci-rs/src/signal/filter/filtfilt.rs | 138 ++++++++++++++++++++++++++- 1 file changed, 136 insertions(+), 2 deletions(-) diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs index e6b64562..b0741113 100644 --- a/sci-rs/src/signal/filter/filtfilt.rs +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -1,4 +1,8 @@ -use super::{axis_slice_unsafe, check_and_get_axis_dyn}; +use super::arraytools::{ + axis_reverse_unsafe, axis_slice_unsafe, check_and_get_axis_dyn, ndarray_shape_as_array_st, +}; +use super::lfilter::LFilter; +use super::lfilter_zi::lfilter_zi_dyn; use alloc::{vec, vec::Vec}; use core::ops::{Add, Sub}; use ndarray::{ @@ -242,6 +246,41 @@ where /// * y : `Array` /// The filtered output with the same shape as `x`. /// + /// # Example + /// The following examples shows how to use an arbitrary FIR filter on a 2-dimensional input + /// `x`. + /// ``` + /// use sci_rs::signal::filter::{FiltFilt, FiltFiltPad}; + /// use ndarray::{array, Array2, ArrayView2}; + /// let x = array![ + /// [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], + /// [0., 1., 4., 9., 16., 25., 36., 49., 64., 81.] + /// ]; + /// let b = array![0.5, 0.4, 0.1]; + /// let a = array![1.]; + /// let _ = ArrayView2::filtfilt( + /// b.view(), + /// a.view(), + /// x.view(), // Pass x by reference + /// Some(1), + /// Some(FiltFiltPad::default())).unwrap(); + /// let result = Array2::filtfilt( + /// b.view(), + /// a.view(), + /// x, // Pass x by value + /// Some(1), + /// Some(FiltFiltPad::default())).unwrap(); + /// + /// use approx::assert_relative_eq; + /// use ndarray::Zip; + /// let expected = array![ + /// [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], + /// [0., 1.78, 4.88, 9.88, 16.88, 25.88, 36.88, 49.88, 64.78, 81.] + /// ]; + /// Zip::from(&result).and(&expected) + /// .for_each(|&r, &e| assert_relative_eq!(r, e, max_relative = 1e-6)); + /// ``` + /// /// # See Also /// sosfiltfilt, lfilter_zi, lfilter, lfiltic, savgol_filter, sosfilt, filtfilt_gust /// @@ -257,7 +296,11 @@ where x: Self, axis: Option, padding: Option, - ) -> Result>>; + ) -> Result>> + where + T: Clone + Add + Sub + num_traits::One, + Dim<[Ix; N]>: Dimension, + T: nalgebra::RealField + Copy + core::iter::Sum; // From lfilter_zi_dyn /// Forward-back IIR filter that uses Gustafsson's method. /// @@ -275,6 +318,97 @@ where todo!("Gust method of FiltFilt is not yet implemented."); } } + +impl FiltFilt for ArrayBase> +where + S: Data, +{ + fn filtfilt<'a>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + padding: Option, + ) -> Result>> + where + T: Clone + Add + Sub + num_traits::One, + Dim<[Ix; 2]>: Dimension, + T: nalgebra::RealField + Copy + core::iter::Sum, // From lfilter_zi_dyn + { + let axis = { + if axis.is_some_and(|axis| { + !(if axis < 0 { + axis.unsigned_abs() <= 2 + } else { + axis.unsigned_abs() < 2 + }) + }) { + return Err(Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + }); + } + + // We make a best effort to convert into appropriate usize. + let axis: isize = axis.unwrap_or(-1); + if axis >= 0 { + axis.unsigned_abs() + } else { + a.ndim() + .checked_add_signed(axis) + .expect("Invalid add to `axis` option") + } + }; + let (edge, ext) = validate_pad(padding, x.view(), axis, a.len().max(b.len()))?; + + let zi: Array> = { + let mut zi = lfilter_zi_dyn(b.as_slice().unwrap(), a.as_slice().unwrap()); + let mut sh = [1; 2]; + sh[axis] = zi.len(); // .size()? + + zi.into_shape_with_order(sh) + .map_err(|_| Error::InvalidArg { + arg: "b/a".into(), + reason: "Generated lfilter_zi from given b or a resulted in an error.".into(), + })? + }; + let (y, _) = { + let x0 = axis_slice_unsafe(&ext, None, Some(1), None, axis, ext.ndim())?; + let zi_arg = zi.clone() * x0; // Is it possible to not need to clone? + ArrayBase::<_, Dim<[Ix; 2]>>::lfilter( + b.view(), + a.view(), + ext, + Some(axis as _), + Some(zi_arg.view()), + )? + }; + + let (y, _) = { + let y0 = axis_slice_unsafe(&y, Some(-1), None, None, axis, y.ndim())?; + let zi_arg = zi * y0; // originally zi * y0 + ArrayView::>::lfilter( + b.view(), + a.view(), + unsafe { axis_reverse_unsafe(&y, axis, 2) }, + Some(axis as _), + Some(zi_arg.view()), + )? + }; + + let y = unsafe { axis_reverse_unsafe(&y, axis, 2) }; + + if edge > 0 { + let y = unsafe { + axis_slice_unsafe(&y, Some(edge as _), Some(-(edge as isize)), None, axis, 2) + }?; + Ok(y.to_owned()) + } else { + Ok(y.to_owned()) + } + } +} + #[cfg(test)] mod test { use super::*; From 53b7e8ed638dacb6248aab81d2142ab85e5a7121 Mon Sep 17 00:00:00 2001 From: SpookyYomo <48710653+SpookyYomo@users.noreply.github.com> Date: Thu, 23 Oct 2025 16:29:29 +0800 Subject: [PATCH 68/68] Change Filtfilt to use macros for 1 to 6 dimensions Similar to lfilter, using macros to help generalize to multiple dimensions. Correspondingly add a test case for 1-dim input. --- sci-rs/src/signal/filter/filtfilt.rs | 226 +++++++++++++++++---------- 1 file changed, 143 insertions(+), 83 deletions(-) diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs index b0741113..fc41d9be 100644 --- a/sci-rs/src/signal/filter/filtfilt.rs +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -319,96 +319,95 @@ where } } -impl FiltFilt for ArrayBase> -where - S: Data, -{ - fn filtfilt<'a>( - b: ArrayView1<'a, T>, - a: ArrayView1<'a, T>, - x: Self, - axis: Option, - padding: Option, - ) -> Result>> - where - T: Clone + Add + Sub + num_traits::One, - Dim<[Ix; 2]>: Dimension, - T: nalgebra::RealField + Copy + core::iter::Sum, // From lfilter_zi_dyn - { - let axis = { - if axis.is_some_and(|axis| { - !(if axis < 0 { - axis.unsigned_abs() <= 2 - } else { - axis.unsigned_abs() < 2 - }) - }) { - return Err(Error::InvalidArg { +macro_rules! filtfilt_for_dim { + ($N: literal) => { + impl FiltFilt for ArrayBase> + where + S: Data, + { + fn filtfilt<'a>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + padding: Option, + ) -> Result>> + where + T: Clone + Add + Sub + num_traits::One, + Dim<[Ix; $N]>: Dimension, + T: nalgebra::RealField + Copy + core::iter::Sum, // From lfilter_zi_dyn + { + let axis = check_and_get_axis_dyn(axis, &x).map_err(|_| Error::InvalidArg { arg: "axis".into(), reason: "index out of range.".into(), - }); - } - - // We make a best effort to convert into appropriate usize. - let axis: isize = axis.unwrap_or(-1); - if axis >= 0 { - axis.unsigned_abs() - } else { - a.ndim() - .checked_add_signed(axis) - .expect("Invalid add to `axis` option") - } - }; - let (edge, ext) = validate_pad(padding, x.view(), axis, a.len().max(b.len()))?; - - let zi: Array> = { - let mut zi = lfilter_zi_dyn(b.as_slice().unwrap(), a.as_slice().unwrap()); - let mut sh = [1; 2]; - sh[axis] = zi.len(); // .size()? - - zi.into_shape_with_order(sh) - .map_err(|_| Error::InvalidArg { - arg: "b/a".into(), - reason: "Generated lfilter_zi from given b or a resulted in an error.".into(), - })? - }; - let (y, _) = { - let x0 = axis_slice_unsafe(&ext, None, Some(1), None, axis, ext.ndim())?; - let zi_arg = zi.clone() * x0; // Is it possible to not need to clone? - ArrayBase::<_, Dim<[Ix; 2]>>::lfilter( - b.view(), - a.view(), - ext, - Some(axis as _), - Some(zi_arg.view()), - )? - }; - - let (y, _) = { - let y0 = axis_slice_unsafe(&y, Some(-1), None, None, axis, y.ndim())?; - let zi_arg = zi * y0; // originally zi * y0 - ArrayView::>::lfilter( - b.view(), - a.view(), - unsafe { axis_reverse_unsafe(&y, axis, 2) }, - Some(axis as _), - Some(zi_arg.view()), - )? - }; + })?; + let (edge, ext) = validate_pad(padding, x.view(), axis, a.len().max(b.len()))?; + + let zi: Array> = { + let mut zi = lfilter_zi_dyn(b.as_slice().unwrap(), a.as_slice().unwrap()); + let mut sh = [1; $N]; + sh[axis] = zi.len(); // .size()? + + zi.into_shape_with_order(sh) + .map_err(|_| Error::InvalidArg { + arg: "b/a".into(), + reason: "Generated lfilter_zi from given b or a resulted in an error." + .into(), + })? + }; + let (y, _) = { + let x0 = axis_slice_unsafe(&ext, None, Some(1), None, axis, ext.ndim())?; + let zi_arg = zi.clone() * x0; // Is it possible to not need to clone? + ArrayBase::<_, Dim<[Ix; $N]>>::lfilter( + b.view(), + a.view(), + ext, + Some(axis as _), + Some(zi_arg.view()), + )? + }; - let y = unsafe { axis_reverse_unsafe(&y, axis, 2) }; + let (y, _) = { + let y0 = axis_slice_unsafe(&y, Some(-1), None, None, axis, y.ndim())?; + let zi_arg = zi * y0; // originally zi * y0 + ArrayView::>::lfilter( + b.view(), + a.view(), + unsafe { axis_reverse_unsafe(&y, axis, $N) }, + Some(axis as _), + Some(zi_arg.view()), + )? + }; - if edge > 0 { - let y = unsafe { - axis_slice_unsafe(&y, Some(edge as _), Some(-(edge as isize)), None, axis, 2) - }?; - Ok(y.to_owned()) - } else { - Ok(y.to_owned()) + let y = unsafe { axis_reverse_unsafe(&y, axis, $N) }; + + if edge > 0 { + let y = unsafe { + axis_slice_unsafe( + &y, + Some(edge as _), + Some(-(edge as isize)), + None, + axis, + $N, + ) + }?; + Ok(y.to_owned()) + } else { + Ok(y.to_owned()) + } + } } - } + }; } +filtfilt_for_dim!(1); +filtfilt_for_dim!(2); +filtfilt_for_dim!(3); +filtfilt_for_dim!(4); +filtfilt_for_dim!(5); +filtfilt_for_dim!(6); + #[cfg(test)] mod test { use super::*; @@ -596,6 +595,67 @@ mod test { assert_eq!(result, expected); } + /// Tests that filtfilt works with default padding with a FIR filter. + #[test] + fn filtfilt_1d_fir_default_pad_small() { + let x = + array![0., 0.6389613, 0.890577, 0.9830277, 0.9992535, 0.9756868, 0.9304659, 0.8734051]; + let b = array![0.5, 0.5]; + let a = array![1.]; + let result = Array::<_, Dim<[_; 1]>>::filtfilt( + b.view(), + a.view(), + x, + None, + Some(FiltFiltPad::default()), + ) + .expect("Could not filtfilt none_pad"); + let expected = + array![0., 0.5421249, 0.8507858, 0.9639715, 0.9893054, 0.9702733, 0.9275059, 0.8734051]; + Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_relative_eq!(r, e, max_relative = 1e-6, epsilon = 1e-10)); + } + + /// Tests that filtfilt works with default padding with a FIR filter. + #[test] + fn filtfilt_1d_fir_default_pad_big() { + // n_elems = 25 + // x = np.sin(np.log(np.linspace(1., n_elems, n_elems))) + // b = firwin(8, 0.2) + // a = np.array([1.]) + // expected = filtfilt(b, a, x) + + let x = array![ + 0., 0.6389613, 0.890577, 0.9830277, 0.9992535, 0.9756868, 0.9304659, 0.8734051, + 0.8101266, 0.7439803, 0.6770137, 0.6104955, 0.5452131, 0.481649, 0.4200881, 0.3606866, + 0.3035148, 0.2485867, 0.1958789, 0.1453437, 0.0969178, 0.0505287, 0.0060984, + -0.0364531, -0.0772063 + ]; + let b = array![ + 0.0087547, 0.0479489, 0.1640244, 0.279272, 0.279272, 0.1640244, 0.0479489, 0.0087547 + ]; + let a = array![1.]; + let result = Array::<_, Dim<[_; 1]>>::filtfilt( + b.view(), + a.view(), + x, + None, + Some(FiltFiltPad::default()), + ) + .expect("Could not filtfilt none_pad"); + let expected = array![ + 0., 0.3503788, 0.6340265, 0.8172474, 0.9055143, 0.9253101, 0.9036955, 0.8594274, + 0.8033733, 0.7414859, 0.6771011, 0.6121664, 0.5478511, 0.4848631, 0.4236259, 0.3643826, + 0.3072603, 0.25231, 0.1995331, 0.1488972, 0.1003401, 0.0537529, 0.0089268, -0.0345238, + -0.0772063 + ]; + + Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_relative_eq!(r, e, max_relative = 1e-5, epsilon = 1e-10)); + } + /// Tests that filtfilt works with no padding with a FIR filter. #[test] fn filtfilt_2d_fir_none_pad() {