diff --git a/sci-rs-core/Cargo.toml b/sci-rs-core/Cargo.toml new file mode 100644 index 00000000..26fe7017 --- /dev/null +++ b/sci-rs-core/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "sci-rs-core" +version = "0.0.0" +edition = "2021" +authors = ["Jacob Trueb "] +description = "Core library for sci-rs internals." +license = "MIT OR Apache-2.0" +repository = "https://github.com/qsib-cbie/sci-rs.git" +homepage = "https://github.com/qsib-cbie/sci-rs.git" +readme = "../README.md" +keywords = ["scipy", "dsp", "signal", "filter", "design"] +categories = ["science", "mathematics", "no-std", "embedded"] + + +[package.metadata.docs.rs] +all-features = true + +[features] +default = ['alloc'] + +# Allow allocating vecs, matrices, etc. +alloc = [] + +# Enable FFT and standard library features +std = ['alloc'] + +[dependencies] +ndarray = { version = "0.16.1", default-features = false } +ndarray-conv = { version = "0.5.0" } +num-traits = { version = "0.2.15", default-features = false } diff --git a/sci-rs-core/src/lib.rs b/sci-rs-core/src/lib.rs new file mode 100644 index 00000000..f52e09a4 --- /dev/null +++ b/sci-rs-core/src/lib.rs @@ -0,0 +1,79 @@ +//! Core library for sci-rs. + +#![cfg_attr(not(feature = "std"), no_std)] + +#[cfg(feature = "alloc")] +extern crate alloc; +#[cfg(feature = "alloc")] +use alloc::format; + +use core::{error, fmt}; + +pub type Result = core::result::Result; + +/// Errors raised whilst running sci-rs. +#[derive(Debug, PartialEq, Eq)] +pub enum Error { + /// Argument parsed into function were invalid. + #[cfg(feature = "alloc")] + InvalidArg { + /// The invalid arg + arg: alloc::string::String, + /// Explaining why arg is invalid. + reason: alloc::string::String, + }, + /// Argument parsed into function were invalid. + #[cfg(not(feature = "alloc"))] + InvalidArg, + /// Two or more optional arguments passed into functions conflict. + #[cfg(feature = "alloc")] + ConflictArg { + /// Explaining what arg is invalid. + reason: alloc::string::String, + }, + /// Two or more optional arguments passed into functions conflict. + #[cfg(not(feature = "alloc"))] + ConflictArg, + /// Errors raised by [ndarray_conv::Error] + #[cfg(feature = "alloc")] + Conv { reason: alloc::string::String }, + /// Errors raised by [ndarray_conv::Error] + #[cfg(not(feature = "alloc"))] + Conv, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + #[cfg(feature = "alloc")] + Error::InvalidArg { arg, reason } => + format!("Invalid Argument on arg = {} with reason = {}", arg, reason), + #[cfg(not(feature = "alloc"))] + Error::InvalidArg => + "There were invalid arguments. Reasons not shown without `alloc` feature.", + #[cfg(feature = "alloc")] + Error::ConflictArg { reason } => + format!("Conflicting Arguments with reason = {}", reason), + #[cfg(not(feature = "alloc"))] + Error::ConflictArg => + "There were conflicting arguments. Reasons not shown without `alloc` feature.", + #[cfg(feature = "alloc")] + Error::Conv { reason } => format!( + "An error occurred during the convolution from ndarray_conv with reason {}.", + reason + ), + #[cfg(not(feature = "alloc"))] + Error::Conv => "An error occurred during the convolution from ndarray_conv. Reasons not shown without `alloc` feature.", + } + ) + } +} + +impl error::Error for Error {} + +/// Collection of numpy-like functions for use by sci-rs. +/// Provide behaviour parity against Numpy, even if the types are not identical. +pub mod num_rs; diff --git a/sci-rs-core/src/num_rs/convolve/mod.rs b/sci-rs-core/src/num_rs/convolve/mod.rs new file mode 100644 index 00000000..2d80c7f7 --- /dev/null +++ b/sci-rs-core/src/num_rs/convolve/mod.rs @@ -0,0 +1,135 @@ +mod ndarray_conv_binds; + +use crate::{Error, Result}; +use alloc::string::ToString; +use ndarray::{Array1, ArrayView1}; +use ndarray_conv::{ConvExt, PaddingMode}; + +/// Convolution mode determines behavior near edges and output size +pub enum ConvolveMode { + /// Full convolution, output size is `in1.len() + in2.len() - 1` + Full, + /// Valid convolution, output size is `max(in1.len(), in2.len()) - min(in1.len(), in2.len()) + 1` + Valid, + /// Same convolution, output size is `in1.len()` + Same, +} + +/// Best effort parallel behaviour with numpy's convolve method. We take `v` as the convolution +/// kernel. +/// +/// Returns the discrete, linear convolution of two one-dimensional sequences. +/// +/// # Parameters +/// * `a` : (N,) [[array_like]]([ndarray::Array1]) +/// Signal to be (linearly) convolved. +/// * `v` : (M,) [[array_like]]([ndarray::Array1]) +/// Second one-dimensional input array. +/// * `mode` : [ConvolveMode] +/// [ConvolveMode::Full]: +/// By default, mode is 'full'. This returns the convolution at each point of overlap, with an +/// output shape of (N+M-1,). At the end-points of the convolution, the signals do not overlap +/// completely, and boundary effects may be seen. +/// +/// [ConvolveMode::Same]: +/// Mode 'same' returns output of length ``max(M, N)``. Boundary effects are still visible. +/// +/// [ConvolveMode::Valid]: +/// Mode 'valid' returns output of length ``max(M, N) - min(M, N) + 1``. The convolution +/// product is only given for points where the signals overlap completely. Values outside the +/// signal boundary have no effect. +/// +/// # Panics +/// We assume that `v` is shorter than `a`. +/// +/// # Examples +/// With [ConvolveMode::Full]: +/// ``` +/// use ndarray::array; +/// use sci_rs_core::num_rs::{ConvolveMode, convolve}; +/// +/// let a = array![1., 2., 3.]; +/// let v = array![0., 1., 0.5]; +/// +/// let expected = array![0., 1., 2.5, 4., 1.5]; +/// let result = convolve((&a).into(), (&v).into(), ConvolveMode::Full).unwrap(); +/// assert_eq!(result, expected); +/// ``` +/// With [ConvolveMode::Same]: +/// ``` +/// use ndarray::array; +/// use sci_rs_core::num_rs::{ConvolveMode, convolve}; +/// +/// let a = array![1., 2., 3.]; +/// let v = array![0., 1., 0.5]; +/// +/// let expected = array![1., 2.5, 4.]; +/// let result = convolve((&a).into(), (&v).into(), ConvolveMode::Same).unwrap(); +/// assert_eq!(result, expected); +/// ``` +/// With [ConvolveMode::Same]: +/// ``` +/// use ndarray::array; +/// use sci_rs_core::num_rs::{ConvolveMode, convolve}; +/// +/// let a = array![1., 2., 3.]; +/// let v = array![0., 1., 0.5]; +/// +/// let expected = array![2.5]; +/// let result = convolve((&a).into(), (&v).into(), ConvolveMode::Valid).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn convolve(a: ArrayView1, v: ArrayView1, mode: ConvolveMode) -> Result> +where + T: num_traits::NumAssign + core::marker::Copy, +{ + // Convolve + let result = a.conv(&v, mode.into(), PaddingMode::Zeros); + #[cfg(feature = "alloc")] + { + result.map_err(|e| Error::Conv { + reason: e.to_string(), + }) + } + #[cfg(not(feature = "alloc"))] + { + result.map_err({ Error::Conv }) + } +} + +#[cfg(test)] +mod linear_convolve { + use super::*; + use alloc::vec; + use ndarray::array; + + #[test] + fn full() { + let a = array![1., 2., 3.]; + let v = array![0., 1., 0.5]; + + let expected = array![0., 1., 2.5, 4., 1.5]; + let result = convolve((&a).into(), (&v).into(), ConvolveMode::Full).unwrap(); + assert_eq!(result, expected); + } + + #[test] + fn same() { + let a = array![1., 2., 3.]; + let v = array![0., 1., 0.5]; + + let expected = array![1., 2.5, 4.]; + let result = convolve((&a).into(), (&v).into(), ConvolveMode::Same).unwrap(); + assert_eq!(result, expected); + } + + #[test] + fn valid() { + let a = array![1., 2., 3.]; + let v = array![0., 1., 0.5]; + + let expected = array![2.5]; + let result = convolve((&a).into(), (&v).into(), ConvolveMode::Valid).unwrap(); + assert_eq!(result, expected); + } +} diff --git a/sci-rs-core/src/num_rs/convolve/ndarray_conv_binds.rs b/sci-rs-core/src/num_rs/convolve/ndarray_conv_binds.rs new file mode 100644 index 00000000..b0f38c79 --- /dev/null +++ b/sci-rs-core/src/num_rs/convolve/ndarray_conv_binds.rs @@ -0,0 +1,12 @@ +use super::ConvolveMode; +use ndarray_conv::ConvMode; + +impl From for ConvMode { + fn from(value: ConvolveMode) -> Self { + match value { + ConvolveMode::Full => ConvMode::Full, + ConvolveMode::Same => ConvMode::Same, + ConvolveMode::Valid => ConvMode::Valid, + } + } +} diff --git a/sci-rs-core/src/num_rs/mod.rs b/sci-rs-core/src/num_rs/mod.rs new file mode 100644 index 00000000..e6e9679c --- /dev/null +++ b/sci-rs-core/src/num_rs/mod.rs @@ -0,0 +1,4 @@ +#[cfg(feature = "alloc")] +mod convolve; +#[cfg(feature = "alloc")] +pub use convolve::*; diff --git a/sci-rs/Cargo.toml b/sci-rs/Cargo.toml index 697e7933..99647819 100644 --- a/sci-rs/Cargo.toml +++ b/sci-rs/Cargo.toml @@ -19,10 +19,10 @@ all-features = true default = ['alloc'] # Allow allocating vecs, matrices, etc. -alloc = ['nalgebra/alloc', 'nalgebra/libm', 'kalmanfilt/alloc'] +alloc = ['nalgebra/alloc', 'nalgebra/libm', 'kalmanfilt/alloc', 'sci-rs-core/alloc'] # Enable FFT and standard library features -std = ['nalgebra/std', 'nalgebra/macros', 'rustfft', 'alloc'] +std = ['nalgebra/std', 'nalgebra/macros', 'rustfft', 'alloc','sci-rs-core/std'] # Enable debug plotting through python system calls plot = ['std'] @@ -36,6 +36,7 @@ lstsq = { version = "0.6.0", default-features = false } rustfft = { version = "6.2.0", optional = true } kalmanfilt = { version = "0.3.0", default-features = false } gaussfilt = { version = "0.1.3", default-features = false } +sci-rs-core = { path = "../sci-rs-core", default-features = false } [dev-dependencies] approx = "0.5.1" diff --git a/sci-rs/src/signal/convolve.rs b/sci-rs/src/signal/convolve.rs index b527805e..7e0486b5 100644 --- a/sci-rs/src/signal/convolve.rs +++ b/sci-rs/src/signal/convolve.rs @@ -2,15 +2,7 @@ use nalgebra::Complex; use num_traits::{Float, FromPrimitive, Signed, Zero}; use rustfft::{FftNum, FftPlanner}; -/// Convolution mode determines behavior near edges and output size -pub enum ConvolveMode { - /// Full convolution, output size is `in1.len() + in2.len() - 1` - Full, - /// Valid convolution, output size is `max(in1.len(), in2.len()) - min(in1.len(), in2.len()) + 1` - Valid, - /// Same convolution, output size is `in1.len()` - Same, -} +pub use sci_rs_core::num_rs::ConvolveMode; /// Performs FFT-based convolution on two slices of floating point values. /// diff --git a/sci-rs/src/signal/filter/arraytools.rs b/sci-rs/src/signal/filter/arraytools.rs new file mode 100644 index 00000000..029ef172 --- /dev/null +++ b/sci-rs/src/signal/filter/arraytools.rs @@ -0,0 +1,113 @@ +//! Functions for acting on a axis of an array. +//! +//! Designed for ndarrays; with scipy's internal nomenclature. + +use ndarray::{ArrayBase, Axis, Data, Dim, Dimension, IntoDimension, Ix, RemoveAxis}; +use sci_rs_core::{Error, Result}; + +/// Internal function for casting into [Axis] and appropriate usize from isize. +/// +/// # Parameters +/// axis: The user-specificed axis which filter is to be applied on. +/// x: The input-data whose axis object that will be manipulated against. +/// +/// # Notes +/// Const nature of this function means error has to be manually created. +#[inline] +pub(crate) const fn check_and_get_axis_st<'a, T, S, const N: usize>( + axis: Option, + x: &ArrayBase>, +) -> core::result::Result +where + S: Data + 'a, +{ + // Before we convert into the appropriate axis object, we have to check at runtime that the + // axis value specified is within -N <= axis < N. + match axis { + None => (), + Some(axis) if axis.is_negative() => { + if axis.unsigned_abs() > N { + return Err(()); + } + } + Some(axis) => { + if axis.unsigned_abs() >= N { + return Err(()); + } + } + } + + // We make a best effort to convert into appropriate axis object. + let axis_inner: isize = match axis { + Some(axis) => axis, + None => -1, + }; + if axis_inner >= 0 { + Ok(axis_inner.unsigned_abs()) + } else { + let axis_inner = N + .checked_add_signed(axis_inner) + .expect("Invalid add to `axis` option"); + Ok(axis_inner) + } +} + +/// Internal function for casting into [Axis] and appropriate usize from isize. +/// [check_and_get_axis_st] but without const, especially for IxDyn arrays. +/// +/// # Parameters +/// axis: The user-specificed axis which filter is to be applied on. +/// x: The input-data whose axis object that will be manipulated against. +#[inline] +pub(crate) fn check_and_get_axis_dyn<'a, T, S, D>( + axis: Option, + x: &ArrayBase, +) -> Result +where + D: Dimension, + S: Data + 'a, +{ + let ndim = D::NDIM.unwrap_or(x.ndim()); + // Before we convert into the appropriate axis object, we have to check at runtime that the + // axis value specified is within -N <= axis < N. + if axis.is_some_and(|axis| { + !(if axis < 0 { + axis.unsigned_abs() <= ndim + } else { + axis.unsigned_abs() < ndim + }) + }) { + return Err(Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + }); + } + + // We make a best effort to convert into appropriate axis object. + let axis_inner: isize = axis.unwrap_or(-1); + if axis_inner >= 0 { + Ok(axis_inner.unsigned_abs()) + } else { + let axis_inner = ndim + .checked_add_signed(axis_inner) + .expect("Invalid add to `axis` option"); + Ok(axis_inner) + } +} + +/// Internal function for obtaining length of all axis as array from input from input. +/// +/// This is almost the same as `a.shape()`, but is a array `[T; N]` instead of a slice `&[T]`. +/// +/// # Parameters +/// `a`: Array whose shape is needed as a slice. +pub(crate) fn ndarray_shape_as_array_st<'a, S, T, const N: usize>( + a: &ArrayBase>, +) -> [Ix; N] +where + [Ix; N]: IntoDimension>, + Dim<[Ix; N]>: RemoveAxis, + S: Data + 'a, +{ + a.shape().try_into().expect("Could not cast shape to array") +} diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs new file mode 100644 index 00000000..32abdfb5 --- /dev/null +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -0,0 +1,992 @@ +use super::arraytools::{check_and_get_axis_dyn, check_and_get_axis_st, ndarray_shape_as_array_st}; +use alloc::{vec, vec::Vec}; +use core::marker::Copy; +use ndarray::{ + Array, Array1, ArrayBase, ArrayD, ArrayView, ArrayView1, Axis, Data, Dim, Dimension, + IntoDimension, Ix, IxDyn, ShapeBuilder, SliceArg, SliceInfo, SliceInfoElem, +}; +use num_traits::{FromPrimitive, Num, NumAssign}; +use sci_rs_core::{Error, Result}; + +type LFilterResult = (Array>, Option>>); +type LFilterDynResult = (Array, Option>); + +/// Implement lfilter for fixed dimension of input array `x`. +/// +/// Valid only from 1 to 6 dimensional arrays. +pub trait LFilter +where + S: Data, +{ + /// Filter data `x` along one-dimension with an IIR or FIR filter. + /// + /// Filter a data sequence, `x`, using a digital filter. This works for many + /// fundamental data types (including Object type). The filter is a direct + /// form II transposed implementation of the standard difference equation + /// (see Notes). + /// + /// The function [super::sosfilt_dyn] (and filter design using ``output='sos'``) should be + /// preferred over `lfilter` for most filtering tasks, as second-order sections + /// have fewer numerical problems. + /// + /// ## Parameters + /// * `b` : array_like + /// The numerator coefficient vector in a 1-D sequence. + /// * `a` : array_like + /// The denominator coefficient vector in a 1-D sequence. If ``a[0]`` + /// is not 1, then both `a` and `b` are normalized by ``a[0]``. + /// * `x` : array_like + /// An N-dimensional input array. + /// * `axis`: `Option` + /// Default to `-1` if `None`. + /// Panics in accordance with [ndarray::ArrayBase::axis_iter]. + /// * `zi`: array_like + /// Currently not implemented. + /// Initial conditions for filter delays. It is a vector + /// (or array of vectors for an N-dimensional input) of length + /// ``max(len(a), len(b)) - 1``. If `zi` is None or is not given then + /// initial rest is assumed. See `lfiltic` and [super::lfilter_zi_dyn] for more information. + /// + /// ## Returns + /// * `y` : array + /// The output of the digital filter. + /// * `zf` : array, optional + /// If `zi` is None, this is not returned, otherwise, `zf` holds the + /// final filter delay values. + /// + /// # See Also + /// * [super::lfilter_zi_dyn] + /// + /// # Notes + /// For compile time reasons, lfilter is implemented per ArrayN at the moment. + /// + /// # Examples + /// On a 1-dimensional signal: + /// ``` + /// use ndarray::{array, ArrayBase, Array1, ArrayView1, Dim, Ix, OwnedRepr}; + /// use sci_rs::signal::filter::LFilter; + /// + /// let b = array![5., 4., 1., 2.]; + /// let a = array![1.]; + /// let x = array![1., 2., 3., 4., 3., 5., 6.]; + /// let expected = array![5., 14., 24., 36., 38., 47., 61.]; + /// let (result, _) = ArrayView1::lfilter((&b).into(), (&a).into(), (&x).into(), None, None).unwrap(); // By ref + /// + /// assert_eq!(result.len(), expected.len()); + /// result.into_iter().zip(expected).for_each(|(r, e)| { + /// assert_eq!(r, e); + /// }); + /// + /// let (result, _) = Array1::lfilter((&b).into(), (&a).into(), x, None, None).unwrap(); // By value + /// ``` + /// + /// # Panics + /// Currently yet to implement for `a.len() > 1`. + // NOTE: zi's TypeSig inherits from lfilter's output, in accordance with examples section of + // documentation, both lfilter_zi and this should eventually support NDArray. + fn lfilter<'a>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + zi: Option>>, + ) -> Result> + where + T: NumAssign + FromPrimitive + Copy + 'a, + S: Data + 'a; +} + +macro_rules! lfilter_for_dim { + ($N:literal) => { + impl LFilter for ArrayBase> + where + S: Data, + { + fn lfilter<'a>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + zi: Option>>, + ) -> Result<(Array>, Option>>)> + where + T: NumAssign + FromPrimitive + Copy + 'a, + S: 'a, + { + if a.len() > 1 { + return linear_filter(b, a, x, axis, zi); + }; + + let (axis, axis_inner) = { + let ax = check_and_get_axis_st(axis, &x) + .map_err(|_| Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + })?; + (Axis(ax), ax) + }; + + if a.is_empty() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: + "Empty 1D array will result in inf/nan result. Consider setting to `array![1.]`." + .into(), + }); + } else if a.first().unwrap().is_zero() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: "First element of a found to be zero.".into(), + }); + } + let b: Array1 = b.mapv(|bi| bi / a[0]); // b /= a[0] + + if let Some(zii) = zi { + // Use a separate branch to avoid unnecessary heap allocation of `out_full` in `zi` = None + // case. + let mut zi = zii.reborrow(); + + // if zi.ndim != x.ndim { return Err(...) } is signature asserted. + + let mut expected_shape: [usize; $N] = x.shape().try_into().unwrap(); + *expected_shape // expected_shape[axis] = b.shape[0] - 1 + .get_mut(axis_inner) + .expect("invalid axis_inner") = b + .shape() + .first() + .expect("Could not get 0th axis len of b") + .checked_sub(1) + .expect("underflowing subtract"); + + if *zi.shape() != expected_shape { + let strides: [Ix; $N] = { + let zi_shape = zi.shape(); + let zi_strides = zi.strides(); + + // Waiting for try_collect() from nightly... we use this Vec> -> Result> method.. + let tmp_heap: Vec> = (0..$N) + .map(|k| { + if zi_shape[k] == expected_shape[k] { + zi_strides[k].try_into().map_err(|_| Error::InvalidArg { + arg: "zi".into(), + reason: "zi found with negative stride".into(), + }) + } else if k != axis_inner && zi_shape[k] == 1 { + Ok(0) + } else { + Err(Error::InvalidArg { + arg: "zi".into(), + reason: "Unexpected shape for parameter zi".into(), + }) + } + }) + .collect(); + let tmp_heap: Result> = tmp_heap.into_iter().collect(); + + tmp_heap?.try_into().unwrap() + }; + + zi = ArrayView::from_shape(expected_shape.strides(strides), zii.as_slice().unwrap()) + .unwrap(); + }; + + let (out_full_dim, out_full_dim_inner): (Dim<_>, [Ix; $N]) = { + let mut tmp: [Ix; $N] = ndarray_shape_as_array_st(&x); + tmp[axis_inner] += b.len_of(Axis(0)) - 1; // From np.convolve(..., 'full') + (IntoDimension::into_dimension(tmp), tmp) + }; + + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out_full = unsafe { Array::uninit(out_full_dim).assume_init() }; + out_full + .lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_full_slice, y)| { + // np.convolve uses full mode by default + // ```py + // out_full = np.apply_along_axis(lambda y: np.convolve(b, y), axis, x) + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + convolve(y, (&b).into(), ConvolveMode::Full)? + .assign_to(&mut out_full_slice); + Ok(()) + })?; + + // ```py + // ind[axis] = slice(zi.shape[axis]) + // out_full[tuple(ind)] += zi + // ``` + { + let slice_info: SliceInfo<_, Dim<[Ix; $N]>, Dim<[Ix; $N]>> = { + let t = zi.shape()[axis_inner]; + let mut tmp = [SliceInfoElem::from(..); $N]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: 0, + end: Some(t as isize), + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; // Does not work because unless N: N<=6 cannot be bounded on type_sig + let mut s = out_full.slice_mut(&slice_info); + s += &zi; + } + + let (out_dim, out_dim_inner) = { + let tmp: [Ix; $N] = ndarray_shape_as_array_st(&x); + (IntoDimension::into_dimension(tmp), tmp) + }; + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; + out.lanes_mut(axis) + .into_iter() + .zip(out_full.lanes(axis)) + .for_each(|(mut out_slice, out_full_slice)| { + // ```py + // # Create the [...; :out_full.shape[axis] - len(b) + 1; ...] at index=axis + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) + // out = out_full[tuple(ind)] + // ``` + out_full_slice + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + }); + + // ```py + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1, None) + // zf = out_full[tuple(ind)] + // ``` + let zf = { + let slice_info: SliceInfo<_, Dim<[Ix; $N]>, Dim<[Ix; $N]>> = { + let t = out_full.shape()[axis_inner] + .checked_add(1) + .unwrap() + .checked_sub(b.len()) + .unwrap(); + let mut tmp = [SliceInfoElem::from(..); $N]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: t as isize, + end: None, + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; + out_full.slice(slice_info).to_owned() + }; + + Ok((out, Some(zf))) + } else { + // In contrast to the case where zi.is_some(), we can inline a slicing operation to reduce + // one extra heap allocation. + + let (out_dim, out_dim_inner): (Dim<_>, [Ix; $N]) = { + let mut tmp: [Ix; $N] = ndarray_shape_as_array_st(&x); + (IntoDimension::into_dimension(tmp), tmp) + }; + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; + + out.lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_slice, y)| { + // np.convolve uses full mode, but is eventually slices out with + // ```py + // ind = out_full.ndim * [slice(None)] # creates the "[:, :, ..., :]" slice r + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) # [:out_full.shape[ ..] - len(b) + 1] + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + let out_full = convolve(y, (&b).into(), ConvolveMode::Full)?; + out_full + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + Ok(()) + })?; + + Ok((out, None)) + } + } + } + }; +} + +lfilter_for_dim!(1); +lfilter_for_dim!(2); +lfilter_for_dim!(3); +lfilter_for_dim!(4); +lfilter_for_dim!(5); +lfilter_for_dim!(6); + +/// Filter data `x` along one-dimension with an IIR or FIR filter. +/// +/// Filter a data sequence, `x`, using a digital filter. This works for many +/// fundamental data types (including Object type). The filter is a direct +/// form II transposed implementation of the standard difference equation +/// (see Notes). +/// +/// The function [super::sosfilt_dyn] (and filter design using ``output='sos'``) should be +/// preferred over `lfilter` for most filtering tasks, as second-order sections +/// have fewer numerical problems. +/// +/// ## Parameters +/// * `b` : array_like +/// The numerator coefficient vector in a 1-D sequence. +/// * `a` : array_like +/// The denominator coefficient vector in a 1-D sequence. If ``a[0]`` +/// is not 1, then both `a` and `b` are normalized by ``a[0]``. +/// * `x` : array_like +/// An N-dimensional input array. +/// * `axis`: `Option` +/// Default to `-1` if `None`. +/// Panics in accordance with [ndarray::ArrayBase::axis_iter]. +/// * `zi`: array_like +/// Currently not implemented. +/// Initial conditions for filter delays. It is a vector +/// (or array of vectors for an N-dimensional input) of length +/// ``max(len(a), len(b)) - 1``. If `zi` is None or is not given then +/// initial rest is assumed. See `lfiltic` and [super::lfilter_zi_dyn] for more information. +/// +/// ## Returns +/// * `y` : array +/// The output of the digital filter. +/// * `zf` : array, optional +/// If `zi` is None, this is not returned, otherwise, `zf` holds the +/// final filter delay values. +/// +/// # See Also +/// * [super::lfilter_zi_dyn] +/// +/// # Notes +/// If Array<_, IxDyn as provided by this function is not desired, consider using [LFilter]. +/// +/// # Examples +/// On a 1-dimensional signal: +/// ``` +/// use ndarray::{array, ArrayBase, Array1, ArrayView1, Dim, Ix, OwnedRepr}; +/// use sci_rs::signal::filter::lfilter; +/// +/// let b = array![5., 4., 1., 2.]; +/// let a = array![1.]; +/// let x = array![1., 2., 3., 4., 3., 5., 6.]; +/// let expected = array![5., 14., 24., 36., 38., 47., 61.]; +/// let (result, _) = lfilter((&b).into(), (&a).into(), x.view(), None, None).unwrap(); // By ref +/// +/// assert_eq!(result.len(), expected.len()); +/// result.into_iter().zip(expected).for_each(|(r, e)| { +/// assert_eq!(r, e); +/// }); +/// +/// let (result, _) = lfilter((&b).into(), (&a).into(), x.clone().into_dyn(), None, None).unwrap(); // Dynamic arrays +/// let (result, _) = lfilter((&b).into(), (&a).into(), x, None, None).unwrap(); // By value +/// ``` +/// +/// # Panics +/// Currently yet to implement for `a.len() > 1`. +// NOTE: zi's TypeSig inherits from lfilter's output, in accordance with examples section of +// documentation, both lfilter_zi and this should eventually support NDArray. +pub fn lfilter<'a, T, S, D>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: ArrayBase, + axis: Option, + zi: Option>, +) -> Result> +where + S: Data + 'a, + T: NumAssign + FromPrimitive + Copy + 'a, + D: Dimension, + SliceInfo, D, D>: SliceArg, +{ + let ndim = D::NDIM.unwrap_or(x.ndim()); + + if ndim == 0 { + return Err(Error::InvalidArg { + arg: "x".into(), + reason: "Linear filter requires at least 1-dimensional `x`.".into(), + }); + } + + if a.len() > 1 { + todo!(); + }; + + let (axis, axis_inner) = { + let ax = check_and_get_axis_dyn(axis, &x)?; + (Axis(ax), ax) + }; + + if a.is_empty() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: + "Empty 1D array will result in inf/nan result. Consider setting to `array![1.]`." + .into(), + }); + } else if a.first().unwrap().is_zero() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: "First element of a found to be zero.".into(), + }); + } + let b: Array1 = b.mapv(|bi| bi / a[0]); // b /= a[0] + + if let Some(zii) = zi { + // Use a separate branch to avoid unnecessary heap allocation of `out_full` in `zi` = None + // case. + let mut zi = zii.clone().reborrow().into_dyn(); + + // if zi.ndim != x.ndim { return Err(...) } is signature asserted. + + let mut expected_shape: Vec = x.shape().to_vec(); + *expected_shape // expected_shape[axis] = b.shape[0] - 1 + .get_mut(axis_inner) + .expect("invalid axis_inner") = b + .shape() + .first() + .expect("Could not get 0th axis len of b") + .checked_sub(1) + .expect("underflowing subtract"); + + if *zi.shape() != expected_shape { + let strides: Vec = { + let zi_shape = zi.shape(); + let zi_strides = zi.strides(); + + // Waiting for try_collect() from nightly... we use this Vec> -> Result> method.. + let tmp_heap: Vec> = (0..ndim) + .map(|k| { + if zi_shape[k] == expected_shape[k] { + zi_strides[k].try_into().map_err(|_| Error::InvalidArg { + arg: "zi".into(), + reason: "zi found with negative stride".into(), + }) + } else if k != axis_inner && zi_shape[k] == 1 { + Ok(0) + } else { + Err(Error::InvalidArg { + arg: "zi".into(), + reason: "Unexpected shape for parameter zi".into(), + }) + } + }) + .collect(); + let tmp_heap: Result> = tmp_heap.into_iter().collect(); + + tmp_heap? + }; + + // ArrayView::from_shape(strides, + // zi.as_slice_memory_order().unwrap()).unwrap().to_owned() + zi = ArrayView::from_shape((expected_shape).strides(strides), zii.as_slice().unwrap()) + .unwrap(); + }; + + let (out_full_dim, out_full_dim_inner): (Dim<_>, Vec) = { + let mut tmp = x.shape().to_vec(); + tmp[axis_inner] += b.len_of(Axis(0)) - 1; // From np.convolve(..., 'full') + (IntoDimension::into_dimension(tmp.as_ref()), tmp) + }; + + let mut out_full = ArrayD::::zeros(out_full_dim); + out_full + .lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_full_slice, y)| { + // np.convolve uses full mode by default + // ```py + // out_full = np.apply_along_axis(lambda y: np.convolve(b, y), axis, x) + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + convolve(y, (&b).into(), ConvolveMode::Full)?.assign_to(&mut out_full_slice); + Ok(()) + })?; + + // ```py + // ind[axis] = slice(zi.shape[axis]) + // out_full[tuple(ind)] += zi + // ``` + { + let slice_info: SliceInfo<_, D, D> = { + let t = zi.shape()[axis_inner]; + let mut tmp = vec![SliceInfoElem::from(..); ndim]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: 0, + end: Some(t as isize), + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; // Does not work because unless N: N<=6 cannot be bounded on type_sig + let mut s = out_full.slice_mut(&slice_info); + s += &zi; + } + + let (out_dim, out_dim_inner) = { + // let mut out_dim_inner = out_full_dim_inner; + // if let Some(inner) = out_dim_inner.get_mut(axis_inner) { + // *inner = inner + // .checked_sub({ + // // Safety: b is Array1 + // *b.shape().first().unwrap() + // }) + // // Safety: inner is defined by having added b.len() + // .unwrap() + // + 1; + // } else { + // unsafe { unreachable_unchecked() }; + // }; + // (IntoDimension::into_dimension(out_dim_inner), out_dim_inner) + let tmp = x.shape(); + (IntoDimension::into_dimension(tmp), tmp) + }; + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; + out.lanes_mut(axis) + .into_iter() + .zip(out_full.lanes(axis)) + .for_each(|(mut out_slice, out_full_slice)| { + // ```py + // # Create the [...; :out_full.shape[axis] - len(b) + 1; ...] at index=axis + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) + // out = out_full[tuple(ind)] + // ``` + out_full_slice + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + }); + + // ```py + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1, None) + // zf = out_full[tuple(ind)] + // ``` + let zf = { + let slice_info: SliceInfo<_, D, IxDyn> = { + let t = out_full.shape()[axis_inner] + .checked_add(1) + .unwrap() + .checked_sub(b.len()) + .unwrap(); + let mut tmp = vec![SliceInfoElem::from(..); ndim]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: t as isize, + end: None, + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; + out_full.slice(slice_info).to_owned() + }; + + Ok((out, Some(zf))) + } else { + // In contrast to the case where zi.is_some(), we can inline a slicing operation to reduce + // one extra heap allocation. + + let (out_dim, out_dim_inner) = { + let tmp = x.shape(); + (IntoDimension::into_dimension(tmp), tmp) + }; + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; // Safety: All elements are overwritten by convolve in subsequent step. + + out.lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_slice, y)| { + // np.convolve uses full mode, but is eventually slices out with + // ```py + // ind = out_full.ndim * [slice(None)] # creates the "[:, :, ..., :]" slice r + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) # [:out_full.shape[ ..] - len(b) + 1] + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + let out_full = convolve(y, (&b).into(), ConvolveMode::Full)?; + out_full + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + Ok(()) + })?; + + Ok((out, None)) + } +} + +/// Internal function called by [LFilter::lfilter] for situation a.len() > 1. +fn linear_filter<'a, T, S, D>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: ArrayBase, + axis: Option, + zi: Option>, +) -> Result> +where + D: Dimension, + T: 'a, + S: Data + 'a, +{ + todo!() +} + +#[cfg(test)] +mod test { + use super::*; + use alloc::vec; + use approx::assert_relative_eq; + use ndarray::{array, ArrayBase, Dim, Ix, OwnedRepr, ViewRepr}; + + // Tests that have a = [1.] with zi = None on input x with dim = 1. + #[test] + fn one_dim_fir_no_zi() { + { + // Tests for b.sum() > 1. + let b = array![5., 4., 1., 2.]; + let a = array![1.]; + let x = array![1., 2., 3., 4., 3., 5., 6.]; + let expected = array![5., 14., 24., 36., 38., 47., 61.]; + + let Ok((result, None)) = Array1::lfilter((&b).into(), (&a).into(), x, None, None) + else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_eq!(r, e); + }) + } + { + // Tests for b[i] < 0 for some i, such that b.sum() = 1. + let b = array![0.7, -0.3, 0.6]; + let a = array![1.]; + let x = array![1., 2., 3., 4., 3., 5., 6.]; + let expected = array![0.7, 1.1, 2.1, 3.1, 2.7, 5., 4.5]; + + let Ok((result, None)) = Array1::lfilter((&b).into(), (&a).into(), x, None, None) + else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + } + + #[test] + fn one_dim_fir_with_zi() { + { + // Case which does not falls into zi.shape() != expected_shape branch + let b = array![0.5, 0.4]; + let a = array![1.]; + let x = array![ + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + ]; + let zi = array![[-1.6], [1.4]]; + let expected = array![ + [-3.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7], + [-0.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7] + ]; + let expected_zi = array![[0.4], [0.4]]; + + let Ok((result, Some(r_zi))) = Array::<_, Dim<[Ix; 2]>>::lfilter( + (&b).into(), + (&a).into(), + x, + None, + Some((&zi).into()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]]; + let zi = array![[0.4], [0.45], [0.05]]; + let expected = array![ + [5.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.85, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85], + ]; + let expected_zi = array![ + [4.25, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.], + ]; + + let Ok((result, Some(r_zi))) = Array::<_, Dim<[Ix; 2]>>::lfilter( + (&b).into(), + (&a).into(), + x, + Some(0), + Some((&zi).into()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch for 3D input + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![ + [[0.2, 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]], + [[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]] + ]; + let zi = array![[[0.4], [0.45], [0.05]], [[0.6], [0.15], [0.25]]]; + let expected = array![ + [ + [1.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.53, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85] + ], + [ + [5.6, 10.6, 15.6, 20.6, 15.6, 25.6, 30.6], + [40.55, 0.95, 6.35, 1.75, 16.35, 37.15, 32.55] + ] + ]; + let expected_zi = array![ + [ + [3.45, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [7.6, -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ], + [ + [4.45, 2.25, 3.65, 4.25, 4.45, 8.05, 8.65], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ] + ]; + + let Ok((result, Some(r_zi))) = Array::<_, Dim<[Ix; 3]>>::lfilter( + (&b).into(), + (&a).into(), + x, + Some(1), + Some((&zi).into()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + } + + #[test] + fn invalid_axis() { + let b = array![5., 4., 1., 2.]; + let a = array![1.]; + let x = array![1., 2., 3., 4., 3., 5., 6.]; + + let result = ArrayView1::lfilter((&b).into(), (&a).into(), (&x).into(), Some(2), None); + assert!(result.is_err()); + + let result = Array1::lfilter((&b).into(), (&a).into(), x.clone(), Some(1), None); + assert!(result.is_err()); + + let result = Array1::lfilter((&b).into(), (&a).into(), x.clone(), Some(0), None); + assert!(result.is_ok()); + + let result = Array1::lfilter((&b).into(), (&a).into(), x.clone(), Some(-1), None); + assert!(result.is_ok()); + + let result = Array1::lfilter((&b).into(), (&a).into(), x, Some(-2), None); + assert!(result.is_err()); + } + + #[test] + fn dyn_dim_fir_with_zi() { + { + // Case which does not falls into zi.shape() != expected_shape branch + let b = array![0.5, 0.4]; + let a = array![1.]; + let x = array![ + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + ]; + let zi = array![[-1.6], [1.4]]; + let expected = array![ + [-3.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7], + [-0.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7] + ]; + let expected_zi = array![[0.4], [0.4]]; + + // Test static dim input + let Ok((result, Some(r_zi))) = + lfilter((&b).into(), (&a).into(), x.view(), None, Some((&zi).into())) + else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(&expected).for_each(|(r, &e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(&expected_zi).for_each(|(r, &e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + + // Test dyn input + let Ok((result, Some(r_zi))) = lfilter( + (&b).into(), + (&a).into(), + x.into_dyn(), + None, + Some(zi.into_dyn().view()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]]; + let zi = array![[0.4], [0.45], [0.05]]; + let expected = array![ + [5.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.85, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85], + ]; + let expected_zi = array![ + [4.25, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.], + ]; + + let Ok((result, Some(r_zi))) = lfilter( + (&b).into(), + (&a).into(), + x.into_dyn(), + Some(0), + Some(zi.into_dyn().view()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch for 3D input + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![ + [[0.2, 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]], + [[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]] + ]; + let zi = array![[[0.4], [0.45], [0.05]], [[0.6], [0.15], [0.25]]]; + let expected = array![ + [ + [1.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.53, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85] + ], + [ + [5.6, 10.6, 15.6, 20.6, 15.6, 25.6, 30.6], + [40.55, 0.95, 6.35, 1.75, 16.35, 37.15, 32.55] + ] + ]; + let expected_zi = array![ + [ + [3.45, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [7.6, -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ], + [ + [4.45, 2.25, 3.65, 4.25, 4.45, 8.05, 8.65], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ] + ]; + + let Ok((result, Some(r_zi))) = lfilter( + (&b).into(), + (&a).into(), + x.into_dyn(), + Some(1), + Some(zi.into_dyn().view()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + } +} diff --git a/sci-rs/src/signal/filter/lfilter_zi.rs b/sci-rs/src/signal/filter/lfilter_zi.rs index 6c92df5a..48296030 100644 --- a/sci-rs/src/signal/filter/lfilter_zi.rs +++ b/sci-rs/src/signal/filter/lfilter_zi.rs @@ -9,20 +9,17 @@ use alloc::vec; #[cfg(feature = "alloc")] use alloc::vec::Vec; +/// Construct initial conditions for [lfilter][super::lfilter::LFilter] for step response +/// steady-state. /// -/// Construct initial conditions for lfilter for step response steady-state. -/// -/// Compute an initial state `zi` for the `lfilter` function that corresponds -/// to the steady state of the step response. +/// Compute an initial state `zi` for the `lfilter` function that corresponds to the steady state +/// of the step response. /// /// A typical use of this function is to set the initial state so that the /// output of the filter starts at the same value as the first element of /// the signal to be filtered. /// -/// /// -/// -/// #[inline] pub fn lfilter_zi_dyn(b: &[F], a: &[F]) -> Vec where @@ -38,7 +35,7 @@ where .expect("There must be at least one nonzero `a` coefficient.") .0; - // Mormalize to a[0] == 1 + // Normalize to a[0] == 1 let mut a = a.iter().skip(ai0).cloned().collect::>(); let mut b = b.to_vec(); let a0 = a[0]; diff --git a/sci-rs/src/signal/filter/mod.rs b/sci-rs/src/signal/filter/mod.rs index 805c65b7..ec0c5f2c 100644 --- a/sci-rs/src/signal/filter/mod.rs +++ b/sci-rs/src/signal/filter/mod.rs @@ -13,7 +13,9 @@ pub use kalmanfilt::kalman::kalman_filter; /// pub use gaussfilt as gaussian_filter; -/// Digital IIR/FIR filter design +/// Digital IIR/FIR filter design +/// Functions located in the [`Filter design` section of +/// `scipy.signal`](https://docs.scipy.org/doc/scipy/reference/signal.html#filter-design). pub mod design; mod ext; @@ -22,6 +24,13 @@ mod sosfilt; pub use ext::*; pub use sosfilt::*; +#[cfg(feature = "alloc")] +mod arraytools; +#[cfg(feature = "alloc")] +use arraytools::*; + +#[cfg(feature = "alloc")] +mod lfilter; #[cfg(feature = "alloc")] mod lfilter_zi; #[cfg(feature = "alloc")] @@ -31,6 +40,8 @@ mod sosfilt_zi; #[cfg(feature = "alloc")] mod sosfiltfilt; +#[cfg(feature = "alloc")] +pub use lfilter::*; #[cfg(feature = "alloc")] pub use lfilter_zi::*; #[cfg(feature = "alloc")]