diff --git a/sci-rs-core/Cargo.toml b/sci-rs-core/Cargo.toml new file mode 100644 index 00000000..26fe7017 --- /dev/null +++ b/sci-rs-core/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "sci-rs-core" +version = "0.0.0" +edition = "2021" +authors = ["Jacob Trueb "] +description = "Core library for sci-rs internals." +license = "MIT OR Apache-2.0" +repository = "https://github.com/qsib-cbie/sci-rs.git" +homepage = "https://github.com/qsib-cbie/sci-rs.git" +readme = "../README.md" +keywords = ["scipy", "dsp", "signal", "filter", "design"] +categories = ["science", "mathematics", "no-std", "embedded"] + + +[package.metadata.docs.rs] +all-features = true + +[features] +default = ['alloc'] + +# Allow allocating vecs, matrices, etc. +alloc = [] + +# Enable FFT and standard library features +std = ['alloc'] + +[dependencies] +ndarray = { version = "0.16.1", default-features = false } +ndarray-conv = { version = "0.5.0" } +num-traits = { version = "0.2.15", default-features = false } diff --git a/sci-rs-core/src/lib.rs b/sci-rs-core/src/lib.rs new file mode 100644 index 00000000..f52e09a4 --- /dev/null +++ b/sci-rs-core/src/lib.rs @@ -0,0 +1,79 @@ +//! Core library for sci-rs. + +#![cfg_attr(not(feature = "std"), no_std)] + +#[cfg(feature = "alloc")] +extern crate alloc; +#[cfg(feature = "alloc")] +use alloc::format; + +use core::{error, fmt}; + +pub type Result = core::result::Result; + +/// Errors raised whilst running sci-rs. +#[derive(Debug, PartialEq, Eq)] +pub enum Error { + /// Argument parsed into function were invalid. + #[cfg(feature = "alloc")] + InvalidArg { + /// The invalid arg + arg: alloc::string::String, + /// Explaining why arg is invalid. + reason: alloc::string::String, + }, + /// Argument parsed into function were invalid. + #[cfg(not(feature = "alloc"))] + InvalidArg, + /// Two or more optional arguments passed into functions conflict. + #[cfg(feature = "alloc")] + ConflictArg { + /// Explaining what arg is invalid. + reason: alloc::string::String, + }, + /// Two or more optional arguments passed into functions conflict. + #[cfg(not(feature = "alloc"))] + ConflictArg, + /// Errors raised by [ndarray_conv::Error] + #[cfg(feature = "alloc")] + Conv { reason: alloc::string::String }, + /// Errors raised by [ndarray_conv::Error] + #[cfg(not(feature = "alloc"))] + Conv, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + #[cfg(feature = "alloc")] + Error::InvalidArg { arg, reason } => + format!("Invalid Argument on arg = {} with reason = {}", arg, reason), + #[cfg(not(feature = "alloc"))] + Error::InvalidArg => + "There were invalid arguments. Reasons not shown without `alloc` feature.", + #[cfg(feature = "alloc")] + Error::ConflictArg { reason } => + format!("Conflicting Arguments with reason = {}", reason), + #[cfg(not(feature = "alloc"))] + Error::ConflictArg => + "There were conflicting arguments. Reasons not shown without `alloc` feature.", + #[cfg(feature = "alloc")] + Error::Conv { reason } => format!( + "An error occurred during the convolution from ndarray_conv with reason {}.", + reason + ), + #[cfg(not(feature = "alloc"))] + Error::Conv => "An error occurred during the convolution from ndarray_conv. Reasons not shown without `alloc` feature.", + } + ) + } +} + +impl error::Error for Error {} + +/// Collection of numpy-like functions for use by sci-rs. +/// Provide behaviour parity against Numpy, even if the types are not identical. +pub mod num_rs; diff --git a/sci-rs-core/src/num_rs/convolve/mod.rs b/sci-rs-core/src/num_rs/convolve/mod.rs new file mode 100644 index 00000000..2d80c7f7 --- /dev/null +++ b/sci-rs-core/src/num_rs/convolve/mod.rs @@ -0,0 +1,135 @@ +mod ndarray_conv_binds; + +use crate::{Error, Result}; +use alloc::string::ToString; +use ndarray::{Array1, ArrayView1}; +use ndarray_conv::{ConvExt, PaddingMode}; + +/// Convolution mode determines behavior near edges and output size +pub enum ConvolveMode { + /// Full convolution, output size is `in1.len() + in2.len() - 1` + Full, + /// Valid convolution, output size is `max(in1.len(), in2.len()) - min(in1.len(), in2.len()) + 1` + Valid, + /// Same convolution, output size is `in1.len()` + Same, +} + +/// Best effort parallel behaviour with numpy's convolve method. We take `v` as the convolution +/// kernel. +/// +/// Returns the discrete, linear convolution of two one-dimensional sequences. +/// +/// # Parameters +/// * `a` : (N,) [[array_like]]([ndarray::Array1]) +/// Signal to be (linearly) convolved. +/// * `v` : (M,) [[array_like]]([ndarray::Array1]) +/// Second one-dimensional input array. +/// * `mode` : [ConvolveMode] +/// [ConvolveMode::Full]: +/// By default, mode is 'full'. This returns the convolution at each point of overlap, with an +/// output shape of (N+M-1,). At the end-points of the convolution, the signals do not overlap +/// completely, and boundary effects may be seen. +/// +/// [ConvolveMode::Same]: +/// Mode 'same' returns output of length ``max(M, N)``. Boundary effects are still visible. +/// +/// [ConvolveMode::Valid]: +/// Mode 'valid' returns output of length ``max(M, N) - min(M, N) + 1``. The convolution +/// product is only given for points where the signals overlap completely. Values outside the +/// signal boundary have no effect. +/// +/// # Panics +/// We assume that `v` is shorter than `a`. +/// +/// # Examples +/// With [ConvolveMode::Full]: +/// ``` +/// use ndarray::array; +/// use sci_rs_core::num_rs::{ConvolveMode, convolve}; +/// +/// let a = array![1., 2., 3.]; +/// let v = array![0., 1., 0.5]; +/// +/// let expected = array![0., 1., 2.5, 4., 1.5]; +/// let result = convolve((&a).into(), (&v).into(), ConvolveMode::Full).unwrap(); +/// assert_eq!(result, expected); +/// ``` +/// With [ConvolveMode::Same]: +/// ``` +/// use ndarray::array; +/// use sci_rs_core::num_rs::{ConvolveMode, convolve}; +/// +/// let a = array![1., 2., 3.]; +/// let v = array![0., 1., 0.5]; +/// +/// let expected = array![1., 2.5, 4.]; +/// let result = convolve((&a).into(), (&v).into(), ConvolveMode::Same).unwrap(); +/// assert_eq!(result, expected); +/// ``` +/// With [ConvolveMode::Same]: +/// ``` +/// use ndarray::array; +/// use sci_rs_core::num_rs::{ConvolveMode, convolve}; +/// +/// let a = array![1., 2., 3.]; +/// let v = array![0., 1., 0.5]; +/// +/// let expected = array![2.5]; +/// let result = convolve((&a).into(), (&v).into(), ConvolveMode::Valid).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn convolve(a: ArrayView1, v: ArrayView1, mode: ConvolveMode) -> Result> +where + T: num_traits::NumAssign + core::marker::Copy, +{ + // Convolve + let result = a.conv(&v, mode.into(), PaddingMode::Zeros); + #[cfg(feature = "alloc")] + { + result.map_err(|e| Error::Conv { + reason: e.to_string(), + }) + } + #[cfg(not(feature = "alloc"))] + { + result.map_err({ Error::Conv }) + } +} + +#[cfg(test)] +mod linear_convolve { + use super::*; + use alloc::vec; + use ndarray::array; + + #[test] + fn full() { + let a = array![1., 2., 3.]; + let v = array![0., 1., 0.5]; + + let expected = array![0., 1., 2.5, 4., 1.5]; + let result = convolve((&a).into(), (&v).into(), ConvolveMode::Full).unwrap(); + assert_eq!(result, expected); + } + + #[test] + fn same() { + let a = array![1., 2., 3.]; + let v = array![0., 1., 0.5]; + + let expected = array![1., 2.5, 4.]; + let result = convolve((&a).into(), (&v).into(), ConvolveMode::Same).unwrap(); + assert_eq!(result, expected); + } + + #[test] + fn valid() { + let a = array![1., 2., 3.]; + let v = array![0., 1., 0.5]; + + let expected = array![2.5]; + let result = convolve((&a).into(), (&v).into(), ConvolveMode::Valid).unwrap(); + assert_eq!(result, expected); + } +} diff --git a/sci-rs-core/src/num_rs/convolve/ndarray_conv_binds.rs b/sci-rs-core/src/num_rs/convolve/ndarray_conv_binds.rs new file mode 100644 index 00000000..b0f38c79 --- /dev/null +++ b/sci-rs-core/src/num_rs/convolve/ndarray_conv_binds.rs @@ -0,0 +1,12 @@ +use super::ConvolveMode; +use ndarray_conv::ConvMode; + +impl From for ConvMode { + fn from(value: ConvolveMode) -> Self { + match value { + ConvolveMode::Full => ConvMode::Full, + ConvolveMode::Same => ConvMode::Same, + ConvolveMode::Valid => ConvMode::Valid, + } + } +} diff --git a/sci-rs-core/src/num_rs/mod.rs b/sci-rs-core/src/num_rs/mod.rs new file mode 100644 index 00000000..e6e9679c --- /dev/null +++ b/sci-rs-core/src/num_rs/mod.rs @@ -0,0 +1,4 @@ +#[cfg(feature = "alloc")] +mod convolve; +#[cfg(feature = "alloc")] +pub use convolve::*; diff --git a/sci-rs/Cargo.toml b/sci-rs/Cargo.toml index 697e7933..99647819 100644 --- a/sci-rs/Cargo.toml +++ b/sci-rs/Cargo.toml @@ -19,10 +19,10 @@ all-features = true default = ['alloc'] # Allow allocating vecs, matrices, etc. -alloc = ['nalgebra/alloc', 'nalgebra/libm', 'kalmanfilt/alloc'] +alloc = ['nalgebra/alloc', 'nalgebra/libm', 'kalmanfilt/alloc', 'sci-rs-core/alloc'] # Enable FFT and standard library features -std = ['nalgebra/std', 'nalgebra/macros', 'rustfft', 'alloc'] +std = ['nalgebra/std', 'nalgebra/macros', 'rustfft', 'alloc','sci-rs-core/std'] # Enable debug plotting through python system calls plot = ['std'] @@ -36,6 +36,7 @@ lstsq = { version = "0.6.0", default-features = false } rustfft = { version = "6.2.0", optional = true } kalmanfilt = { version = "0.3.0", default-features = false } gaussfilt = { version = "0.1.3", default-features = false } +sci-rs-core = { path = "../sci-rs-core", default-features = false } [dev-dependencies] approx = "0.5.1" diff --git a/sci-rs/src/signal/convolve.rs b/sci-rs/src/signal/convolve.rs index b527805e..7e0486b5 100644 --- a/sci-rs/src/signal/convolve.rs +++ b/sci-rs/src/signal/convolve.rs @@ -2,15 +2,7 @@ use nalgebra::Complex; use num_traits::{Float, FromPrimitive, Signed, Zero}; use rustfft::{FftNum, FftPlanner}; -/// Convolution mode determines behavior near edges and output size -pub enum ConvolveMode { - /// Full convolution, output size is `in1.len() + in2.len() - 1` - Full, - /// Valid convolution, output size is `max(in1.len(), in2.len()) - min(in1.len(), in2.len()) + 1` - Valid, - /// Same convolution, output size is `in1.len()` - Same, -} +pub use sci_rs_core::num_rs::ConvolveMode; /// Performs FFT-based convolution on two slices of floating point values. /// diff --git a/sci-rs/src/signal/filter/arraytools.rs b/sci-rs/src/signal/filter/arraytools.rs new file mode 100644 index 00000000..c86e71ec --- /dev/null +++ b/sci-rs/src/signal/filter/arraytools.rs @@ -0,0 +1,392 @@ +//! Functions for acting on a axis of an array. +//! +//! Designed for ndarrays; with scipy's internal nomenclature. + +use alloc::{vec, vec::Vec}; +use ndarray::{ + ArrayBase, ArrayView, Axis, Data, Dim, Dimension, IntoDimension, Ix, RemoveAxis, SliceArg, + SliceInfo, SliceInfoElem, +}; +use sci_rs_core::{Error, Result}; + +/// Internal function for casting into [Axis] and appropriate usize from isize. +/// +/// # Parameters +/// axis: The user-specificed axis which filter is to be applied on. +/// x: The input-data whose axis object that will be manipulated against. +/// +/// # Notes +/// Const nature of this function means error has to be manually created. +#[inline] +pub(crate) const fn check_and_get_axis_st<'a, T, S, const N: usize>( + axis: Option, + x: &ArrayBase>, +) -> core::result::Result +where + S: Data + 'a, +{ + // Before we convert into the appropriate axis object, we have to check at runtime that the + // axis value specified is within -N <= axis < N. + match axis { + None => (), + Some(axis) if axis.is_negative() => { + if axis.unsigned_abs() > N { + return Err(()); + } + } + Some(axis) => { + if axis.unsigned_abs() >= N { + return Err(()); + } + } + } + + // We make a best effort to convert into appropriate axis object. + let axis_inner: isize = match axis { + Some(axis) => axis, + None => -1, + }; + if axis_inner >= 0 { + Ok(axis_inner.unsigned_abs()) + } else { + let axis_inner = N + .checked_add_signed(axis_inner) + .expect("Invalid add to `axis` option"); + Ok(axis_inner) + } +} + +/// Internal function for casting into [Axis] and appropriate usize from isize. +/// [check_and_get_axis_st] but without const, especially for IxDyn arrays. +/// +/// # Parameters +/// axis: The user-specificed axis which filter is to be applied on. +/// x: The input-data whose axis object that will be manipulated against. +#[inline] +pub(crate) fn check_and_get_axis_dyn<'a, T, S, D>( + axis: Option, + x: &ArrayBase, +) -> Result +where + D: Dimension, + S: Data + 'a, +{ + let ndim = D::NDIM.unwrap_or(x.ndim()); + // Before we convert into the appropriate axis object, we have to check at runtime that the + // axis value specified is within -N <= axis < N. + if axis.is_some_and(|axis| { + !(if axis < 0 { + axis.unsigned_abs() <= ndim + } else { + axis.unsigned_abs() < ndim + }) + }) { + return Err(Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + }); + } + + // We make a best effort to convert into appropriate axis object. + let axis_inner: isize = axis.unwrap_or(-1); + if axis_inner >= 0 { + Ok(axis_inner.unsigned_abs()) + } else { + let axis_inner = ndim + .checked_add_signed(axis_inner) + .expect("Invalid add to `axis` option"); + Ok(axis_inner) + } +} + +/// Internal function for obtaining length of all axis as array from input from input. +/// +/// This is almost the same as `a.shape()`, but is a array `[T; N]` instead of a slice `&[T]`. +/// +/// # Parameters +/// `a`: Array whose shape is needed as a slice. +pub(crate) fn ndarray_shape_as_array_st<'a, S, T, const N: usize>( + a: &ArrayBase>, +) -> [Ix; N] +where + [Ix; N]: IntoDimension>, + Dim<[Ix; N]>: RemoveAxis, + S: Data + 'a, +{ + a.shape().try_into().expect("Could not cast shape to array") +} + +/// Takes a slice along `axis` from `a`. +/// +/// # Parameters +/// * `a`: Array being sliced from. +/// * `start`: `Option`. None defaults to 0. +/// * `end`: `Option`. +/// * `step`: `Option`. None default to 1. +/// * `axis`: `Option`. None defaults to -1. +/// +/// # Errors +/// - Axis is out of bounds. +/// +/// # Panics +/// - Start/stop elements are out of bounds. +pub fn axis_slice( + a: &ArrayBase, + start: Option, + end: Option, + step: Option, + axis: Option, +) -> Result> +where + S: Data, + D: Dimension, + SliceInfo, D, D>: SliceArg, +{ + let ndim = D::NDIM.unwrap_or(a.ndim()); + + let axis = { + if axis.is_some_and(|axis| { + !(if axis < 0 { + axis.unsigned_abs() <= ndim + } else { + axis.unsigned_abs() < ndim + }) + }) { + return Err(Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + }); + } + + // We make a best effort to convert into appropriate usize. + let axis: isize = axis.unwrap_or(-1); + if axis >= 0 { + axis.unsigned_abs() + } else { + a.ndim() + .checked_add_signed(axis) + .expect("Invalid add to `axis` option") + } + }; + + unsafe { axis_slice_unsafe(a, start, end, step, axis, ndim) } +} + +/// Takes a slice along `axis` from `a`. +/// +/// Assumes that the specified axis is within bounds. +/// +/// # Parameters +/// * `a`: Array being sliced from. +/// * `start`: `Option`. None defaults to 0. +/// * `end`: `Option`. +/// * `step`: `Option`. None default to 1. +/// * `axis`: `usize`. +/// * `a_ndim`: Dimensionality of `a`. This strictly has to be `a.ndim()`. +/// +/// # Panics +/// - Axis is out of bounds. +/// - Start/stop elements are out of bounds. +pub(crate) fn axis_slice_unsafe( + a: &ArrayBase, + start: Option, + end: Option, + step: Option, + axis: usize, + a_ndim: usize, +) -> Result> +where + S: Data, + D: Dimension, + SliceInfo, D, D>: SliceArg, +{ + debug_assert!(if a_ndim != 0 { + axis < a_ndim // Eg: A 1D-array should only have axis = 0 + } else { + axis <= a_ndim // Allow for axis = 0 when ndim = 0. + }); + + let axis_len = a.shape()[axis] as isize; + let step = step.unwrap_or(1); + + let coerce = |idx: Option, def_pos: isize, def_neg: isize| -> isize { + match idx { + Some(i) if i.is_negative() => (axis_len + i), + Some(i) => i.min(axis_len), + None => { + if !step.is_negative() { + def_pos + } else { + def_neg + } + } + } + }; + let (start, end) = { + let mut start = coerce(start, 0, axis_len - 1); + let mut end = coerce(end, axis_len, -1); + if step.is_negative() { + (end + 1, Some(start + 1)) + } else { + (start, Some(end)) // No + 1 breaking into axis_len + } + }; + + let sl = SliceInfo::<_, D, D>::try_from({ + let mut tmp = vec![SliceInfoElem::from(..); a_ndim]; + tmp[axis] = SliceInfoElem::Slice { start, end, step }; + + tmp + }) + .unwrap(); + + Ok(a.slice(&sl)) +} + +/// Reverse the 1-D slices (aka lanes) of `a` along axis `axis`. +/// +/// Returns axis_slice(a, step=-1, axis=axis). +/// +/// # Parameters +/// * `a`: Array being sliced from. +/// * `axis`: `Option`. None defaults to -1. +pub fn axis_reverse( + a: &ArrayBase, + axis: Option, +) -> Result> +where + S: Data, + D: Dimension, + SliceInfo, D, D>: SliceArg, +{ + let ndim = D::NDIM.unwrap_or(a.ndim()); + + let axis = { + if axis.is_some_and(|axis| { + !(if axis < 0 { + axis.unsigned_abs() <= ndim + } else { + axis.unsigned_abs() < ndim + }) + }) { + return Err(Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + }); + } + + // We make a best effort to convert into appropriate usize. + let axis: isize = axis.unwrap_or(-1); + if axis >= 0 { + axis.unsigned_abs() + } else { + a.ndim() + .checked_add_signed(axis) + .expect("Invalid add to `axis` option") + } + }; + + unsafe { axis_slice_unsafe(a, None, None, Some(-1), axis, ndim) } +} + +/// Reverse the 1-D slices (aka lanes) of `a` along axis `axis`. +/// +/// Returns axis_slice(a, step=-1, axis=axis). +/// +/// # Parameters +/// * `a`: Array being sliced from. +/// * `axis`: `usize`. +/// * `a_ndim`: Dimensionality of `a`. This strictly has to be `a.ndim()`. +/// +/// # Panics +/// If axis is out of bounds, and dimensions are wrong. +#[inline] +pub(crate) unsafe fn axis_reverse_unsafe( + a: &ArrayBase, + axis: usize, + a_ndim: usize, +) -> ArrayView<'_, A, D> +where + S: Data, + D: Dimension, + SliceInfo, D, D>: SliceArg, +{ + unsafe { + let r = axis_slice_unsafe(a, None, None, Some(-1), axis, a_ndim); + debug_assert!(r.is_ok()); + r.unwrap_unchecked() + } +} + +#[cfg(test)] +mod test { + use super::*; + use ndarray::{array, Array, ArrayD, IxDyn}; + + /// Tests on IxN arrays. + #[test] + fn axis_slice_doc() { + let a = array![[1, 2, 3], [4, 5, 6], [7, 8, 9]]; + + assert_eq!( + axis_slice(&a, Some(0), Some(1), Some(1), Some(1)).unwrap(), + array![[1], [4], [7]] + ); + assert_eq!( + axis_slice(&a, Some(0), Some(2), Some(1), Some(0)).unwrap(), + array![[1, 2, 3], [4, 5, 6]] + ); + } + + /// Tests on IxN arrays with negative step. + #[test] + fn axis_slice_neg_step() { + let a = array![[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]]; + assert_eq!( + axis_slice(&a, Some(2), Some(0), Some(-1), None).unwrap(), + array![[3, 2], [4, 1]] + ); + assert_eq!( + axis_slice(&a, Some(-2), Some(-4), Some(-1), None).unwrap(), + array![[4, 3], [9, 4]] + ); + } + + /// Tests on IxN arrays with negative indices. + #[test] + fn axis_slice_neg_indices_weird() { + let a = array![1, 2, 3, 4]; + assert_eq!( + unsafe { axis_slice(&a, Some(-2), Some(-5), Some(-1), None) }.unwrap(), + array![3, 2, 1] + ); + assert_eq!( + unsafe { axis_slice_unsafe(&a, Some(-2), Some(-5), Some(-1), 0, a.ndim()) }.unwrap(), + array![3, 2, 1] + ); + } + + /// Test on IxDyn Arrays. + #[test] + fn axis_slice_doc_dyn() { + let a = { + let mut y: Array<_, IxDyn> = ArrayD::::zeros(IxDyn(&[2, 3])); + y[[0, 0]] = 5; + y[[0, 1]] = 6; + y[[0, 2]] = 7; + y[[1, 0]] = 1; + y[[1, 1]] = 2; + y[[1, 2]] = 3; + + y + }; + + assert_eq!( + axis_slice(&a, Some(0), Some(1), Some(1), Some(1)) + .unwrap() + .into_dimensionality() + .unwrap(), + array![[5], [1]] + ); + } +} diff --git a/sci-rs/src/signal/filter/ext.rs b/sci-rs/src/signal/filter/ext.rs index 26654625..b9b16a13 100644 --- a/sci-rs/src/signal/filter/ext.rs +++ b/sci-rs/src/signal/filter/ext.rs @@ -3,9 +3,10 @@ use core::ops::Sub; use nalgebra::{allocator::Allocator, DefaultAllocator, Dim, Dyn, OMatrix, Scalar}; use num_traits::{One, Zero}; -/// /// Pad types. /// +/// Used by [super::sosfiltfilt]. +/// This differs from [super::FiltFiltPadType] which has different semantics. pub enum Pad { /// No padding. None, @@ -20,9 +21,7 @@ pub enum Pad { Constant, } -/// -/// Pad an array. -/// +/// Pad an [nalgebra] array. pub fn pad( padtype: Pad, mut padlen: Option, @@ -89,9 +88,9 @@ where } } +/// Pad an [nalgebra] array with odd extension. /// -/// Pad an array with odd extension. -/// +// This differs from [super::FiltFiltPadType]'s ext that acts on [ndarray]. pub fn odd_ext_dyn(x: OMatrix, n: usize, axis: usize) -> OMatrix where T: Scalar + Copy + Zero + One + Sub, diff --git a/sci-rs/src/signal/filter/filtfilt.rs b/sci-rs/src/signal/filter/filtfilt.rs new file mode 100644 index 00000000..fc41d9be --- /dev/null +++ b/sci-rs/src/signal/filter/filtfilt.rs @@ -0,0 +1,803 @@ +use super::arraytools::{ + axis_reverse_unsafe, axis_slice_unsafe, check_and_get_axis_dyn, ndarray_shape_as_array_st, +}; +use super::lfilter::LFilter; +use super::lfilter_zi::lfilter_zi_dyn; +use alloc::{vec, vec::Vec}; +use core::ops::{Add, Sub}; +use ndarray::{ + Array, ArrayBase, ArrayView, ArrayView1, Axis, CowArray, Data, Dim, Dimension, Ix, RawData, + RemoveAxis, SliceArg, SliceInfo, SliceInfoElem, +}; +use sci_rs_core::{Error, Result}; + +/// Padding utilised in [FiltFilt::filtfilt]. +// WARN: Related/Duplicate: [super::Pad]. +#[derive(Copy, Clone, Default)] +pub enum FiltFiltPadType { + /// Odd extensions + #[default] + Odd, + /// Even extensions + Even, + /// Constant extensions + Const, +} + +impl FiltFiltPadType { + /// Extensions on ndarrays. + /// + /// # Parameters + /// `self`: Type of extension. + /// `x`: Array to extend on. + /// `n`: The number of elements by which to extend `x` at each end of the axis. + /// `axis`: The axis along which to extend `x`. + /// + /// ## Type of extension + /// * odd: Odd extension at the boundaries of an array, generating a new ndarray by making an + /// odd extension of `x` along the specified axis. + /// * even: Even extension at the boundaries of an array, generating a new ndarray by making an + /// even extension of `x` along the specified axis. + /// * const: Constant extension at the boundaries of an array, generating a new ndarray by + /// making an constant extension of `x` along the specified axis. + fn ext(&self, x: ArrayBase, n: usize, axis: Option) -> Result> + where + T: Clone + Add + Sub + num_traits::One, + S: Data, + D: Dimension + RemoveAxis, + SliceInfo, D, D>: SliceArg, + { + if n < 1 { + return Ok(x.to_owned()); + } + + let ndim = D::NDIM.unwrap_or(x.ndim()); + + let axis = check_and_get_axis_dyn(axis, &x).map_err(|_| Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + })?; + + { + let axis_len = x.shape()[axis]; + if n >= axis_len { + return Err(Error::InvalidArg { + arg: "n".into(), + reason: "Extension of array cannot be longer than array in specified axis." + .into(), + }); + } + } + + match self { + FiltFiltPadType::Odd => { + let left_end = + unsafe { axis_slice_unsafe(&x, Some(0), Some(1), None, axis, ndim) }?; + let left_ext = unsafe { + axis_slice_unsafe(&x, Some(n as isize), Some(0), Some(-1), axis, ndim) + }?; + let right_end = unsafe { axis_slice_unsafe(&x, Some(-1), None, None, axis, ndim) }?; + let right_ext = unsafe { + axis_slice_unsafe(&x, Some(-2), Some(-2 - (n as isize)), Some(-1), axis, ndim) + }?; + + let ll = left_end.to_owned().add(left_end).sub(left_ext); + let rr = right_end.to_owned().add(right_end).sub(right_ext); + + ndarray::concatenate(Axis(axis), &[ll.view(), x.view(), rr.view()]).map_err(|_| { + Error::InvalidArg { + arg: "x".into(), + reason: "Shape Error".into(), + } + }) + } + FiltFiltPadType::Even => { + let left_ext = unsafe { + axis_slice_unsafe(&x, Some(n as isize), Some(0), Some(-1), axis, ndim) + }?; + let right_ext = unsafe { + axis_slice_unsafe(&x, Some(-2), Some(-2 - (n as isize)), Some(-1), axis, ndim) + }?; + + ndarray::concatenate(Axis(axis), &[left_ext.view(), x.view(), right_ext.view()]) + .map_err(|_| Error::InvalidArg { + arg: "x".into(), + reason: "Shape Error".into(), + }) + } + FiltFiltPadType::Const => { + let ones: Array = Array::ones({ + let mut t = vec![1; ndim]; + t[axis] = n; + ndarray::IxDyn(&t) + }) + .into_dimensionality() // This is needed for IxDyn -> IxN + .map_err(|_| Error::InvalidArg { + arg: "x".into(), + reason: "Coercing into identical dimensionality had issue".into(), + })?; + + let left_ext = { + let left_end = + unsafe { axis_slice_unsafe(&x, Some(0), Some(1), None, axis, ndim) }?; + ones.clone() * left_end + }; + + let right_ext = { + let right_end = + unsafe { axis_slice_unsafe(&x, Some(-1), None, None, axis, ndim) }?; + ones * right_end + }; + + ndarray::concatenate(Axis(axis), &[left_ext.view(), x.view(), right_ext.view()]) + .map_err(|_| Error::InvalidArg { + arg: "x".into(), + reason: "Shape Error".into(), + }) + } + } + } +} + +/// Arguments for [FiltFilt::filtfilt]. +#[derive(Copy, Clone, Default)] +pub struct FiltFiltPad { + /// Padding type. + pub pad_type: FiltFiltPadType, + /// Length of padding. + pub len: Option, +} + +/// Helper for validating padding of filtfilt. +/// +/// # Parameters +/// `pad`: `Option` +/// A none value from the user specifies no padding, which implies that the pad len is also 0. +/// Otherwise, the user specifies a specific padding and pad_len. +/// `x`: NDArray +/// Array that is being filtered. +/// `axis`: usize +/// Axis of `x` which is being filtered. +/// `ntaps`: usize +/// This simply is `max(a.len(), b.len())`. +// In an ideal world, this would be shoe-horned into FiltFiltPad's len parameter. +/// +/// # Panics +/// `axis` is as acting on `x` is assumed to be valid, otherwise panics. +fn validate_pad( + pad: Option, + x: ArrayView, + axis: usize, + ntaps: usize, +) -> Result<(usize, CowArray)> +where + T: Clone + Add + Sub + num_traits::One, + D: Dimension + RemoveAxis, + SliceInfo, D, D>: SliceArg, +{ + let edge = match pad { + None => 0, + Some(FiltFiltPad { len, .. }) => len.unwrap_or(ntaps * 3), + }; + + { + x.shape().get(axis).ok_or(Error::InvalidArg { + arg: "axis".into(), + reason: "The length of the input vector x must be greater than padlen.".into(), + })?; + } + + let ext = if let Some(FiltFiltPad { pad_type, .. }) = pad { + CowArray::from(pad_type.ext(x, edge, Some(axis as _))?) + } else { + CowArray::from(x) + }; + + Ok((edge, ext)) +} + +/// Implement filtfilt for fixed dimension of input array `x`. +/// +/// Valid only from 1 to 6 dimensional arrays. +/// +/// Note: FiltFilt gust is a separate function not yet implemented. +// Note: Usage of trait and macro for implementation is an inherited from LFilter. +// LFilter for supertrait? +pub trait FiltFilt +where + S: Data, +{ + /// Apply a digital filter forward and backward to a signal. + /// + /// This function applies a linear digital filter twice, once forward and + /// once backwards. The combined filter has zero phase and a filter order + /// twice that of the original. + /// + /// The function provides options for handling the edges of the signal. + /// + /// The function `sosfiltfilt` (and filter design using ``output='sos'``) + /// should be preferred over `filtfilt` for most filtering tasks, as + /// second-order sections have fewer numerical problems. + /// + /// # Parameters + /// * `b`: (N,) array_like + /// The numerator coefficient vector of the filter. + /// * `a`: (N,) array_like + /// The denominator coefficient vector of the filter. If ``a[0]`` + /// is not 1, then both `a` and `b` are normalized b:y ``a[0]``. + /// * `x`: array_like + /// The array of data to be filtered. + /// * `axis`: int, optional + /// The axis of `x` to which the filter is applied. + /// Default is -1. + /// * `pad` + /// [Option::None] here denotes a deliberate absence of padding. + /// * `padtype` [FiltFiltPadType] + /// Must be 'odd', 'even', 'constant', or None. + /// This determines the type of extension to use for the padded signal to which the filter is applied. + /// The default is 'odd'. + /// * `padlen` int or None, optional + /// The number of elements by which to extend `x` at both ends of `axis` before applying + /// the filter. + /// This value must be less than ``x.shape[axis] - 1``. ``padlen=0`` implies no padding. [Option::None] here denotes the default value. + /// The default value is ``3 * max(len(a), len(b))``. + /// + /// # Returns + /// * y : `Array` + /// The filtered output with the same shape as `x`. + /// + /// # Example + /// The following examples shows how to use an arbitrary FIR filter on a 2-dimensional input + /// `x`. + /// ``` + /// use sci_rs::signal::filter::{FiltFilt, FiltFiltPad}; + /// use ndarray::{array, Array2, ArrayView2}; + /// let x = array![ + /// [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], + /// [0., 1., 4., 9., 16., 25., 36., 49., 64., 81.] + /// ]; + /// let b = array![0.5, 0.4, 0.1]; + /// let a = array![1.]; + /// let _ = ArrayView2::filtfilt( + /// b.view(), + /// a.view(), + /// x.view(), // Pass x by reference + /// Some(1), + /// Some(FiltFiltPad::default())).unwrap(); + /// let result = Array2::filtfilt( + /// b.view(), + /// a.view(), + /// x, // Pass x by value + /// Some(1), + /// Some(FiltFiltPad::default())).unwrap(); + /// + /// use approx::assert_relative_eq; + /// use ndarray::Zip; + /// let expected = array![ + /// [1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], + /// [0., 1.78, 4.88, 9.88, 16.88, 25.88, 36.88, 49.88, 64.78, 81.] + /// ]; + /// Zip::from(&result).and(&expected) + /// .for_each(|&r, &e| assert_relative_eq!(r, e, max_relative = 1e-6)); + /// ``` + /// + /// # See Also + /// sosfiltfilt, lfilter_zi, lfilter, lfiltic, savgol_filter, sosfilt, filtfilt_gust + /// + /// # Notes + /// When `method` is "pad", the function pads the data along the given axis in one of three + /// ways: odd, even or constant. The odd and even extensions have the corresponding symmetry + /// about the end point of the data. The constant extension extends the data with the values + /// at the end points. On both the forward and backward passes, the initial condition of the + /// filter is found by using `lfilter_zi` and scaling it by the end point of the extended data. + fn filtfilt<'a>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + padding: Option, + ) -> Result>> + where + T: Clone + Add + Sub + num_traits::One, + Dim<[Ix; N]>: Dimension, + T: nalgebra::RealField + Copy + core::iter::Sum; // From lfilter_zi_dyn + + /// Forward-back IIR filter that uses Gustafsson's method. + /// + /// Not yet implemented. + fn filtfilt_gust<'a>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + irlen: Option, + ) -> Result>> + where + Self: Sized, + { + todo!("Gust method of FiltFilt is not yet implemented."); + } +} + +macro_rules! filtfilt_for_dim { + ($N: literal) => { + impl FiltFilt for ArrayBase> + where + S: Data, + { + fn filtfilt<'a>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + padding: Option, + ) -> Result>> + where + T: Clone + Add + Sub + num_traits::One, + Dim<[Ix; $N]>: Dimension, + T: nalgebra::RealField + Copy + core::iter::Sum, // From lfilter_zi_dyn + { + let axis = check_and_get_axis_dyn(axis, &x).map_err(|_| Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + })?; + let (edge, ext) = validate_pad(padding, x.view(), axis, a.len().max(b.len()))?; + + let zi: Array> = { + let mut zi = lfilter_zi_dyn(b.as_slice().unwrap(), a.as_slice().unwrap()); + let mut sh = [1; $N]; + sh[axis] = zi.len(); // .size()? + + zi.into_shape_with_order(sh) + .map_err(|_| Error::InvalidArg { + arg: "b/a".into(), + reason: "Generated lfilter_zi from given b or a resulted in an error." + .into(), + })? + }; + let (y, _) = { + let x0 = axis_slice_unsafe(&ext, None, Some(1), None, axis, ext.ndim())?; + let zi_arg = zi.clone() * x0; // Is it possible to not need to clone? + ArrayBase::<_, Dim<[Ix; $N]>>::lfilter( + b.view(), + a.view(), + ext, + Some(axis as _), + Some(zi_arg.view()), + )? + }; + + let (y, _) = { + let y0 = axis_slice_unsafe(&y, Some(-1), None, None, axis, y.ndim())?; + let zi_arg = zi * y0; // originally zi * y0 + ArrayView::>::lfilter( + b.view(), + a.view(), + unsafe { axis_reverse_unsafe(&y, axis, $N) }, + Some(axis as _), + Some(zi_arg.view()), + )? + }; + + let y = unsafe { axis_reverse_unsafe(&y, axis, $N) }; + + if edge > 0 { + let y = unsafe { + axis_slice_unsafe( + &y, + Some(edge as _), + Some(-(edge as isize)), + None, + axis, + $N, + ) + }?; + Ok(y.to_owned()) + } else { + Ok(y.to_owned()) + } + } + } + }; +} + +filtfilt_for_dim!(1); +filtfilt_for_dim!(2); +filtfilt_for_dim!(3); +filtfilt_for_dim!(4); +filtfilt_for_dim!(5); +filtfilt_for_dim!(6); + +#[cfg(test)] +mod test { + use super::*; + use alloc::vec; + use approx::assert_relative_eq; + use ndarray::{array, Zip}; + + /// Test odd_ext as from documentation. + #[test] + fn odd_ext_doc() { + let odd = FiltFiltPadType::Odd; + let a = array![[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]]; + + let result = odd.ext(a.view(), 2, None).expect("Could not get odd_ext."); + let expected = array![ + [-1, 0, 1, 2, 3, 4, 5, 6, 7], + [-4, -1, 0, 1, 4, 9, 16, 23, 28] + ]; + + ndarray::Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_eq!(r, e)); + + let result = odd + .ext(a.into_dyn(), 2, None) + .expect("Could not get odd_ext."); + ndarray::Zip::from(&result) + .and(&expected.into_dyn()) + .for_each(|&r, &e| assert_eq!(r, e)); + } + + /// Test odd_ext's limits. + #[test] + fn odd_ext_limits() { + let odd = FiltFiltPadType::Odd; + let a = array![[1, 2, 3, 4], [0, 1, 4, 9]]; + + let result = odd.ext(a.view(), 3, None); + assert!(result.is_ok()); + let result = odd.ext(a, 4, None); + assert!(result.is_err()); + } + + /// Test odd_ext as from documentation. + #[test] + fn even_ext_doc() { + let even = FiltFiltPadType::Even; + let a = array![[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]]; + + let result = even + .ext(a.view(), 2, None) + .expect("Could not get even_ext."); + let expected = array![[3, 2, 1, 2, 3, 4, 5, 4, 3], [4, 1, 0, 1, 4, 9, 16, 9, 4]]; + + ndarray::Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_eq!(r, e)); + + let result = even + .ext(a.into_dyn(), 2, None) + .expect("Could not get even_ext."); + ndarray::Zip::from(&result) + .and(&expected.into_dyn()) + .for_each(|&r, &e| assert_eq!(r, e)); + } + + /// Test even_ext's limits. + #[test] + fn even_ext_limits() { + let even = FiltFiltPadType::Even; + let a = array![[1, 2, 3, 4], [0, 1, 4, 9]]; + + let result = even.ext(a.view(), 3, None); + assert!(result.is_ok()); + let result = even.ext(a, 4, None); + assert!(result.is_err()); + } + + /// Test const_ext as from documentation. + #[test] + fn const_ext_doc() { + let const_ext = FiltFiltPadType::Const; + let a = array![[1, 2, 3, 4, 5], [0, 1, 4, 9, 16]]; + + let result = const_ext + .ext(a.view(), 2, None) + .expect("Could not get even_ext."); + let expected = array![[1, 1, 1, 2, 3, 4, 5, 5, 5], [0, 0, 0, 1, 4, 9, 16, 16, 16]]; + + ndarray::Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_eq!(r, e)); + + let result = const_ext + .ext(a.into_dyn(), 2, None) + .expect("Could not get even_ext."); + ndarray::Zip::from(&result) + .and(&expected.into_dyn()) + .for_each(|&r, &e| assert_eq!(r, e)); + } + + /// Test const_ext's limits. + #[test] + fn const_ext_limits() { + let const_ext = FiltFiltPadType::Const; + let a = array![[1, 2, 3, 4], [0, 1, 4, 9]]; + + let result = const_ext.ext(a.view(), 3, None); + assert!(result.is_ok()); + let result = const_ext.ext(a, 4, None); + assert!(result.is_err()); + } + + /// Tests for when there is no padding. + #[test] + fn pad_none() { + let p = None; + let x = array![1]; + let result = validate_pad(p, x.view(), 0, 0).expect("Could not pad with none."); + + assert_eq!(result.0, 0); + assert_eq!(result.1, x); + } + + /// Tests for when there is even padding. + #[test] + fn pad_even() { + let p = Some(FiltFiltPad { + pad_type: FiltFiltPadType::Even, + len: Some(2), + }); + let x = array![[1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 4, 9, 16, 25, 36, 49]]; + let (result_edge, result) = + validate_pad(p, x.view(), 1, 2).expect("Could not pad with even."); + let expected_edge = 2; + let expected = array![ + [3, 2, 1, 2, 3, 4, 5, 6, 7, 8, 7, 6], + [4, 1, 0, 1, 4, 9, 16, 25, 36, 49, 36, 25] + ]; + + assert_eq!(result_edge, expected_edge); + assert_eq!(result, expected); + ndarray::Zip::from(&result) + .and(&expected) + .for_each(|r, e| assert_eq!(r, e)); + } + + /// Tests for when there is even padding. + #[test] + fn pad_odd() { + let p = Some(FiltFiltPad { + pad_type: FiltFiltPadType::Odd, + len: Some(2), + }); + let x = array![[1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 4, 9, 16, 25, 36, 49]]; + let (result_edge, result) = + validate_pad(p, x.view(), 1, 2).expect("Could not pad with odd."); + let expected_edge = 2; + let expected = array![ + [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [-4, -1, 0, 1, 4, 9, 16, 25, 36, 49, 62, 73] + ]; + + assert_eq!(result_edge, expected_edge); + assert_eq!(result, expected); + } + + /// Tests for when there is const padding. + #[test] + fn pad_const() { + let p = Some(FiltFiltPad { + pad_type: FiltFiltPadType::Const, + len: Some(2), + }); + let x = array![[1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 4, 9, 16, 25, 36, 49]]; + let (result_edge, result) = + validate_pad(p, x.view(), 1, 2).expect("Could not pad with odd."); + let expected_edge = 2; + let expected = array![ + [1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 8, 8], + [0, 0, 0, 1, 4, 9, 16, 25, 36, 49, 49, 49] + ]; + + assert_eq!(result_edge, expected_edge); + assert_eq!(result, expected); + } + + /// Tests that filtfilt works with default padding with a FIR filter. + #[test] + fn filtfilt_1d_fir_default_pad_small() { + let x = + array![0., 0.6389613, 0.890577, 0.9830277, 0.9992535, 0.9756868, 0.9304659, 0.8734051]; + let b = array![0.5, 0.5]; + let a = array![1.]; + let result = Array::<_, Dim<[_; 1]>>::filtfilt( + b.view(), + a.view(), + x, + None, + Some(FiltFiltPad::default()), + ) + .expect("Could not filtfilt none_pad"); + let expected = + array![0., 0.5421249, 0.8507858, 0.9639715, 0.9893054, 0.9702733, 0.9275059, 0.8734051]; + Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_relative_eq!(r, e, max_relative = 1e-6, epsilon = 1e-10)); + } + + /// Tests that filtfilt works with default padding with a FIR filter. + #[test] + fn filtfilt_1d_fir_default_pad_big() { + // n_elems = 25 + // x = np.sin(np.log(np.linspace(1., n_elems, n_elems))) + // b = firwin(8, 0.2) + // a = np.array([1.]) + // expected = filtfilt(b, a, x) + + let x = array![ + 0., 0.6389613, 0.890577, 0.9830277, 0.9992535, 0.9756868, 0.9304659, 0.8734051, + 0.8101266, 0.7439803, 0.6770137, 0.6104955, 0.5452131, 0.481649, 0.4200881, 0.3606866, + 0.3035148, 0.2485867, 0.1958789, 0.1453437, 0.0969178, 0.0505287, 0.0060984, + -0.0364531, -0.0772063 + ]; + let b = array![ + 0.0087547, 0.0479489, 0.1640244, 0.279272, 0.279272, 0.1640244, 0.0479489, 0.0087547 + ]; + let a = array![1.]; + let result = Array::<_, Dim<[_; 1]>>::filtfilt( + b.view(), + a.view(), + x, + None, + Some(FiltFiltPad::default()), + ) + .expect("Could not filtfilt none_pad"); + let expected = array![ + 0., 0.3503788, 0.6340265, 0.8172474, 0.9055143, 0.9253101, 0.9036955, 0.8594274, + 0.8033733, 0.7414859, 0.6771011, 0.6121664, 0.5478511, 0.4848631, 0.4236259, 0.3643826, + 0.3072603, 0.25231, 0.1995331, 0.1488972, 0.1003401, 0.0537529, 0.0089268, -0.0345238, + -0.0772063 + ]; + + Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_relative_eq!(r, e, max_relative = 1e-5, epsilon = 1e-10)); + } + + /// Tests that filtfilt works with no padding with a FIR filter. + #[test] + fn filtfilt_2d_fir_none_pad() { + let b = array![0.1, 0.2, 0.1, -0.3, 0.2, 0.4, 0.2, 0.1]; + let a = array![1.]; + let x = { + let rows_n = 40; + let mut x = Array::zeros((2, rows_n)); + x.row_mut(0) + .assign(&Array::linspace(1 as _, rows_n as _, rows_n)); + x.row_mut(1) + .assign(&Array::from_iter((0..rows_n).map(|i| (i as f64).powi(2)))); + + x + }; + let result = Array::<_, Dim<[_; 2]>>::filtfilt(b.view(), a.view(), x, Some(1), None) + .expect("Could not filtfilt none_pad"); + let expected = array![ + [ + 2.14, 2.84, 3.67, 4.43, 5.2, 6.06, 7.01, 8., 9., 10., 11., 12., 13., 14., 15., 16., + 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., + 33., 33.9, 34.6, 34.9, 35., 35.4, 35.7, 35.8 + ], + [ + 5.56, 8.54, 13.05, 19.15, 26.78, 36.04, 47.11, 60.12, 75.12, 92.12, 111.12, 132.12, + 155.12, 180.12, 207.12, 236.12, 267.12, 300.12, 335.12, 372.12, 411.12, 452.12, + 495.12, 540.12, 587.12, 636.12, 687.12, 740.12, 795.12, 852.12, 911.12, 972.12, + 1035.12, 1093.06, 1138.68, 1157.46, 1162.72, 1189.36, 1209.74, 1216.6 + ] + ]; + + Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_relative_eq!(r, e, max_relative = 1e-6)); + } + + /// Tests that filtfilt works with some padding with a FIR filter. + #[test] + fn filtfilt_2d_fir_some_pad() { + let b = array![0.1, 0.2, 0.1, -0.3, 0.2, 0.4, 0.2, 0.1]; + let a = array![1.]; + let x = { + let rows_n = 40; + let mut x = Array::zeros((2, rows_n)); + x.row_mut(0) + .assign(&Array::linspace(1 as _, rows_n as _, rows_n)); + x.row_mut(1) + .assign(&Array::from_iter((0..rows_n).map(|i| (i as f64).powi(2)))); + + x + }; + let pad_arg = FiltFiltPad { + pad_type: FiltFiltPadType::default(), + len: Some(4), + }; + let result = + Array::<_, Dim<[_; 2]>>::filtfilt(b.view(), a.view(), x, Some(1), Some(pad_arg)) + .expect("Could not filtfilt none_pad"); + let expected = array![ + [ + 1.2, 2.06, 3.01, 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., + 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., + 34., 35., 36., 37., 37.9, 38.6, 38.9 + ], + [ + 1.94, 5.52, 11.07, 18.18, 26.44, 35.96, 47.1, 60.12, 75.12, 92.12, 111.12, 132.12, + 155.12, 180.12, 207.12, 236.12, 267.12, 300.12, 335.12, 372.12, 411.12, 452.12, + 495.12, 540.12, 587.12, 636.12, 687.12, 740.12, 795.12, 852.12, 911.12, 972.12, + 1035.12, 1100.1, 1166.96, 1235.44, 1305.18, 1368.54, 1418.2, 1439.28 + ] + ]; + + Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_relative_eq!(r, e, max_relative = 1e-6)); + } + + /// Tests that filtfilt works with default padding with a FIR filter. + #[test] + fn filtfilt_2d_fir_default_pad() { + let b = array![0.1, 0.2, 0.1, -0.3, 0.2, 0.4, 0.2, 0.1]; + let a = array![1.]; + let x = { + let rows_n = 40; + let mut x = Array::zeros((2, rows_n)); + x.row_mut(0) + .assign(&Array::linspace(1 as _, rows_n as _, rows_n)); + x.row_mut(1) + .assign(&Array::from_iter((0..rows_n).map(|i| (i as f64).powi(2)))); + + x + }; + let result = Array::<_, Dim<[_; 2]>>::filtfilt( + b.view(), + a.view(), + x, + Some(1), + Some(FiltFiltPad::default()), + ) + .expect("Could not filtfilt none_pad"); + let expected = array![ + [ + 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., + 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34., + 35., 36., 37., 38., 39., 40. + ], + [ + 0., 4.96, 10.98, 18.18, 26.44, 35.96, 47.1, 60.12, 75.12, 92.12, 111.12, 132.12, + 155.12, 180.12, 207.12, 236.12, 267.12, 300.12, 335.12, 372.12, 411.12, 452.12, + 495.12, 540.12, 587.12, 636.12, 687.12, 740.12, 795.12, 852.12, 911.12, 972.12, + 1035.12, 1100.1, 1166.96, 1235.44, 1305.18, 1375.98, 1447.96, 1521. + ] + ]; + + Zip::from(&result) + .and(&expected) + .for_each(|&r, &e| assert_relative_eq!(r, e, max_relative = 1e-6, epsilon = 1e-10)); + } + + /// Tests that is an error if the specified padding is a lot longer than the array. + #[test] + fn filfilt_2d_fir_limit() { + let b = array![0.1, 0.2, 0.1, -0.3, 0.2, 0.4, 0.2, 0.1]; + let a = array![1.]; + let x = { + let rows_n = 4; + let mut x = Array::zeros((2, rows_n)); + x.row_mut(0) + .assign(&Array::linspace(1 as _, rows_n as _, rows_n)); + x.row_mut(1) + .assign(&Array::from_iter((0..rows_n).map(|i| (i as f64).powi(2)))); + + x + }; + let result = Array::<_, Dim<[_; 2]>>::filtfilt( + b.view(), + a.view(), + x, + Some(1), + Some(FiltFiltPad::default()), + ); + + assert!(result.is_err()); + } +} diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs new file mode 100644 index 00000000..32abdfb5 --- /dev/null +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -0,0 +1,992 @@ +use super::arraytools::{check_and_get_axis_dyn, check_and_get_axis_st, ndarray_shape_as_array_st}; +use alloc::{vec, vec::Vec}; +use core::marker::Copy; +use ndarray::{ + Array, Array1, ArrayBase, ArrayD, ArrayView, ArrayView1, Axis, Data, Dim, Dimension, + IntoDimension, Ix, IxDyn, ShapeBuilder, SliceArg, SliceInfo, SliceInfoElem, +}; +use num_traits::{FromPrimitive, Num, NumAssign}; +use sci_rs_core::{Error, Result}; + +type LFilterResult = (Array>, Option>>); +type LFilterDynResult = (Array, Option>); + +/// Implement lfilter for fixed dimension of input array `x`. +/// +/// Valid only from 1 to 6 dimensional arrays. +pub trait LFilter +where + S: Data, +{ + /// Filter data `x` along one-dimension with an IIR or FIR filter. + /// + /// Filter a data sequence, `x`, using a digital filter. This works for many + /// fundamental data types (including Object type). The filter is a direct + /// form II transposed implementation of the standard difference equation + /// (see Notes). + /// + /// The function [super::sosfilt_dyn] (and filter design using ``output='sos'``) should be + /// preferred over `lfilter` for most filtering tasks, as second-order sections + /// have fewer numerical problems. + /// + /// ## Parameters + /// * `b` : array_like + /// The numerator coefficient vector in a 1-D sequence. + /// * `a` : array_like + /// The denominator coefficient vector in a 1-D sequence. If ``a[0]`` + /// is not 1, then both `a` and `b` are normalized by ``a[0]``. + /// * `x` : array_like + /// An N-dimensional input array. + /// * `axis`: `Option` + /// Default to `-1` if `None`. + /// Panics in accordance with [ndarray::ArrayBase::axis_iter]. + /// * `zi`: array_like + /// Currently not implemented. + /// Initial conditions for filter delays. It is a vector + /// (or array of vectors for an N-dimensional input) of length + /// ``max(len(a), len(b)) - 1``. If `zi` is None or is not given then + /// initial rest is assumed. See `lfiltic` and [super::lfilter_zi_dyn] for more information. + /// + /// ## Returns + /// * `y` : array + /// The output of the digital filter. + /// * `zf` : array, optional + /// If `zi` is None, this is not returned, otherwise, `zf` holds the + /// final filter delay values. + /// + /// # See Also + /// * [super::lfilter_zi_dyn] + /// + /// # Notes + /// For compile time reasons, lfilter is implemented per ArrayN at the moment. + /// + /// # Examples + /// On a 1-dimensional signal: + /// ``` + /// use ndarray::{array, ArrayBase, Array1, ArrayView1, Dim, Ix, OwnedRepr}; + /// use sci_rs::signal::filter::LFilter; + /// + /// let b = array![5., 4., 1., 2.]; + /// let a = array![1.]; + /// let x = array![1., 2., 3., 4., 3., 5., 6.]; + /// let expected = array![5., 14., 24., 36., 38., 47., 61.]; + /// let (result, _) = ArrayView1::lfilter((&b).into(), (&a).into(), (&x).into(), None, None).unwrap(); // By ref + /// + /// assert_eq!(result.len(), expected.len()); + /// result.into_iter().zip(expected).for_each(|(r, e)| { + /// assert_eq!(r, e); + /// }); + /// + /// let (result, _) = Array1::lfilter((&b).into(), (&a).into(), x, None, None).unwrap(); // By value + /// ``` + /// + /// # Panics + /// Currently yet to implement for `a.len() > 1`. + // NOTE: zi's TypeSig inherits from lfilter's output, in accordance with examples section of + // documentation, both lfilter_zi and this should eventually support NDArray. + fn lfilter<'a>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + zi: Option>>, + ) -> Result> + where + T: NumAssign + FromPrimitive + Copy + 'a, + S: Data + 'a; +} + +macro_rules! lfilter_for_dim { + ($N:literal) => { + impl LFilter for ArrayBase> + where + S: Data, + { + fn lfilter<'a>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + zi: Option>>, + ) -> Result<(Array>, Option>>)> + where + T: NumAssign + FromPrimitive + Copy + 'a, + S: 'a, + { + if a.len() > 1 { + return linear_filter(b, a, x, axis, zi); + }; + + let (axis, axis_inner) = { + let ax = check_and_get_axis_st(axis, &x) + .map_err(|_| Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + })?; + (Axis(ax), ax) + }; + + if a.is_empty() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: + "Empty 1D array will result in inf/nan result. Consider setting to `array![1.]`." + .into(), + }); + } else if a.first().unwrap().is_zero() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: "First element of a found to be zero.".into(), + }); + } + let b: Array1 = b.mapv(|bi| bi / a[0]); // b /= a[0] + + if let Some(zii) = zi { + // Use a separate branch to avoid unnecessary heap allocation of `out_full` in `zi` = None + // case. + let mut zi = zii.reborrow(); + + // if zi.ndim != x.ndim { return Err(...) } is signature asserted. + + let mut expected_shape: [usize; $N] = x.shape().try_into().unwrap(); + *expected_shape // expected_shape[axis] = b.shape[0] - 1 + .get_mut(axis_inner) + .expect("invalid axis_inner") = b + .shape() + .first() + .expect("Could not get 0th axis len of b") + .checked_sub(1) + .expect("underflowing subtract"); + + if *zi.shape() != expected_shape { + let strides: [Ix; $N] = { + let zi_shape = zi.shape(); + let zi_strides = zi.strides(); + + // Waiting for try_collect() from nightly... we use this Vec> -> Result> method.. + let tmp_heap: Vec> = (0..$N) + .map(|k| { + if zi_shape[k] == expected_shape[k] { + zi_strides[k].try_into().map_err(|_| Error::InvalidArg { + arg: "zi".into(), + reason: "zi found with negative stride".into(), + }) + } else if k != axis_inner && zi_shape[k] == 1 { + Ok(0) + } else { + Err(Error::InvalidArg { + arg: "zi".into(), + reason: "Unexpected shape for parameter zi".into(), + }) + } + }) + .collect(); + let tmp_heap: Result> = tmp_heap.into_iter().collect(); + + tmp_heap?.try_into().unwrap() + }; + + zi = ArrayView::from_shape(expected_shape.strides(strides), zii.as_slice().unwrap()) + .unwrap(); + }; + + let (out_full_dim, out_full_dim_inner): (Dim<_>, [Ix; $N]) = { + let mut tmp: [Ix; $N] = ndarray_shape_as_array_st(&x); + tmp[axis_inner] += b.len_of(Axis(0)) - 1; // From np.convolve(..., 'full') + (IntoDimension::into_dimension(tmp), tmp) + }; + + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out_full = unsafe { Array::uninit(out_full_dim).assume_init() }; + out_full + .lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_full_slice, y)| { + // np.convolve uses full mode by default + // ```py + // out_full = np.apply_along_axis(lambda y: np.convolve(b, y), axis, x) + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + convolve(y, (&b).into(), ConvolveMode::Full)? + .assign_to(&mut out_full_slice); + Ok(()) + })?; + + // ```py + // ind[axis] = slice(zi.shape[axis]) + // out_full[tuple(ind)] += zi + // ``` + { + let slice_info: SliceInfo<_, Dim<[Ix; $N]>, Dim<[Ix; $N]>> = { + let t = zi.shape()[axis_inner]; + let mut tmp = [SliceInfoElem::from(..); $N]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: 0, + end: Some(t as isize), + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; // Does not work because unless N: N<=6 cannot be bounded on type_sig + let mut s = out_full.slice_mut(&slice_info); + s += &zi; + } + + let (out_dim, out_dim_inner) = { + let tmp: [Ix; $N] = ndarray_shape_as_array_st(&x); + (IntoDimension::into_dimension(tmp), tmp) + }; + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; + out.lanes_mut(axis) + .into_iter() + .zip(out_full.lanes(axis)) + .for_each(|(mut out_slice, out_full_slice)| { + // ```py + // # Create the [...; :out_full.shape[axis] - len(b) + 1; ...] at index=axis + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) + // out = out_full[tuple(ind)] + // ``` + out_full_slice + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + }); + + // ```py + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1, None) + // zf = out_full[tuple(ind)] + // ``` + let zf = { + let slice_info: SliceInfo<_, Dim<[Ix; $N]>, Dim<[Ix; $N]>> = { + let t = out_full.shape()[axis_inner] + .checked_add(1) + .unwrap() + .checked_sub(b.len()) + .unwrap(); + let mut tmp = [SliceInfoElem::from(..); $N]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: t as isize, + end: None, + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; + out_full.slice(slice_info).to_owned() + }; + + Ok((out, Some(zf))) + } else { + // In contrast to the case where zi.is_some(), we can inline a slicing operation to reduce + // one extra heap allocation. + + let (out_dim, out_dim_inner): (Dim<_>, [Ix; $N]) = { + let mut tmp: [Ix; $N] = ndarray_shape_as_array_st(&x); + (IntoDimension::into_dimension(tmp), tmp) + }; + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; + + out.lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_slice, y)| { + // np.convolve uses full mode, but is eventually slices out with + // ```py + // ind = out_full.ndim * [slice(None)] # creates the "[:, :, ..., :]" slice r + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) # [:out_full.shape[ ..] - len(b) + 1] + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + let out_full = convolve(y, (&b).into(), ConvolveMode::Full)?; + out_full + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + Ok(()) + })?; + + Ok((out, None)) + } + } + } + }; +} + +lfilter_for_dim!(1); +lfilter_for_dim!(2); +lfilter_for_dim!(3); +lfilter_for_dim!(4); +lfilter_for_dim!(5); +lfilter_for_dim!(6); + +/// Filter data `x` along one-dimension with an IIR or FIR filter. +/// +/// Filter a data sequence, `x`, using a digital filter. This works for many +/// fundamental data types (including Object type). The filter is a direct +/// form II transposed implementation of the standard difference equation +/// (see Notes). +/// +/// The function [super::sosfilt_dyn] (and filter design using ``output='sos'``) should be +/// preferred over `lfilter` for most filtering tasks, as second-order sections +/// have fewer numerical problems. +/// +/// ## Parameters +/// * `b` : array_like +/// The numerator coefficient vector in a 1-D sequence. +/// * `a` : array_like +/// The denominator coefficient vector in a 1-D sequence. If ``a[0]`` +/// is not 1, then both `a` and `b` are normalized by ``a[0]``. +/// * `x` : array_like +/// An N-dimensional input array. +/// * `axis`: `Option` +/// Default to `-1` if `None`. +/// Panics in accordance with [ndarray::ArrayBase::axis_iter]. +/// * `zi`: array_like +/// Currently not implemented. +/// Initial conditions for filter delays. It is a vector +/// (or array of vectors for an N-dimensional input) of length +/// ``max(len(a), len(b)) - 1``. If `zi` is None or is not given then +/// initial rest is assumed. See `lfiltic` and [super::lfilter_zi_dyn] for more information. +/// +/// ## Returns +/// * `y` : array +/// The output of the digital filter. +/// * `zf` : array, optional +/// If `zi` is None, this is not returned, otherwise, `zf` holds the +/// final filter delay values. +/// +/// # See Also +/// * [super::lfilter_zi_dyn] +/// +/// # Notes +/// If Array<_, IxDyn as provided by this function is not desired, consider using [LFilter]. +/// +/// # Examples +/// On a 1-dimensional signal: +/// ``` +/// use ndarray::{array, ArrayBase, Array1, ArrayView1, Dim, Ix, OwnedRepr}; +/// use sci_rs::signal::filter::lfilter; +/// +/// let b = array![5., 4., 1., 2.]; +/// let a = array![1.]; +/// let x = array![1., 2., 3., 4., 3., 5., 6.]; +/// let expected = array![5., 14., 24., 36., 38., 47., 61.]; +/// let (result, _) = lfilter((&b).into(), (&a).into(), x.view(), None, None).unwrap(); // By ref +/// +/// assert_eq!(result.len(), expected.len()); +/// result.into_iter().zip(expected).for_each(|(r, e)| { +/// assert_eq!(r, e); +/// }); +/// +/// let (result, _) = lfilter((&b).into(), (&a).into(), x.clone().into_dyn(), None, None).unwrap(); // Dynamic arrays +/// let (result, _) = lfilter((&b).into(), (&a).into(), x, None, None).unwrap(); // By value +/// ``` +/// +/// # Panics +/// Currently yet to implement for `a.len() > 1`. +// NOTE: zi's TypeSig inherits from lfilter's output, in accordance with examples section of +// documentation, both lfilter_zi and this should eventually support NDArray. +pub fn lfilter<'a, T, S, D>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: ArrayBase, + axis: Option, + zi: Option>, +) -> Result> +where + S: Data + 'a, + T: NumAssign + FromPrimitive + Copy + 'a, + D: Dimension, + SliceInfo, D, D>: SliceArg, +{ + let ndim = D::NDIM.unwrap_or(x.ndim()); + + if ndim == 0 { + return Err(Error::InvalidArg { + arg: "x".into(), + reason: "Linear filter requires at least 1-dimensional `x`.".into(), + }); + } + + if a.len() > 1 { + todo!(); + }; + + let (axis, axis_inner) = { + let ax = check_and_get_axis_dyn(axis, &x)?; + (Axis(ax), ax) + }; + + if a.is_empty() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: + "Empty 1D array will result in inf/nan result. Consider setting to `array![1.]`." + .into(), + }); + } else if a.first().unwrap().is_zero() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: "First element of a found to be zero.".into(), + }); + } + let b: Array1 = b.mapv(|bi| bi / a[0]); // b /= a[0] + + if let Some(zii) = zi { + // Use a separate branch to avoid unnecessary heap allocation of `out_full` in `zi` = None + // case. + let mut zi = zii.clone().reborrow().into_dyn(); + + // if zi.ndim != x.ndim { return Err(...) } is signature asserted. + + let mut expected_shape: Vec = x.shape().to_vec(); + *expected_shape // expected_shape[axis] = b.shape[0] - 1 + .get_mut(axis_inner) + .expect("invalid axis_inner") = b + .shape() + .first() + .expect("Could not get 0th axis len of b") + .checked_sub(1) + .expect("underflowing subtract"); + + if *zi.shape() != expected_shape { + let strides: Vec = { + let zi_shape = zi.shape(); + let zi_strides = zi.strides(); + + // Waiting for try_collect() from nightly... we use this Vec> -> Result> method.. + let tmp_heap: Vec> = (0..ndim) + .map(|k| { + if zi_shape[k] == expected_shape[k] { + zi_strides[k].try_into().map_err(|_| Error::InvalidArg { + arg: "zi".into(), + reason: "zi found with negative stride".into(), + }) + } else if k != axis_inner && zi_shape[k] == 1 { + Ok(0) + } else { + Err(Error::InvalidArg { + arg: "zi".into(), + reason: "Unexpected shape for parameter zi".into(), + }) + } + }) + .collect(); + let tmp_heap: Result> = tmp_heap.into_iter().collect(); + + tmp_heap? + }; + + // ArrayView::from_shape(strides, + // zi.as_slice_memory_order().unwrap()).unwrap().to_owned() + zi = ArrayView::from_shape((expected_shape).strides(strides), zii.as_slice().unwrap()) + .unwrap(); + }; + + let (out_full_dim, out_full_dim_inner): (Dim<_>, Vec) = { + let mut tmp = x.shape().to_vec(); + tmp[axis_inner] += b.len_of(Axis(0)) - 1; // From np.convolve(..., 'full') + (IntoDimension::into_dimension(tmp.as_ref()), tmp) + }; + + let mut out_full = ArrayD::::zeros(out_full_dim); + out_full + .lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_full_slice, y)| { + // np.convolve uses full mode by default + // ```py + // out_full = np.apply_along_axis(lambda y: np.convolve(b, y), axis, x) + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + convolve(y, (&b).into(), ConvolveMode::Full)?.assign_to(&mut out_full_slice); + Ok(()) + })?; + + // ```py + // ind[axis] = slice(zi.shape[axis]) + // out_full[tuple(ind)] += zi + // ``` + { + let slice_info: SliceInfo<_, D, D> = { + let t = zi.shape()[axis_inner]; + let mut tmp = vec![SliceInfoElem::from(..); ndim]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: 0, + end: Some(t as isize), + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; // Does not work because unless N: N<=6 cannot be bounded on type_sig + let mut s = out_full.slice_mut(&slice_info); + s += &zi; + } + + let (out_dim, out_dim_inner) = { + // let mut out_dim_inner = out_full_dim_inner; + // if let Some(inner) = out_dim_inner.get_mut(axis_inner) { + // *inner = inner + // .checked_sub({ + // // Safety: b is Array1 + // *b.shape().first().unwrap() + // }) + // // Safety: inner is defined by having added b.len() + // .unwrap() + // + 1; + // } else { + // unsafe { unreachable_unchecked() }; + // }; + // (IntoDimension::into_dimension(out_dim_inner), out_dim_inner) + let tmp = x.shape(); + (IntoDimension::into_dimension(tmp), tmp) + }; + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; + out.lanes_mut(axis) + .into_iter() + .zip(out_full.lanes(axis)) + .for_each(|(mut out_slice, out_full_slice)| { + // ```py + // # Create the [...; :out_full.shape[axis] - len(b) + 1; ...] at index=axis + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) + // out = out_full[tuple(ind)] + // ``` + out_full_slice + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + }); + + // ```py + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1, None) + // zf = out_full[tuple(ind)] + // ``` + let zf = { + let slice_info: SliceInfo<_, D, IxDyn> = { + let t = out_full.shape()[axis_inner] + .checked_add(1) + .unwrap() + .checked_sub(b.len()) + .unwrap(); + let mut tmp = vec![SliceInfoElem::from(..); ndim]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: t as isize, + end: None, + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; + out_full.slice(slice_info).to_owned() + }; + + Ok((out, Some(zf))) + } else { + // In contrast to the case where zi.is_some(), we can inline a slicing operation to reduce + // one extra heap allocation. + + let (out_dim, out_dim_inner) = { + let tmp = x.shape(); + (IntoDimension::into_dimension(tmp), tmp) + }; + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; // Safety: All elements are overwritten by convolve in subsequent step. + + out.lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_slice, y)| { + // np.convolve uses full mode, but is eventually slices out with + // ```py + // ind = out_full.ndim * [slice(None)] # creates the "[:, :, ..., :]" slice r + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) # [:out_full.shape[ ..] - len(b) + 1] + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + let out_full = convolve(y, (&b).into(), ConvolveMode::Full)?; + out_full + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + Ok(()) + })?; + + Ok((out, None)) + } +} + +/// Internal function called by [LFilter::lfilter] for situation a.len() > 1. +fn linear_filter<'a, T, S, D>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: ArrayBase, + axis: Option, + zi: Option>, +) -> Result> +where + D: Dimension, + T: 'a, + S: Data + 'a, +{ + todo!() +} + +#[cfg(test)] +mod test { + use super::*; + use alloc::vec; + use approx::assert_relative_eq; + use ndarray::{array, ArrayBase, Dim, Ix, OwnedRepr, ViewRepr}; + + // Tests that have a = [1.] with zi = None on input x with dim = 1. + #[test] + fn one_dim_fir_no_zi() { + { + // Tests for b.sum() > 1. + let b = array![5., 4., 1., 2.]; + let a = array![1.]; + let x = array![1., 2., 3., 4., 3., 5., 6.]; + let expected = array![5., 14., 24., 36., 38., 47., 61.]; + + let Ok((result, None)) = Array1::lfilter((&b).into(), (&a).into(), x, None, None) + else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_eq!(r, e); + }) + } + { + // Tests for b[i] < 0 for some i, such that b.sum() = 1. + let b = array![0.7, -0.3, 0.6]; + let a = array![1.]; + let x = array![1., 2., 3., 4., 3., 5., 6.]; + let expected = array![0.7, 1.1, 2.1, 3.1, 2.7, 5., 4.5]; + + let Ok((result, None)) = Array1::lfilter((&b).into(), (&a).into(), x, None, None) + else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + } + + #[test] + fn one_dim_fir_with_zi() { + { + // Case which does not falls into zi.shape() != expected_shape branch + let b = array![0.5, 0.4]; + let a = array![1.]; + let x = array![ + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + ]; + let zi = array![[-1.6], [1.4]]; + let expected = array![ + [-3.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7], + [-0.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7] + ]; + let expected_zi = array![[0.4], [0.4]]; + + let Ok((result, Some(r_zi))) = Array::<_, Dim<[Ix; 2]>>::lfilter( + (&b).into(), + (&a).into(), + x, + None, + Some((&zi).into()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]]; + let zi = array![[0.4], [0.45], [0.05]]; + let expected = array![ + [5.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.85, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85], + ]; + let expected_zi = array![ + [4.25, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.], + ]; + + let Ok((result, Some(r_zi))) = Array::<_, Dim<[Ix; 2]>>::lfilter( + (&b).into(), + (&a).into(), + x, + Some(0), + Some((&zi).into()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch for 3D input + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![ + [[0.2, 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]], + [[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]] + ]; + let zi = array![[[0.4], [0.45], [0.05]], [[0.6], [0.15], [0.25]]]; + let expected = array![ + [ + [1.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.53, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85] + ], + [ + [5.6, 10.6, 15.6, 20.6, 15.6, 25.6, 30.6], + [40.55, 0.95, 6.35, 1.75, 16.35, 37.15, 32.55] + ] + ]; + let expected_zi = array![ + [ + [3.45, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [7.6, -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ], + [ + [4.45, 2.25, 3.65, 4.25, 4.45, 8.05, 8.65], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ] + ]; + + let Ok((result, Some(r_zi))) = Array::<_, Dim<[Ix; 3]>>::lfilter( + (&b).into(), + (&a).into(), + x, + Some(1), + Some((&zi).into()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + } + + #[test] + fn invalid_axis() { + let b = array![5., 4., 1., 2.]; + let a = array![1.]; + let x = array![1., 2., 3., 4., 3., 5., 6.]; + + let result = ArrayView1::lfilter((&b).into(), (&a).into(), (&x).into(), Some(2), None); + assert!(result.is_err()); + + let result = Array1::lfilter((&b).into(), (&a).into(), x.clone(), Some(1), None); + assert!(result.is_err()); + + let result = Array1::lfilter((&b).into(), (&a).into(), x.clone(), Some(0), None); + assert!(result.is_ok()); + + let result = Array1::lfilter((&b).into(), (&a).into(), x.clone(), Some(-1), None); + assert!(result.is_ok()); + + let result = Array1::lfilter((&b).into(), (&a).into(), x, Some(-2), None); + assert!(result.is_err()); + } + + #[test] + fn dyn_dim_fir_with_zi() { + { + // Case which does not falls into zi.shape() != expected_shape branch + let b = array![0.5, 0.4]; + let a = array![1.]; + let x = array![ + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + ]; + let zi = array![[-1.6], [1.4]]; + let expected = array![ + [-3.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7], + [-0.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7] + ]; + let expected_zi = array![[0.4], [0.4]]; + + // Test static dim input + let Ok((result, Some(r_zi))) = + lfilter((&b).into(), (&a).into(), x.view(), None, Some((&zi).into())) + else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(&expected).for_each(|(r, &e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(&expected_zi).for_each(|(r, &e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + + // Test dyn input + let Ok((result, Some(r_zi))) = lfilter( + (&b).into(), + (&a).into(), + x.into_dyn(), + None, + Some(zi.into_dyn().view()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]]; + let zi = array![[0.4], [0.45], [0.05]]; + let expected = array![ + [5.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.85, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85], + ]; + let expected_zi = array![ + [4.25, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.], + ]; + + let Ok((result, Some(r_zi))) = lfilter( + (&b).into(), + (&a).into(), + x.into_dyn(), + Some(0), + Some(zi.into_dyn().view()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch for 3D input + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![ + [[0.2, 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]], + [[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]] + ]; + let zi = array![[[0.4], [0.45], [0.05]], [[0.6], [0.15], [0.25]]]; + let expected = array![ + [ + [1.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.53, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85] + ], + [ + [5.6, 10.6, 15.6, 20.6, 15.6, 25.6, 30.6], + [40.55, 0.95, 6.35, 1.75, 16.35, 37.15, 32.55] + ] + ]; + let expected_zi = array![ + [ + [3.45, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [7.6, -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ], + [ + [4.45, 2.25, 3.65, 4.25, 4.45, 8.05, 8.65], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ] + ]; + + let Ok((result, Some(r_zi))) = lfilter( + (&b).into(), + (&a).into(), + x.into_dyn(), + Some(1), + Some(zi.into_dyn().view()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + } +} diff --git a/sci-rs/src/signal/filter/lfilter_zi.rs b/sci-rs/src/signal/filter/lfilter_zi.rs index 6c92df5a..24d0dc83 100644 --- a/sci-rs/src/signal/filter/lfilter_zi.rs +++ b/sci-rs/src/signal/filter/lfilter_zi.rs @@ -1,5 +1,6 @@ use core::{iter::Sum, ops::SubAssign}; use nalgebra::{DMatrix, OMatrix, RealField, SMatrix, Scalar}; +use ndarray::Array1; use num_traits::{Float, One, Zero}; use crate::linalg::companion_dyn; @@ -9,26 +10,22 @@ use alloc::vec; #[cfg(feature = "alloc")] use alloc::vec::Vec; +/// Construct initial conditions for [lfilter][super::lfilter::LFilter] for step response +/// steady-state. /// -/// Construct initial conditions for lfilter for step response steady-state. -/// -/// Compute an initial state `zi` for the `lfilter` function that corresponds -/// to the steady state of the step response. +/// Compute an initial state `zi` for the `lfilter` function that corresponds to the steady state +/// of the step response. /// /// A typical use of this function is to set the initial state so that the /// output of the filter starts at the same value as the first element of /// the signal to be filtered. /// -/// /// -/// -/// #[inline] -pub fn lfilter_zi_dyn(b: &[F], a: &[F]) -> Vec +pub fn lfilter_zi_dyn(b: &[F], a: &[F]) -> Array1 where F: RealField + Copy + PartialEq + Scalar + Zero + One + Sum + SubAssign, { - assert!(b.len() == a.len()); let m = b.len(); let ai0 = a @@ -38,7 +35,7 @@ where .expect("There must be at least one nonzero `a` coefficient.") .0; - // Mormalize to a[0] == 1 + // Normalize to a[0] == 1 let mut a = a.iter().skip(ai0).cloned().collect::>(); let mut b = b.to_vec(); let a0 = a[0]; @@ -79,7 +76,7 @@ where } } - zi + zi.into() } #[cfg(test)] diff --git a/sci-rs/src/signal/filter/mod.rs b/sci-rs/src/signal/filter/mod.rs index 805c65b7..0c0dc4b5 100644 --- a/sci-rs/src/signal/filter/mod.rs +++ b/sci-rs/src/signal/filter/mod.rs @@ -13,7 +13,9 @@ pub use kalmanfilt::kalman::kalman_filter; /// pub use gaussfilt as gaussian_filter; -/// Digital IIR/FIR filter design +/// Digital IIR/FIR filter design +/// Functions located in the [`Filter design` section of +/// `scipy.signal`](https://docs.scipy.org/doc/scipy/reference/signal.html#filter-design). pub mod design; mod ext; @@ -22,6 +24,15 @@ mod sosfilt; pub use ext::*; pub use sosfilt::*; +#[cfg(feature = "alloc")] +mod arraytools; +#[cfg(feature = "alloc")] +use arraytools::*; + +#[cfg(feature = "alloc")] +mod filtfilt; +#[cfg(feature = "alloc")] +mod lfilter; #[cfg(feature = "alloc")] mod lfilter_zi; #[cfg(feature = "alloc")] @@ -31,6 +42,10 @@ mod sosfilt_zi; #[cfg(feature = "alloc")] mod sosfiltfilt; +#[cfg(feature = "alloc")] +pub use filtfilt::*; +#[cfg(feature = "alloc")] +pub use lfilter::*; #[cfg(feature = "alloc")] pub use lfilter_zi::*; #[cfg(feature = "alloc")]