diff --git a/sci-rs-core/Cargo.toml b/sci-rs-core/Cargo.toml new file mode 100644 index 00000000..26fe7017 --- /dev/null +++ b/sci-rs-core/Cargo.toml @@ -0,0 +1,30 @@ +[package] +name = "sci-rs-core" +version = "0.0.0" +edition = "2021" +authors = ["Jacob Trueb "] +description = "Core library for sci-rs internals." +license = "MIT OR Apache-2.0" +repository = "https://github.com/qsib-cbie/sci-rs.git" +homepage = "https://github.com/qsib-cbie/sci-rs.git" +readme = "../README.md" +keywords = ["scipy", "dsp", "signal", "filter", "design"] +categories = ["science", "mathematics", "no-std", "embedded"] + + +[package.metadata.docs.rs] +all-features = true + +[features] +default = ['alloc'] + +# Allow allocating vecs, matrices, etc. +alloc = [] + +# Enable FFT and standard library features +std = ['alloc'] + +[dependencies] +ndarray = { version = "0.16.1", default-features = false } +ndarray-conv = { version = "0.5.0" } +num-traits = { version = "0.2.15", default-features = false } diff --git a/sci-rs-core/src/lib.rs b/sci-rs-core/src/lib.rs new file mode 100644 index 00000000..f52e09a4 --- /dev/null +++ b/sci-rs-core/src/lib.rs @@ -0,0 +1,79 @@ +//! Core library for sci-rs. + +#![cfg_attr(not(feature = "std"), no_std)] + +#[cfg(feature = "alloc")] +extern crate alloc; +#[cfg(feature = "alloc")] +use alloc::format; + +use core::{error, fmt}; + +pub type Result = core::result::Result; + +/// Errors raised whilst running sci-rs. +#[derive(Debug, PartialEq, Eq)] +pub enum Error { + /// Argument parsed into function were invalid. + #[cfg(feature = "alloc")] + InvalidArg { + /// The invalid arg + arg: alloc::string::String, + /// Explaining why arg is invalid. + reason: alloc::string::String, + }, + /// Argument parsed into function were invalid. + #[cfg(not(feature = "alloc"))] + InvalidArg, + /// Two or more optional arguments passed into functions conflict. + #[cfg(feature = "alloc")] + ConflictArg { + /// Explaining what arg is invalid. + reason: alloc::string::String, + }, + /// Two or more optional arguments passed into functions conflict. + #[cfg(not(feature = "alloc"))] + ConflictArg, + /// Errors raised by [ndarray_conv::Error] + #[cfg(feature = "alloc")] + Conv { reason: alloc::string::String }, + /// Errors raised by [ndarray_conv::Error] + #[cfg(not(feature = "alloc"))] + Conv, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + #[cfg(feature = "alloc")] + Error::InvalidArg { arg, reason } => + format!("Invalid Argument on arg = {} with reason = {}", arg, reason), + #[cfg(not(feature = "alloc"))] + Error::InvalidArg => + "There were invalid arguments. Reasons not shown without `alloc` feature.", + #[cfg(feature = "alloc")] + Error::ConflictArg { reason } => + format!("Conflicting Arguments with reason = {}", reason), + #[cfg(not(feature = "alloc"))] + Error::ConflictArg => + "There were conflicting arguments. Reasons not shown without `alloc` feature.", + #[cfg(feature = "alloc")] + Error::Conv { reason } => format!( + "An error occurred during the convolution from ndarray_conv with reason {}.", + reason + ), + #[cfg(not(feature = "alloc"))] + Error::Conv => "An error occurred during the convolution from ndarray_conv. Reasons not shown without `alloc` feature.", + } + ) + } +} + +impl error::Error for Error {} + +/// Collection of numpy-like functions for use by sci-rs. +/// Provide behaviour parity against Numpy, even if the types are not identical. +pub mod num_rs; diff --git a/sci-rs-core/src/num_rs/convolve/mod.rs b/sci-rs-core/src/num_rs/convolve/mod.rs new file mode 100644 index 00000000..2d80c7f7 --- /dev/null +++ b/sci-rs-core/src/num_rs/convolve/mod.rs @@ -0,0 +1,135 @@ +mod ndarray_conv_binds; + +use crate::{Error, Result}; +use alloc::string::ToString; +use ndarray::{Array1, ArrayView1}; +use ndarray_conv::{ConvExt, PaddingMode}; + +/// Convolution mode determines behavior near edges and output size +pub enum ConvolveMode { + /// Full convolution, output size is `in1.len() + in2.len() - 1` + Full, + /// Valid convolution, output size is `max(in1.len(), in2.len()) - min(in1.len(), in2.len()) + 1` + Valid, + /// Same convolution, output size is `in1.len()` + Same, +} + +/// Best effort parallel behaviour with numpy's convolve method. We take `v` as the convolution +/// kernel. +/// +/// Returns the discrete, linear convolution of two one-dimensional sequences. +/// +/// # Parameters +/// * `a` : (N,) [[array_like]]([ndarray::Array1]) +/// Signal to be (linearly) convolved. +/// * `v` : (M,) [[array_like]]([ndarray::Array1]) +/// Second one-dimensional input array. +/// * `mode` : [ConvolveMode] +/// [ConvolveMode::Full]: +/// By default, mode is 'full'. This returns the convolution at each point of overlap, with an +/// output shape of (N+M-1,). At the end-points of the convolution, the signals do not overlap +/// completely, and boundary effects may be seen. +/// +/// [ConvolveMode::Same]: +/// Mode 'same' returns output of length ``max(M, N)``. Boundary effects are still visible. +/// +/// [ConvolveMode::Valid]: +/// Mode 'valid' returns output of length ``max(M, N) - min(M, N) + 1``. The convolution +/// product is only given for points where the signals overlap completely. Values outside the +/// signal boundary have no effect. +/// +/// # Panics +/// We assume that `v` is shorter than `a`. +/// +/// # Examples +/// With [ConvolveMode::Full]: +/// ``` +/// use ndarray::array; +/// use sci_rs_core::num_rs::{ConvolveMode, convolve}; +/// +/// let a = array![1., 2., 3.]; +/// let v = array![0., 1., 0.5]; +/// +/// let expected = array![0., 1., 2.5, 4., 1.5]; +/// let result = convolve((&a).into(), (&v).into(), ConvolveMode::Full).unwrap(); +/// assert_eq!(result, expected); +/// ``` +/// With [ConvolveMode::Same]: +/// ``` +/// use ndarray::array; +/// use sci_rs_core::num_rs::{ConvolveMode, convolve}; +/// +/// let a = array![1., 2., 3.]; +/// let v = array![0., 1., 0.5]; +/// +/// let expected = array![1., 2.5, 4.]; +/// let result = convolve((&a).into(), (&v).into(), ConvolveMode::Same).unwrap(); +/// assert_eq!(result, expected); +/// ``` +/// With [ConvolveMode::Same]: +/// ``` +/// use ndarray::array; +/// use sci_rs_core::num_rs::{ConvolveMode, convolve}; +/// +/// let a = array![1., 2., 3.]; +/// let v = array![0., 1., 0.5]; +/// +/// let expected = array![2.5]; +/// let result = convolve((&a).into(), (&v).into(), ConvolveMode::Valid).unwrap(); +/// assert_eq!(result, expected); +/// ``` +pub fn convolve(a: ArrayView1, v: ArrayView1, mode: ConvolveMode) -> Result> +where + T: num_traits::NumAssign + core::marker::Copy, +{ + // Convolve + let result = a.conv(&v, mode.into(), PaddingMode::Zeros); + #[cfg(feature = "alloc")] + { + result.map_err(|e| Error::Conv { + reason: e.to_string(), + }) + } + #[cfg(not(feature = "alloc"))] + { + result.map_err({ Error::Conv }) + } +} + +#[cfg(test)] +mod linear_convolve { + use super::*; + use alloc::vec; + use ndarray::array; + + #[test] + fn full() { + let a = array![1., 2., 3.]; + let v = array![0., 1., 0.5]; + + let expected = array![0., 1., 2.5, 4., 1.5]; + let result = convolve((&a).into(), (&v).into(), ConvolveMode::Full).unwrap(); + assert_eq!(result, expected); + } + + #[test] + fn same() { + let a = array![1., 2., 3.]; + let v = array![0., 1., 0.5]; + + let expected = array![1., 2.5, 4.]; + let result = convolve((&a).into(), (&v).into(), ConvolveMode::Same).unwrap(); + assert_eq!(result, expected); + } + + #[test] + fn valid() { + let a = array![1., 2., 3.]; + let v = array![0., 1., 0.5]; + + let expected = array![2.5]; + let result = convolve((&a).into(), (&v).into(), ConvolveMode::Valid).unwrap(); + assert_eq!(result, expected); + } +} diff --git a/sci-rs-core/src/num_rs/convolve/ndarray_conv_binds.rs b/sci-rs-core/src/num_rs/convolve/ndarray_conv_binds.rs new file mode 100644 index 00000000..b0f38c79 --- /dev/null +++ b/sci-rs-core/src/num_rs/convolve/ndarray_conv_binds.rs @@ -0,0 +1,12 @@ +use super::ConvolveMode; +use ndarray_conv::ConvMode; + +impl From for ConvMode { + fn from(value: ConvolveMode) -> Self { + match value { + ConvolveMode::Full => ConvMode::Full, + ConvolveMode::Same => ConvMode::Same, + ConvolveMode::Valid => ConvMode::Valid, + } + } +} diff --git a/sci-rs-core/src/num_rs/mod.rs b/sci-rs-core/src/num_rs/mod.rs new file mode 100644 index 00000000..e6e9679c --- /dev/null +++ b/sci-rs-core/src/num_rs/mod.rs @@ -0,0 +1,4 @@ +#[cfg(feature = "alloc")] +mod convolve; +#[cfg(feature = "alloc")] +pub use convolve::*; diff --git a/sci-rs/Cargo.toml b/sci-rs/Cargo.toml index 697e7933..26db45e0 100644 --- a/sci-rs/Cargo.toml +++ b/sci-rs/Cargo.toml @@ -19,10 +19,10 @@ all-features = true default = ['alloc'] # Allow allocating vecs, matrices, etc. -alloc = ['nalgebra/alloc', 'nalgebra/libm', 'kalmanfilt/alloc'] +alloc = ['nalgebra/alloc', 'nalgebra/libm', 'kalmanfilt/alloc', 'sci-rs-core/alloc'] # Enable FFT and standard library features -std = ['nalgebra/std', 'nalgebra/macros', 'rustfft', 'alloc'] +std = ['nalgebra/std', 'nalgebra/macros', 'rustfft', 'alloc','sci-rs-core/std'] # Enable debug plotting through python system calls plot = ['std'] @@ -36,12 +36,13 @@ lstsq = { version = "0.6.0", default-features = false } rustfft = { version = "6.2.0", optional = true } kalmanfilt = { version = "0.3.0", default-features = false } gaussfilt = { version = "0.1.3", default-features = false } +sci-rs-core = { path = "../sci-rs-core", default-features = false } [dev-dependencies] approx = "0.5.1" dasp_signal = { version = "0.11.0" } criterion = { version = "0.4", features = ["html_reports"] } -rand = "0.8.4" +rand = "0.9.2" [[bench]] name = "sosfilt" @@ -50,3 +51,7 @@ harness = false [[bench]] name = "sosfiltfilt" harness = false + +[[bench]] +name = "lfilter" +harness = false diff --git a/sci-rs/benches/lfilter.rs b/sci-rs/benches/lfilter.rs new file mode 100644 index 00000000..4d41ff1f --- /dev/null +++ b/sci-rs/benches/lfilter.rs @@ -0,0 +1,108 @@ +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use ndarray::{array, Array1, ArrayView1}; +use rand::rngs::ThreadRng; +use sci_rs::signal::filter::design::{firwin_dyn, FilterBandType}; +use sci_rs::signal::filter::LFilter; +use sci_rs::signal::windows::Hamming; +use std::num::NonZeroUsize; + +/// Get a randomized signal from instance of `rng`. +fn randomized_signal( + mut rng: ThreadRng, + num_freqs: NonZeroUsize, + num_data_points: NonZeroUsize, + time_seconds: f64, +) -> (Array1, Array1) { + use rand::Rng; + + let nf: usize = num_freqs.into(); // Num freqs + let n: usize = num_data_points.into(); // Num data points + let num_secs: f64 = time_seconds; // this corresponds to how much N data points corresponds to in time. + let nyq_freq: f64 = 0.5 * n as f64 / num_secs; // Do not generated frequency should be larger than this + + let t: Array1 = Array1::linspace(0.0, num_secs, n); // Time-axis + let global_shift = t.mapv(|ti| 7.0 * (-ti / 2.0).exp()); // Total offset independent of sinusoidal generation + + let ampl = { + let orig: Array1 = Array1::from_iter((0..nf).map(|_| rng.random_range(0.5..1.5))); + let decay: Array1 = Array1::from_iter((0..nf).map(|i| 1. / (1.1f64.powf(i as f64)))); // Weight + + decay * orig + }; + let freqs = { + let mut freqs = Vec::with_capacity(nf); + const INITIAL_FREQ: f64 = 1.2; + freqs.push(INITIAL_FREQ); + (1..nf).fold(INITIAL_FREQ, |acc, _| { + let next_freq = (acc * rng.random_range(1.01..2.01)) % nyq_freq; // Approximately double wrt the previous frequency. + freqs.push(next_freq); + next_freq + }); + freqs + }; + let phases: Vec<_> = (0..nf) + .map(|_| rng.random_range(0.0..std::f64::consts::PI)) + .collect(); + + let mut result: Array1 = Array1::zeros((n,)) + global_shift; + + for ((a, freq), p) in ampl + .into_iter() + .zip(freqs.into_iter()) + .zip(phases.into_iter()) + { + let wave = t.mapv(|ti| a * (freq * ti + p).sin()); + result += &wave; + } + + (t, result) +} + +/// Test the filter with zi. +/// +/// Use window asasociated with `decimate`'s default values, running at decimation factor = 500. +fn lfilter_dyn(c: &mut Criterion) { + const DECIMATION_FACTOR: usize = 50; + const FILTER_ORDER: usize = DECIMATION_FACTOR * 20; + + // Finite impulse response from Hamming window; + let b: Array1 = firwin_dyn( + FILTER_ORDER + 1, + &[1. / (DECIMATION_FACTOR as f64)], + None, + None::<&Hamming>, + &FilterBandType::Lowpass, + None, + None, + ) + .unwrap() + .into(); + let a = array![1.]; + + let (_, signal) = randomized_signal( + rand::rng(), + NonZeroUsize::new(14).unwrap(), + NonZeroUsize::new(1 << 16).unwrap(), + 15., + ); + + // Apply with criterion + c.bench_with_input( + BenchmarkId::new("lfilter_dyn", DECIMATION_FACTOR), + &signal, + |bench, sig| { + bench.iter(|| { + ArrayView1::lfilter( + black_box((&b).into()), + black_box((&a).into()), + black_box((sig).into()), + None, + None, + ) + }) + }, + ); +} + +criterion_group!(benches, lfilter_dyn); +criterion_main!(benches); diff --git a/sci-rs/src/error/mod.rs b/sci-rs/src/error/mod.rs new file mode 100644 index 00000000..c4fe9ade --- /dev/null +++ b/sci-rs/src/error/mod.rs @@ -0,0 +1,27 @@ +use core::{error, fmt}; + +/// Errors raised whilst running sci-rs. +#[derive(Debug, PartialEq, Eq)] +#[cfg(feature = "alloc")] +pub enum Error { + /// Argument parsed into function were invalid. + InvalidArg { + /// The invalid arg + arg: alloc::string::String, + /// Explaining why arg is invalid. + reason: alloc::string::String, + }, + /// Two or more optional arguments passed into functions conflict. + ConflictArg { + /// Explaining what arg is invalid. + reason: alloc::string::String, + }, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + todo!() + } +} + +impl error::Error for Error {} diff --git a/sci-rs/src/lib.rs b/sci-rs/src/lib.rs index 7e47de09..c19eec65 100644 --- a/sci-rs/src/lib.rs +++ b/sci-rs/src/lib.rs @@ -41,3 +41,6 @@ pub mod special; /// Debug plotting #[cfg(feature = "plot")] pub mod plot; + +/// Errors +pub mod error; diff --git a/sci-rs/src/signal/convolve.rs b/sci-rs/src/signal/convolve.rs index b527805e..1d667168 100644 --- a/sci-rs/src/signal/convolve.rs +++ b/sci-rs/src/signal/convolve.rs @@ -2,15 +2,7 @@ use nalgebra::Complex; use num_traits::{Float, FromPrimitive, Signed, Zero}; use rustfft::{FftNum, FftPlanner}; -/// Convolution mode determines behavior near edges and output size -pub enum ConvolveMode { - /// Full convolution, output size is `in1.len() + in2.len() - 1` - Full, - /// Valid convolution, output size is `max(in1.len(), in2.len()) - min(in1.len(), in2.len()) + 1` - Valid, - /// Same convolution, output size is `in1.len()` - Same, -} +pub use sci_rs_core::num_rs::ConvolveMode; /// Performs FFT-based convolution on two slices of floating point values. /// @@ -173,12 +165,12 @@ mod tests { #[test] #[cfg(feature = "plot")] fn test_scipy_example() { - use rand::distributions::{Distribution, Standard}; - use rand::thread_rng; + use rand::distr::{Distribution, StandardUniform}; + use rand::rng; // Generate 1000 random samples from standard normal distribution - let mut rng = thread_rng(); - let sig: Vec = Standard.sample_iter(&mut rng).take(1000).collect(); + let mut rng = rng(); + let sig: Vec = StandardUniform.sample_iter(&mut rng).take(1000).collect(); // Compute autocorrelation using correlate directly let autocorr = correlate(&sig, &sig, ConvolveMode::Full); diff --git a/sci-rs/src/signal/filter/arraytools.rs b/sci-rs/src/signal/filter/arraytools.rs new file mode 100644 index 00000000..029ef172 --- /dev/null +++ b/sci-rs/src/signal/filter/arraytools.rs @@ -0,0 +1,113 @@ +//! Functions for acting on a axis of an array. +//! +//! Designed for ndarrays; with scipy's internal nomenclature. + +use ndarray::{ArrayBase, Axis, Data, Dim, Dimension, IntoDimension, Ix, RemoveAxis}; +use sci_rs_core::{Error, Result}; + +/// Internal function for casting into [Axis] and appropriate usize from isize. +/// +/// # Parameters +/// axis: The user-specificed axis which filter is to be applied on. +/// x: The input-data whose axis object that will be manipulated against. +/// +/// # Notes +/// Const nature of this function means error has to be manually created. +#[inline] +pub(crate) const fn check_and_get_axis_st<'a, T, S, const N: usize>( + axis: Option, + x: &ArrayBase>, +) -> core::result::Result +where + S: Data + 'a, +{ + // Before we convert into the appropriate axis object, we have to check at runtime that the + // axis value specified is within -N <= axis < N. + match axis { + None => (), + Some(axis) if axis.is_negative() => { + if axis.unsigned_abs() > N { + return Err(()); + } + } + Some(axis) => { + if axis.unsigned_abs() >= N { + return Err(()); + } + } + } + + // We make a best effort to convert into appropriate axis object. + let axis_inner: isize = match axis { + Some(axis) => axis, + None => -1, + }; + if axis_inner >= 0 { + Ok(axis_inner.unsigned_abs()) + } else { + let axis_inner = N + .checked_add_signed(axis_inner) + .expect("Invalid add to `axis` option"); + Ok(axis_inner) + } +} + +/// Internal function for casting into [Axis] and appropriate usize from isize. +/// [check_and_get_axis_st] but without const, especially for IxDyn arrays. +/// +/// # Parameters +/// axis: The user-specificed axis which filter is to be applied on. +/// x: The input-data whose axis object that will be manipulated against. +#[inline] +pub(crate) fn check_and_get_axis_dyn<'a, T, S, D>( + axis: Option, + x: &ArrayBase, +) -> Result +where + D: Dimension, + S: Data + 'a, +{ + let ndim = D::NDIM.unwrap_or(x.ndim()); + // Before we convert into the appropriate axis object, we have to check at runtime that the + // axis value specified is within -N <= axis < N. + if axis.is_some_and(|axis| { + !(if axis < 0 { + axis.unsigned_abs() <= ndim + } else { + axis.unsigned_abs() < ndim + }) + }) { + return Err(Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + }); + } + + // We make a best effort to convert into appropriate axis object. + let axis_inner: isize = axis.unwrap_or(-1); + if axis_inner >= 0 { + Ok(axis_inner.unsigned_abs()) + } else { + let axis_inner = ndim + .checked_add_signed(axis_inner) + .expect("Invalid add to `axis` option"); + Ok(axis_inner) + } +} + +/// Internal function for obtaining length of all axis as array from input from input. +/// +/// This is almost the same as `a.shape()`, but is a array `[T; N]` instead of a slice `&[T]`. +/// +/// # Parameters +/// `a`: Array whose shape is needed as a slice. +pub(crate) fn ndarray_shape_as_array_st<'a, S, T, const N: usize>( + a: &ArrayBase>, +) -> [Ix; N] +where + [Ix; N]: IntoDimension>, + Dim<[Ix; N]>: RemoveAxis, + S: Data + 'a, +{ + a.shape().try_into().expect("Could not cast shape to array") +} diff --git a/sci-rs/src/signal/filter/design/firwin.rs b/sci-rs/src/signal/filter/design/firwin.rs new file mode 100644 index 00000000..125f6820 --- /dev/null +++ b/sci-rs/src/signal/filter/design/firwin.rs @@ -0,0 +1,883 @@ +use super::filter_type::FilterBandType; +use super::iirfilter_dyn; +use super::{kaiser_atten, kaiser_beta}; +use crate::signal::windows::get_window; +use crate::{error, error::Error, special}; +use core::cmp::Ord; +use nalgebra::RealField; +use num_traits::{real::Real, Float, MulAdd, Pow}; + +#[cfg(feature = "alloc")] +use crate::signal::{ + windows, + windows::{GetWindow, GetWindowBuilder}, + windows::{Hamming, Kaiser}, +}; +#[cfg(feature = "alloc")] +use alloc::{string::String, vec, vec::Vec}; + +/// Validation of [firwin_dyn] input. +fn firwin_dyn_validate( + numtaps: &usize, + cutoff: &[F], + pass_zero: &FilterBandType, + width: &Option, + window: &Option<&impl GetWindow>, +) -> Result<(), Error> { + if cutoff.is_empty() { + return Err(Error::InvalidArg { + arg: "cutoff".into(), + reason: "At least one cutoff frequency must be given.".into(), + }); + } + if *numtaps == 0 { + return Err(Error::InvalidArg { + arg: "numtaps".into(), + reason: "Invalid numtaps: Nonzero-numtaps is expected!".into(), + }); + } + + // Whilst it may be faster to write + // if *cutoff.iter().min().unwrap() <= F::zero() || *cutoff.iter().max().unwrap() >= F::one() { + // + // vec.min() requires the Ord trait, which is very difficult to impose onto f64 + let minimal = cutoff.iter().fold(F::max_value(), |acc, &m| acc.min(m)); + let maximal = cutoff.iter().fold(F::min_value(), |acc, &m| acc.max(m)); + if minimal <= F::zero() || maximal >= F::one() { + return Err(Error::InvalidArg { + arg: "cutoff".into(), + reason: + "Invalid cutoff frequency: frequencies must be greater than 0 and less than fs/2." + .into(), + }); + } + if cutoff.windows(2).any(|x| x[0] >= x[1]) { + return Err(Error::InvalidArg { + arg: "cutoff".into(), + reason: "Invalid cutoff frequencies: the frequencies must be strictly increasing." + .into(), + }); + } + + match pass_zero { + FilterBandType::Lowpass => { + if cutoff.len() != 1 { + return Err(Error::InvalidArg { + arg: "cutoff".into(), + reason: "cutoff must have one element if pass_zero is Lowpass.".into(), + }); + } + } + FilterBandType::Bandstop => { + if cutoff.len() < 2 { + return Err(Error::InvalidArg { + arg: "cutoff".into(), + reason: "cutoff must have at least two elements if pass_zero is Bandstop." + .into(), + }); + } + } + FilterBandType::Highpass => { + if cutoff.len() != 1 { + return Err(Error::InvalidArg { + arg: "cutoff".into(), + reason: "cutoff must have one element if pass_zero is Highpass.".into(), + }); + } + } + FilterBandType::Bandpass => { + if cutoff.len() < 2 { + return Err(Error::InvalidArg { + arg: "cutoff".into(), + reason: "cutoff must have at least two elements if pass_zero is Bandpass." + .into(), + }); + } + } + } + + // While this was silently ignored in Scipy, we make this explicit here. + // Cannot use != on &impl GetWindow. + if window.is_some() && width.is_some() { + return Err(Error::InvalidArg { + arg: "window/width".into(), + reason: "Setting both window and width to something is silently ignored only in Scipy." + .into(), + }); + } + + // This is here only because impl GetWindow on GetWindowBuilder is non-trivial to fetch numtaps + // outside the struct/enum-variants. + // if let Some(w) = *window { + // if w.m != numtaps { + // return Err(Error::ConfictArg { + // arg: "window".into(), + // reason: "Window has m value differing from numtaps", + // }); + // } + // } + Ok(()) +} + +/// FIR filter design using the window method. +/// +/// This function computes the coefficients of a finite impulse response filter. The filter will +/// have linear phase; it will be Type I if `numtaps` is odd and Type II if `numtaps` is even. +/// +/// Type II filters always have zero response at the Nyquist frequency, so a [error::Error] is +/// returned if firwin is called with `numtaps` even and having a passband whose right end is at +/// the Nyquist frequency. +/// +/// # Parameters +/// * `numtaps`: usize +/// Length of the filter (number of coefficients, i.e. the filter order + 1). `numtaps` must +/// be odd if a passband includes the Nyquist frequency. +/// * `cutoff`: 1-D array_like +/// Cutoff frequency of filter (expressed in the same units as `fs`) OR an array of cutoff +/// frequencies (that is, band edges). In the latter case, the frequencies in `cutoff` should +/// be positive and monotonically increasing between 0 and `fs/2`. The values 0 and `fs/2` must +/// not be included in `cutoff`. +/// * `width`: float or None, optional +/// If `width` is not None, then assume it is the approximate width of the transition region +/// (expressed in the same units as `fs`) for use in Kaiser FIR filter design. In this case, +/// the `window` argument is ignored. +/// * `window` : string or tuple of string and parameter values, optional +/// Desired window to use. See [GetWindow] for a list of windows and required parameters. +/// Defaults to Hamming. +/// Please set `sym=True` if you wish to have similar behaviour to scipy. +/// * `pass_zero` : [FilterBandType], optional +/// If True, the gain at the frequency 0 (i.e., the "DC gain") is 1. +/// If False, the DC gain is +/// 0. Can also be a string argument for the desired filter type (equivalent to ``btype`` in +/// [IIR design functions](iirfilter_dyn)). +/// +/// * `scale` : bool, optional +/// Set to True to scale the coefficients so that the frequency response is exactly unity at a +/// certain frequency. That frequency is either: +/// +/// * 0 (DC) if the first passband starts at 0 (i.e. pass_zero +/// is True) +/// * `fs/2` (the Nyquist frequency) if the first passband ends at +/// `fs/2` (i.e the filter is a single band highpass filter); +/// center of first passband otherwise +/// +/// * fs : float, optional +/// The sampling frequency of the signal. Each frequency in `cutoff` +/// must be between 0 and ``fs/2``. Default is 2. +/// +/// Returns +/// ------- +/// `h` : Result<`W`, _> +/// Coefficients of length `numtaps` FIR filter. +/// +/// Raises +/// ------ +/// Error : +/// If any value in `cutoff` is less than or equal to 0 or greater than or equal to ``fs/2``, +/// if the values in `cutoff` are not strictly monotonically increasing, or if `numtaps` is +/// even but a passband includes the Nyquist frequency. +/// +/// See Also +/// -------- +/// firwin2 +/// firls +/// minimum_phase +/// remez +/// +/// Examples +/// -------- +/// * Low-pass from 0 to f: +/// +/// ```custom,{class=language-python} +/// >>> from scipy import signal +/// >>> numtaps = 3 +/// >>> f = 0.1 +/// >>> signal.firwin(numtaps, f) +/// array([ 0.06799017, 0.86401967, 0.06799017]) +/// ``` +/// Sci-rs: +/// ``` +/// use approx:: assert_abs_diff_eq; +/// use sci_rs::signal::filter::design::{firwin_dyn, FilterBandType}; +/// use sci_rs::signal::windows::Hamming; +/// +/// let window: Vec = firwin_dyn( +/// 3, +/// &[0.1f64], +/// None, +/// None::<&Hamming>, +/// &FilterBandType::Lowpass, +/// None, +/// None, +/// ) +/// .unwrap(); +/// let expected = vec![0.06799017, 0.86401967, 0.06799017]; +/// +/// fn assert_vec_eq(a: Vec, b: Vec) { +/// for (a, b) in a.into_iter().zip(b) { +/// assert_abs_diff_eq!(a, b, epsilon = 1e-6); +/// } +/// } +/// +/// assert_vec_eq(window, expected); +/// ``` +/// +/// Use a specific window function: +/// +/// ```custom,{class=language-python} +/// >>> signal.firwin(numtaps, f, window='nuttall') +/// array([ 3.56607041e-04, 9.99286786e-01, 3.56607041e-04]) +/// ``` +/// Sci-rs: +// TODO: This still needs work so that its not Some(&Nuttall::new(...)) which might be mistakenly +// different from what's above it. +/// ``` +/// use approx::assert_abs_diff_eq; +/// use sci_rs::signal::filter::design::{firwin_dyn, FilterBandType}; +/// use sci_rs::signal::windows::Nuttall; +/// +/// let window: Vec = firwin_dyn( +/// 3, +/// &[0.1f64], +/// None, +/// Some(&Nuttall::new(3, true)), +/// &FilterBandType::Lowpass, +/// None, +/// None, +/// ) +/// .unwrap(); +/// let expected = vec![3.56607041e-04, 9.99286786e-01, 3.56607041e-04]; +/// +/// fn assert_vec_eq(a: Vec, b: Vec) { +/// for (a, b) in a.into_iter().zip(b) { +/// assert_abs_diff_eq!(a, b, epsilon = 1e-6); +/// } +/// } +/// +/// assert_vec_eq(window, expected); +/// ``` +/// +/// High-pass ('stop' from 0 to f): +/// +/// ```custom,{class=language-python} +/// >>> signal.firwin(numtaps, f, pass_zero=False) +/// array([-0.00859313, 0.98281375, -0.00859313]) +/// ``` +/// Sci-rs: +/// ``` +/// use approx::assert_abs_diff_eq; +/// use sci_rs::signal::filter::design::{firwin_dyn, FilterBandType}; +/// use sci_rs::signal::windows::Hamming; +/// +/// let window: Vec = firwin_dyn( +/// 3, +/// &[0.1f64], +/// None, +/// None::<&Hamming>, +/// &FilterBandType::Highpass, +/// None, +/// None, +/// ) +/// .unwrap(); +/// let expected = vec![-0.00859313, 0.98281375, -0.00859313]; +/// +/// fn assert_vec_eq(a: Vec, b: Vec) { +/// for (a, b) in a.into_iter().zip(b) { +/// assert_abs_diff_eq!(a, b, epsilon = 1e-6); +/// } +/// } +/// +/// assert_vec_eq(window, expected); +/// ``` +/// +/// Band-pass: +/// +/// ```custom,{class=language-python} +/// >>> f1, f2 = 0.1, 0.2 +/// >>> signal.firwin(numtaps, [f1, f2], pass_zero=False) +/// array([ 0.06301614, 0.88770441, 0.06301614]) +/// ``` +/// Sci-rs: +/// ``` +/// use approx::assert_abs_diff_eq; +/// use sci_rs::signal::filter::design::{firwin_dyn, FilterBandType}; +/// use sci_rs::signal::windows::Hamming; +/// +/// let window: Vec = firwin_dyn( +/// 3, +/// &[0.1f64, 0.2f64], +/// None, +/// None::<&Hamming>, +/// &FilterBandType::Bandpass, +/// None, +/// None, +/// ) +/// .unwrap(); +/// let expected = vec![ 0.06301614, 0.88770441, 0.06301614]; +/// +/// fn assert_vec_eq(a: Vec, b: Vec) { +/// for (a, b) in a.into_iter().zip(b) { +/// assert_abs_diff_eq!(a, b, epsilon = 1e-6); +/// } +/// } +/// +/// assert_vec_eq(window, expected); +/// ``` +/// +/// Band-stop: +/// +/// ```custom,{class=language-python} +/// >>> signal.firwin(numtaps, [f1, f2]) +/// array([-0.00801395, 1.0160279 , -0.00801395]) +/// ``` +/// Sci-rs: +/// ``` +/// use approx::assert_abs_diff_eq; +/// use sci_rs::signal::filter::design::{firwin_dyn, FilterBandType}; +/// use sci_rs::signal::windows::Hamming; +/// +/// let window: Vec = firwin_dyn( +/// 3, +/// &[0.1f64, 0.2f64], +/// None, +/// None::<&Hamming>, +/// &FilterBandType::Bandstop, +/// None, +/// None, +/// ) +/// .unwrap(); +/// let expected = vec![-0.00801395, 1.0160279 , -0.00801395]; +/// +/// fn assert_vec_eq(a: Vec, b: Vec) { +/// for (a, b) in a.into_iter().zip(b) { +/// assert_abs_diff_eq!(a, b, epsilon = 1e-6); +/// } +/// } +/// +/// assert_vec_eq(window, expected); +/// ``` +/// +/// Multi-band (passbands are `[0, f1]`, `[f2, f3]` and `[f4, 1]`): +/// +/// ```custom,{class=language-python} +/// >>> f3, f4 = 0.3, 0.4 +/// >>> signal.firwin(numtaps, [f1, f2, f3, f4]) +/// array([-0.01376344, 1.02752689, -0.01376344]) +/// ``` +/// Sci-rs: +/// ``` +/// use approx::assert_abs_diff_eq; +/// use sci_rs::signal::filter::design::{firwin_dyn, FilterBandType}; +/// use sci_rs::signal::windows::Hamming; +/// +/// let window: Vec = firwin_dyn( +/// 3, +/// &[0.1f64, 0.2f64, 0.3f64, 0.4f64], +/// None, +/// None::<&Hamming>, +/// &FilterBandType::Bandstop, +/// None, +/// None, +/// ) +/// .unwrap(); +/// let expected = vec![-0.01376344, 1.02752689, -0.01376344]; +/// +/// fn assert_vec_eq(a: Vec, b: Vec) { +/// for (a, b) in a.into_iter().zip(b) { +/// assert_abs_diff_eq!(a, b, epsilon = 1e-6); +/// } +/// } +/// +/// assert_vec_eq(window, expected); +/// ``` +/// +/// Multi-band (passbands are `[f1, f2]` and `[f3,f4]`): +/// +/// ```custom,{class=language-python} +/// >>> signal.firwin(numtaps, [f1, f2, f3, f4], pass_zero=False) +/// array([ 0.04890915, 0.91284326, 0.04890915]) +/// ``` +/// Sci-rs: +/// ``` +/// use approx::assert_abs_diff_eq; +/// use sci_rs::signal::filter::design::{firwin_dyn, FilterBandType}; +/// use sci_rs::signal::windows::Hamming; +/// +/// let window: Vec = firwin_dyn( +/// 3, +/// &[0.1f64, 0.2f64, 0.3f64, 0.4f64], +/// None, +/// None::<&Hamming>, +/// &FilterBandType::Bandpass, +/// None, +/// None, +/// ) +/// .unwrap(); +/// let expected = vec![0.04890915, 0.91284326, 0.04890915]; +/// +/// fn assert_vec_eq(a: Vec, b: Vec) { +/// for (a, b) in a.into_iter().zip(b) { +/// assert_abs_diff_eq!(a, b, epsilon = 1e-6); +/// } +/// } +/// +/// assert_vec_eq(window, expected); +/// ``` +// In accordance with https://github.com/scipy/scipy/pull/16315, nyq as an argument should not be +// provided. +#[cfg(feature = "alloc")] +pub fn firwin_dyn( + numtaps: usize, + cutoff: &[F], + width: Option, + window: Option<&impl GetWindow>, + pass_zero: &FilterBandType, // Union with bools to follow scipy? + scale: Option, + fs: Option, +) -> Result, Error> +where + W: Real + Float + RealField + special::Bessel, + F: Real + PartialOrd + Float + RealField + MulAdd + Pow, +{ + let fs = match fs { + None => F::from(2).unwrap(), + Some(x) => x, + }; + let nyq = fs / F::from(2).unwrap(); + let cutoff: Vec<_> = cutoff.iter().map(|&c| c / nyq).collect(); + + firwin_dyn_validate(&numtaps, &cutoff, pass_zero, &width, &window)?; + + // Get and apply the window function. + let win: Vec = if let Some(x) = window { + // ?warn if window.sym != true + x.get_window() + } else if let Some(width) = width { + let atten = kaiser_atten(numtaps.try_into().unwrap(), width / nyq); + let beta = kaiser_beta(atten); + let k = get_window(GetWindowBuilder::Kaiser { beta }, numtaps, Some(false)); + k.get_window() + } else { + let h = get_window(GetWindowBuilder::::Hamming, numtaps, Some(false)); + h.get_window() + }; + + let pass_zero: bool = match pass_zero { + FilterBandType::Lowpass => true, + FilterBandType::Bandstop => true, + FilterBandType::Highpass => false, + FilterBandType::Bandpass => false, + }; + let pass_nyquist = (cutoff.len() % 2 == 1) ^ pass_zero; + if pass_nyquist && numtaps.is_multiple_of(2) { + return Err(Error::InvalidArg { + arg: "numtaps".into(), + reason: "A filter with an even number of coefficients must have zero response at the Nyquist frequency." + .into(), + }); + } + + let cutoff: Vec = { + let mut tmp = Vec::::new(); + if pass_zero { + tmp.push(F::zero()); + } + tmp.extend_from_slice(&cutoff); + if pass_nyquist { + tmp.push(F::one()); + } + tmp + }; + if !cutoff.len().is_multiple_of(2) { + unreachable!(); + // return Err(Error::InvalidArg { + // arg: "cutoff".into(), + // reason: "Parity of cutoff given for type of Filter is incorrect.".into(), + // }); + } + let bands: Vec<_> = cutoff.chunks_exact(2).collect(); + let scale_frequency = { + // for use in scale branch later + let left = bands[0][0]; + let right = bands[0][1]; + if left == F::zero() { + F::zero() + } else if right == F::one() { + F::one() + } else { + F::from(0.5).unwrap() * (left + right) + } + }; + + // Build up the coefficients. + // alpha: scalar = 0.5 * (numtaps - 1) + let alpha = F::from(0.5).unwrap() * F::from(numtaps - 1).unwrap(); + // m: 1Darray[numtaps] = 0..numtaps - alpha // lifetimes + // h: 1Darray[numtaps] = sum([right * sinc(right *m) - left * sinc(left * m)) for (left, right) + // in bands]) + let h: Vec = bands + .into_iter() + .map(|b| { + let left = b[0]; + let right = b[1]; + let m = (0..numtaps).map(|mi| F::from(mi).unwrap() - alpha); + let h: Vec = m + .map(|mi| { + if !mi.is_zero() { + right * (right * mi * F::pi()).sinc() - left * (left * mi * F::pi()).sinc() + } else { + right - left + } + }) + .collect(); + h + }) + .fold(vec![W::zero(); numtaps], |mut acc, x| { + acc.iter_mut() + .zip(x) + .map(|(a, xi)| *a += W::from(xi).unwrap()) + .collect::<()>(); + acc + }); + let mut h: Vec = h.into_iter().zip(win).map(|(hi, wi)| hi * wi).collect(); + + let scale = scale.unwrap_or(true); + if scale { + // m: 1Darray[numtaps] = 0..numtaps - alpha // lifetimes + let m = (0..numtaps).map(|mi| F::from(mi).unwrap() - alpha); + // c: 1Darray[numtaps] = np.cos(np.pi * m * scale_frequency) + let c = m + .map(|mi| F::pi() * mi * scale_frequency) + .map(|x| W::from(Real::cos(x)).unwrap()); + // s: scalar = np.sum(h*c) + let s: W = h + .iter() + .zip(c) + .map(|(&hi, ci)| hi * ci) + .fold(W::zero(), |acc, x| acc + x); + // h: 1Darray[numtaps] = h/s + h.iter_mut().map(|hi| *hi /= s).collect::<()>(); + } + + Ok(h) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::signal::windows; + use approx::assert_abs_diff_eq; + + #[test] + fn invalid_args() { + type E = crate::error::Error; + { + let cutoff = Vec::::new(); + let empty_cutoff: Result, E> = firwin_dyn( + 3, + &cutoff, + None, + None::<&Hamming>, + &FilterBandType::Lowpass, + None, + None, + ); + assert_eq!( + empty_cutoff.unwrap_err(), + E::InvalidArg { + arg: "cutoff".into(), + reason: "At least one cutoff frequency must be given.".into() + } + ); + } + { + let cutoff: Vec = vec![-0.1, 3.]; + let invalid_cutoff: Result, E> = firwin_dyn( + 3, + &cutoff, + None, + None::<&Hamming>, + &FilterBandType::Bandpass, + None, + None, + ); + assert_eq!( + invalid_cutoff.unwrap_err(), + E::InvalidArg { + arg: "cutoff".into(), + reason: "Invalid cutoff frequency: frequencies must be greater than 0 and less than fs/2.".into() + } + ); + } + { + let cutoff: Vec = vec![0.2, 0.2]; + let decreasing_cutoff: Result, E> = firwin_dyn( + 3, + &cutoff, + None, + None::<&Hamming>, + &FilterBandType::Bandpass, + None, + None, + ); + assert_eq!( + decreasing_cutoff.unwrap_err(), + E::InvalidArg { + arg: "cutoff".into(), + reason: + "Invalid cutoff frequencies: the frequencies must be strictly increasing." + .into() + } + ); + } + { + // numtaps = 0 has ValueError in Python + let cutoff: Vec = vec![0.2, 0.2]; + let decreasing_cutoff: Result, E> = firwin_dyn( + 0, + &cutoff, + None, + None::<&Hamming>, + &FilterBandType::Bandpass, + None, + None, + ); + assert_eq!( + decreasing_cutoff.unwrap_err(), + E::InvalidArg { + arg: "numtaps".into(), + reason: "Invalid numtaps: Nonzero-numtaps is expected!".into() + } + ); + } + { + // Lowpass + let cutoff: Vec = vec![0.2, 0.7]; + let lowpass_invalid_cutoff: Result, E> = firwin_dyn( + 3, + &cutoff, + None, + None::<&Hamming>, + &FilterBandType::Lowpass, + None, + None, + ); + assert_eq!( + lowpass_invalid_cutoff.unwrap_err(), + E::InvalidArg { + arg: "cutoff".into(), + reason: "cutoff must have one element if pass_zero is Lowpass.".into() + } + ); + } + { + // Bandstop + let cutoff: Vec = vec![0.7]; + let bandstop_invalid_cutoff: Result, E> = firwin_dyn( + 3, + &cutoff, + None, + None::<&Hamming>, + &FilterBandType::Bandstop, + None, + None, + ); + assert_eq!( + bandstop_invalid_cutoff.unwrap_err(), + E::InvalidArg { + arg: "cutoff".into(), + reason: "cutoff must have at least two elements if pass_zero is Bandstop." + .into() + } + ); + } + { + // Highpass + let cutoff: Vec = vec![0.2, 0.7]; + let highpass_invalid_cutoff: Result, E> = firwin_dyn( + 3, + &cutoff, + None, + None::<&Hamming>, + &FilterBandType::Highpass, + None, + None, + ); + assert_eq!( + highpass_invalid_cutoff.unwrap_err(), + E::InvalidArg { + arg: "cutoff".into(), + reason: "cutoff must have one element if pass_zero is Highpass.".into() + } + ); + } + { + // Bandpass + let cutoff: Vec = vec![0.2]; + let bandpass_invalid_cutoff: Result, E> = firwin_dyn( + 3, + &cutoff, + None, + None::<&Hamming>, + &FilterBandType::Bandpass, + None, + None, + ); + assert_eq!( + bandpass_invalid_cutoff.unwrap_err(), + E::InvalidArg { + arg: "cutoff".into(), + reason: "cutoff must have at least two elements if pass_zero is Bandpass." + .into() + } + ); + } + { + let cutoff: Vec = vec![0.2, 0.7]; + let invalid_numtaps: Result, E> = firwin_dyn( + 4, + &cutoff, + None, + None::<&Hamming>, + &FilterBandType::Bandstop, + None, + None, + ); + assert_eq!( + invalid_numtaps.unwrap_err(), + E::InvalidArg { + arg: "numtaps".into(), + reason: "A filter with an even number of coefficients must have zero response at the Nyquist frequency.".into() + } + ); + } + } + + #[test] + fn conflicting_args() { + type E = crate::error::Error; + { + let cutoff: Vec = vec![0.2, 0.5, 0.7]; + let window = Hamming::new(12, true); + let conflicting: Result, E> = firwin_dyn( + 5, + &cutoff, + Some(0.3), + Some(&window), + &FilterBandType::Bandpass, + None, + None, + ); + assert_eq!( + conflicting.unwrap_err(), + E::InvalidArg { + arg: "window/width".into(), + reason: "Setting both window and width to something is silently ignored only in Scipy." + .into() + } + ); + } + } + + #[test] + fn bandstop() { + // from scipy.signal import firwin + // firwin(numtaps=5, cutoff=[0.2, 0.7], width=None, # window=None, + // pass_zero='bandstop', scale=True, fs=None) + let cutoff: Vec = vec![0.2, 0.7]; + let window: Result, _> = firwin_dyn( + 5, + &cutoff, + None, + None::<&Hamming>, + &FilterBandType::Bandstop, + None, + None, + ); + let expected = vec![0.05126868, -0.08050021, 1.05846306, -0.08050021, 0.05126868]; + assert_vec_eq(expected, window.unwrap()); + } + + #[test] + fn bandpass() { + // from scipy.signal import firwin + // firwin(numtaps=5, cutoff=[0.2, 0.7], width=None, # window=None, + // pass_zero='bandpass', scale=True, fs=None) + let cutoff: Vec = vec![0.2, 0.7]; + let window: Result, _> = firwin_dyn( + 5, + &cutoff, + None, + None::<&Hamming>, + &FilterBandType::Bandpass, + None, + None, + ); + let expected = vec![-0.04340507, 0.06815306, 0.89611567, 0.06815306, -0.04340507]; + assert_vec_eq(expected, window.unwrap()); + } + + #[test] + fn lowpass() { + // from scipy.signal import firwin + // firwin(numtaps=5, cutoff=[0.2], width=None, # window=None, + // pass_zero='lowpass', scale=True, fs=None) + let cutoff: Vec = vec![0.2]; + let window: Result, _> = firwin_dyn( + 5, + &cutoff, + None, + None::<&Hamming>, + &FilterBandType::Lowpass, + None, + None, + ); + let expected = vec![0.02840647, 0.23700821, 0.46917063, 0.23700821, 0.02840647]; + assert_vec_eq(expected, window.unwrap()); + } + + #[test] + fn highpass() { + // from scipy.signal import firwin + // firwin(numtaps=5, cutoff=[0.2], width=None, # window=None, + // pass_zero='highpass', scale=True, fs=None) + let cutoff: Vec = vec![0.2]; + let window: Result, _> = firwin_dyn( + 5, + &cutoff, + None, + None::<&Hamming>, + &FilterBandType::Highpass, + None, + None, + ); + let expected = vec![-0.01238356, -0.1033217, 0.81812371, -0.1033217, -0.01238356]; + assert_vec_eq(expected, window.unwrap()); + } + + #[test] + fn different_fs() { + // from scipy.signal import firwin + // firwin(numtaps=5, cutoff=[2], width=None, # window=None, + // pass_zero='lowpass', scale=True, fs=10) + let cutoff: Vec = vec![2.]; + let window: Result, _> = firwin_dyn( + 5, + &cutoff, + None, + None::<&Hamming>, + &FilterBandType::Lowpass, + None, + Some(10.), + ); + let expected = vec![0.01008727, 0.22034079, 0.53914388, 0.22034079, 0.01008727]; + assert_vec_eq(expected, window.unwrap()); + } + + #[track_caller] + fn assert_vec_eq(a: Vec, b: Vec) { + for (a, b) in a.into_iter().zip(b) { + assert_abs_diff_eq!(a, b, epsilon = 1e-6); + } + } +} diff --git a/sci-rs/src/signal/filter/design/kaiser.rs b/sci-rs/src/signal/filter/design/kaiser.rs new file mode 100644 index 00000000..c2c1fb43 --- /dev/null +++ b/sci-rs/src/signal/filter/design/kaiser.rs @@ -0,0 +1,244 @@ +use nalgebra::RealField; +use num_traits::{real::Real, MulAdd, Pow, ToPrimitive}; + +/// Compute the Kaiser parameter `beta`, given the attenuation `a`. +/// +/// # Parameters +/// * `a`: float +/// The desired attenuation in the stopband and maximum ripple in the passband, in dB. This +/// should be a *positive* number. +/// +/// # Returns +/// * `beta`: float +/// The `beta` parameter to be used in the formula for a Kaiser window. +/// +/// # References +/// Oppenheim, Schafer, "Discrete-Time Signal Processing", p.475-476. +/// +/// Examples +/// -------- +/// Suppose we want to design a lowpass filter, with 65 dB attenuation +/// in the stop band. The Kaiser window parameter to be used in the +/// window method is computed by ``kaiser_beta(65.)``: +/// +/// ``` +/// use sci_rs::signal::filter::design::kaiser_beta; +/// assert_eq!(6.20426, kaiser_beta(65.)); +/// ``` +/// # See Also +/// [kaiser_atten], [kaiserord] +pub fn kaiser_beta(a: F) -> F +where + F: Real + MulAdd + Pow, + >::Output: MulAdd, +{ + if a > F::from(50).unwrap() { + F::from(0.1102).unwrap() * (a - F::from(8.7).unwrap()) + } else if a > F::from(21).unwrap() { + let a = a - F::from(21).unwrap(); + + MulAdd::mul_add( + a.pow(F::from(0.4).unwrap()), + F::from(0.5842).unwrap(), + F::from(0.07886).unwrap() * a, + ) + } else if a > F::zero() { + F::zero() + } else { + panic!("Expected a positive input.") + } +} + +/// Compute the attenuation of a Kaiser FIR filter. +/// +/// Given the number of taps `N` and the transition width `width`, compute the +/// attenuation `a` in dB, given by Kaiser's formula: +/// ```custom +/// a = 2.285 * (N - 1) * pi * width + 7.95 +/// ``` +/// +/// # Parameters +/// * `numtaps`: int +/// The number of taps in the FIR filter. +/// * `width`: float +/// The desired width of the transition region between passband and +/// stopband (or, in general, at any discontinuity) for the filter, +/// expressed as a fraction of the Nyquist frequency. +/// +/// # Returns +/// `a`: The attenuation of the ripple, in dB. +/// +/// # Examples +/// Suppose we want to design a FIR filter using the Kaiser window method +/// that will have 211 taps and a transition width of 9 Hz for a signal that +/// is sampled at 480 Hz. Expressed as a fraction of the Nyquist frequency, +/// the width is 9/(0.5*480) = 0.0375. The approximate attenuation (in dB) +/// is computed as follows: +/// ``` +/// use sci_rs::signal::filter::design::kaiser_atten; +/// assert_eq!(64.48099630593983 , kaiser_atten(211, 0.0375)); +/// ``` +/// +/// # See Also +/// [kaiserord], [kaiser_beta] +pub fn kaiser_atten(numtaps: u32, width: F) -> F +where + F: Real + MulAdd + RealField, +{ + MulAdd::mul_add( + width, + F::from(numtaps - 1).unwrap() * F::from(2.285).unwrap() * F::pi(), + F::from(7.95).unwrap(), + ) +} + +/// Determine the filter window parameters for the Kaiser window method. +/// +/// The parameters returned by this function are generally used to create +/// a finite impulse response filter using the window method, with either +/// `firwin` or `firwin2`. +/// +/// # Parameters +/// * `ripple`: float +/// Upper bound for the deviation (in dB) of the magnitude of the +/// filter's frequency response from that of the desired filter (not +/// including frequencies in any transition intervals). That is, if w +/// is the frequency expressed as a fraction of the Nyquist frequency, +/// A(w) is the actual frequency response of the filter and D(w) is the +/// desired frequency response, the design requirement is that: +/// ```abs(A(w) - D(w))) < 10**(-ripple/20)``` +/// for 0 <= w <= 1 and w not in a transition interval. +/// * `width`: float +/// Width of transition region, normalized so that 1 corresponds to pi +/// radians / sample. That is, the frequency is expressed as a fraction +/// of the Nyquist frequency. +/// +/// # Returns +/// * `numtaps`: int +/// The length of the Kaiser window. +/// * `beta`: float +/// The beta parameter for the Kaiser window. +/// +/// # Notes +/// There are several ways to obtain the Kaiser window: +/// +/// - ``signal.windows.kaiser(numtaps, beta, sym=True)`` +/// - ``signal.get_window(beta, numtaps)`` +/// - ``signal.get_window(('kaiser', beta), numtaps)`` +/// +/// The empirical equations discovered by Kaiser are used. +/// +/// # References +/// Oppenheim, Schafer, "Discrete-Time Signal Processing", pp.475-476. +/// +/// # Scipy Example +/// We will use the Kaiser window method to design a lowpass FIR filter +/// for a signal that is sampled at 1000 Hz. +/// +/// We want at least 65 dB rejection in the stop band, and in the pass +/// band the gain should vary no more than 0.5%. +/// +/// We want a cutoff frequency of 175 Hz, with a transition between the +/// pass band and the stop band of 24 Hz. That is, in the band [0, 163], +/// the gain varies no more than 0.5%, and in the band [187, 500], the +/// signal is attenuated by at least 65 dB. +/// +/// ```custom,{class=language-python} +/// >>> import numpy as np +/// >>> from scipy.signal import kaiserord, firwin, freqz +/// >>> import matplotlib.pyplot as plt +/// >>> fs = 1000.0 +/// >>> cutoff = 175 +/// >>> width = 24 +/// ``` +/// +/// The Kaiser method accepts just a single parameter to control the pass +/// band ripple and the stop band rejection, so we use the more restrictive +/// of the two. In this case, the pass band ripple is 0.005, or 46.02 dB, +/// so we will use 65 dB as the design parameter. +/// +/// Use `kaiserord` to determine the length of the filter and the +/// parameter for the Kaiser window. +/// +/// ```custom,{class=language-python} +/// >>> numtaps, beta = kaiserord(65, width/(0.5*fs)) +/// >>> numtaps +/// 167 +/// >>> beta +/// 6.20426 +/// ``` +/// +/// Use `firwin` to create the FIR filter. +/// +/// ```custom,{class=language-python} +/// >>> taps = firwin(numtaps, cutoff, window=('kaiser', beta), +/// ... scale=False, fs=fs) +/// ``` +/// +/// Compute the frequency response of the filter. ``w`` is the array of +/// frequencies, and ``h`` is the corresponding complex array of frequency +/// responses. +/// +/// ```custom,{class=language-python} +/// >>> w, h = freqz(taps, worN=8000) +/// >>> w *= 0.5*fs/np.pi # Convert w to Hz. +/// ``` +/// +/// Compute the deviation of the magnitude of the filter's response from +/// that of the ideal lowpass filter. Values in the transition region are +/// set to ``nan``, so they won't appear in the plot. +/// +/// ```custom,{class=language-python} +/// >>> ideal = w < cutoff # The "ideal" frequency response. +/// >>> deviation = np.abs(np.abs(h) - ideal) +/// >>> deviation[(w > cutoff - 0.5*width) & (w < cutoff + 0.5*width)] = np.nan +/// ``` +/// +/// Plot the deviation. A close look at the left end of the stop band shows +/// that the requirement for 65 dB attenuation is violated in the first lobe +/// by about 0.125 dB. This is not unusual for the Kaiser window method. +/// +/// ```custom,{class=language-python} +/// >>> plt.plot(w, 20*np.log10(np.abs(deviation))) +/// >>> plt.xlim(0, 0.5*fs) +/// >>> plt.ylim(-90, -60) +/// >>> plt.grid(alpha=0.25) +/// >>> plt.axhline(-65, color='r', ls='--', alpha=0.3) +/// >>> plt.xlabel('Frequency (Hz)') +/// >>> plt.ylabel('Deviation from ideal (dB)') +/// >>> plt.title('Lowpass Filter Frequency Response') +/// >>> plt.show() +/// ``` +/// +/// See Also +/// -------- +/// [kaiser_beta], [kaiser_atten] +/// +pub fn kaiserord(ripple: F, width: F) -> (F, F) +where + F: Real + MulAdd + Pow + RealField, +{ + let a = Real::abs(ripple); + if a < F::from(8).unwrap() { + panic!("Requested maximum ripple attenuation is too small for the Kaiser formula."); + } + let beta = kaiser_beta(a); + let numtaps = + F::one() + (a - F::from(7.95).unwrap()) / (F::from(2.285).unwrap() * F::pi() * width); + (Real::ceil(numtaps), beta) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn kaiserord_test_a() { + let fs = 1000.; + let cutoff = 175.; + let width = 24.; + let ripple = 65.; + + assert_eq!((167., 6.20426), kaiserord(ripple, width / (0.5 * fs))); + } +} diff --git a/sci-rs/src/signal/filter/design/mod.rs b/sci-rs/src/signal/filter/design/mod.rs index 9079bec7..192bba47 100644 --- a/sci-rs/src/signal/filter/design/mod.rs +++ b/sci-rs/src/signal/filter/design/mod.rs @@ -3,7 +3,9 @@ mod butter; mod cplx; mod filter_output; mod filter_type; +mod firwin; mod iirfilter; +mod kaiser; mod lp2bp_zpk; mod lp2bs_zpk; mod lp2hp_zpk; @@ -18,7 +20,9 @@ pub use butter::*; use cplx::*; pub use filter_output::*; pub use filter_type::*; +pub use firwin::*; pub use iirfilter::*; +pub use kaiser::*; pub use lp2bp_zpk::*; pub use lp2bs_zpk::*; pub use lp2hp_zpk::*; diff --git a/sci-rs/src/signal/filter/lfilter.rs b/sci-rs/src/signal/filter/lfilter.rs new file mode 100644 index 00000000..32abdfb5 --- /dev/null +++ b/sci-rs/src/signal/filter/lfilter.rs @@ -0,0 +1,992 @@ +use super::arraytools::{check_and_get_axis_dyn, check_and_get_axis_st, ndarray_shape_as_array_st}; +use alloc::{vec, vec::Vec}; +use core::marker::Copy; +use ndarray::{ + Array, Array1, ArrayBase, ArrayD, ArrayView, ArrayView1, Axis, Data, Dim, Dimension, + IntoDimension, Ix, IxDyn, ShapeBuilder, SliceArg, SliceInfo, SliceInfoElem, +}; +use num_traits::{FromPrimitive, Num, NumAssign}; +use sci_rs_core::{Error, Result}; + +type LFilterResult = (Array>, Option>>); +type LFilterDynResult = (Array, Option>); + +/// Implement lfilter for fixed dimension of input array `x`. +/// +/// Valid only from 1 to 6 dimensional arrays. +pub trait LFilter +where + S: Data, +{ + /// Filter data `x` along one-dimension with an IIR or FIR filter. + /// + /// Filter a data sequence, `x`, using a digital filter. This works for many + /// fundamental data types (including Object type). The filter is a direct + /// form II transposed implementation of the standard difference equation + /// (see Notes). + /// + /// The function [super::sosfilt_dyn] (and filter design using ``output='sos'``) should be + /// preferred over `lfilter` for most filtering tasks, as second-order sections + /// have fewer numerical problems. + /// + /// ## Parameters + /// * `b` : array_like + /// The numerator coefficient vector in a 1-D sequence. + /// * `a` : array_like + /// The denominator coefficient vector in a 1-D sequence. If ``a[0]`` + /// is not 1, then both `a` and `b` are normalized by ``a[0]``. + /// * `x` : array_like + /// An N-dimensional input array. + /// * `axis`: `Option` + /// Default to `-1` if `None`. + /// Panics in accordance with [ndarray::ArrayBase::axis_iter]. + /// * `zi`: array_like + /// Currently not implemented. + /// Initial conditions for filter delays. It is a vector + /// (or array of vectors for an N-dimensional input) of length + /// ``max(len(a), len(b)) - 1``. If `zi` is None or is not given then + /// initial rest is assumed. See `lfiltic` and [super::lfilter_zi_dyn] for more information. + /// + /// ## Returns + /// * `y` : array + /// The output of the digital filter. + /// * `zf` : array, optional + /// If `zi` is None, this is not returned, otherwise, `zf` holds the + /// final filter delay values. + /// + /// # See Also + /// * [super::lfilter_zi_dyn] + /// + /// # Notes + /// For compile time reasons, lfilter is implemented per ArrayN at the moment. + /// + /// # Examples + /// On a 1-dimensional signal: + /// ``` + /// use ndarray::{array, ArrayBase, Array1, ArrayView1, Dim, Ix, OwnedRepr}; + /// use sci_rs::signal::filter::LFilter; + /// + /// let b = array![5., 4., 1., 2.]; + /// let a = array![1.]; + /// let x = array![1., 2., 3., 4., 3., 5., 6.]; + /// let expected = array![5., 14., 24., 36., 38., 47., 61.]; + /// let (result, _) = ArrayView1::lfilter((&b).into(), (&a).into(), (&x).into(), None, None).unwrap(); // By ref + /// + /// assert_eq!(result.len(), expected.len()); + /// result.into_iter().zip(expected).for_each(|(r, e)| { + /// assert_eq!(r, e); + /// }); + /// + /// let (result, _) = Array1::lfilter((&b).into(), (&a).into(), x, None, None).unwrap(); // By value + /// ``` + /// + /// # Panics + /// Currently yet to implement for `a.len() > 1`. + // NOTE: zi's TypeSig inherits from lfilter's output, in accordance with examples section of + // documentation, both lfilter_zi and this should eventually support NDArray. + fn lfilter<'a>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + zi: Option>>, + ) -> Result> + where + T: NumAssign + FromPrimitive + Copy + 'a, + S: Data + 'a; +} + +macro_rules! lfilter_for_dim { + ($N:literal) => { + impl LFilter for ArrayBase> + where + S: Data, + { + fn lfilter<'a>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: Self, + axis: Option, + zi: Option>>, + ) -> Result<(Array>, Option>>)> + where + T: NumAssign + FromPrimitive + Copy + 'a, + S: 'a, + { + if a.len() > 1 { + return linear_filter(b, a, x, axis, zi); + }; + + let (axis, axis_inner) = { + let ax = check_and_get_axis_st(axis, &x) + .map_err(|_| Error::InvalidArg { + arg: "axis".into(), + reason: "index out of range.".into(), + })?; + (Axis(ax), ax) + }; + + if a.is_empty() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: + "Empty 1D array will result in inf/nan result. Consider setting to `array![1.]`." + .into(), + }); + } else if a.first().unwrap().is_zero() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: "First element of a found to be zero.".into(), + }); + } + let b: Array1 = b.mapv(|bi| bi / a[0]); // b /= a[0] + + if let Some(zii) = zi { + // Use a separate branch to avoid unnecessary heap allocation of `out_full` in `zi` = None + // case. + let mut zi = zii.reborrow(); + + // if zi.ndim != x.ndim { return Err(...) } is signature asserted. + + let mut expected_shape: [usize; $N] = x.shape().try_into().unwrap(); + *expected_shape // expected_shape[axis] = b.shape[0] - 1 + .get_mut(axis_inner) + .expect("invalid axis_inner") = b + .shape() + .first() + .expect("Could not get 0th axis len of b") + .checked_sub(1) + .expect("underflowing subtract"); + + if *zi.shape() != expected_shape { + let strides: [Ix; $N] = { + let zi_shape = zi.shape(); + let zi_strides = zi.strides(); + + // Waiting for try_collect() from nightly... we use this Vec> -> Result> method.. + let tmp_heap: Vec> = (0..$N) + .map(|k| { + if zi_shape[k] == expected_shape[k] { + zi_strides[k].try_into().map_err(|_| Error::InvalidArg { + arg: "zi".into(), + reason: "zi found with negative stride".into(), + }) + } else if k != axis_inner && zi_shape[k] == 1 { + Ok(0) + } else { + Err(Error::InvalidArg { + arg: "zi".into(), + reason: "Unexpected shape for parameter zi".into(), + }) + } + }) + .collect(); + let tmp_heap: Result> = tmp_heap.into_iter().collect(); + + tmp_heap?.try_into().unwrap() + }; + + zi = ArrayView::from_shape(expected_shape.strides(strides), zii.as_slice().unwrap()) + .unwrap(); + }; + + let (out_full_dim, out_full_dim_inner): (Dim<_>, [Ix; $N]) = { + let mut tmp: [Ix; $N] = ndarray_shape_as_array_st(&x); + tmp[axis_inner] += b.len_of(Axis(0)) - 1; // From np.convolve(..., 'full') + (IntoDimension::into_dimension(tmp), tmp) + }; + + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out_full = unsafe { Array::uninit(out_full_dim).assume_init() }; + out_full + .lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_full_slice, y)| { + // np.convolve uses full mode by default + // ```py + // out_full = np.apply_along_axis(lambda y: np.convolve(b, y), axis, x) + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + convolve(y, (&b).into(), ConvolveMode::Full)? + .assign_to(&mut out_full_slice); + Ok(()) + })?; + + // ```py + // ind[axis] = slice(zi.shape[axis]) + // out_full[tuple(ind)] += zi + // ``` + { + let slice_info: SliceInfo<_, Dim<[Ix; $N]>, Dim<[Ix; $N]>> = { + let t = zi.shape()[axis_inner]; + let mut tmp = [SliceInfoElem::from(..); $N]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: 0, + end: Some(t as isize), + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; // Does not work because unless N: N<=6 cannot be bounded on type_sig + let mut s = out_full.slice_mut(&slice_info); + s += &zi; + } + + let (out_dim, out_dim_inner) = { + let tmp: [Ix; $N] = ndarray_shape_as_array_st(&x); + (IntoDimension::into_dimension(tmp), tmp) + }; + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; + out.lanes_mut(axis) + .into_iter() + .zip(out_full.lanes(axis)) + .for_each(|(mut out_slice, out_full_slice)| { + // ```py + // # Create the [...; :out_full.shape[axis] - len(b) + 1; ...] at index=axis + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) + // out = out_full[tuple(ind)] + // ``` + out_full_slice + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + }); + + // ```py + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1, None) + // zf = out_full[tuple(ind)] + // ``` + let zf = { + let slice_info: SliceInfo<_, Dim<[Ix; $N]>, Dim<[Ix; $N]>> = { + let t = out_full.shape()[axis_inner] + .checked_add(1) + .unwrap() + .checked_sub(b.len()) + .unwrap(); + let mut tmp = [SliceInfoElem::from(..); $N]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: t as isize, + end: None, + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; + out_full.slice(slice_info).to_owned() + }; + + Ok((out, Some(zf))) + } else { + // In contrast to the case where zi.is_some(), we can inline a slicing operation to reduce + // one extra heap allocation. + + let (out_dim, out_dim_inner): (Dim<_>, [Ix; $N]) = { + let mut tmp: [Ix; $N] = ndarray_shape_as_array_st(&x); + (IntoDimension::into_dimension(tmp), tmp) + }; + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; + + out.lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_slice, y)| { + // np.convolve uses full mode, but is eventually slices out with + // ```py + // ind = out_full.ndim * [slice(None)] # creates the "[:, :, ..., :]" slice r + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) # [:out_full.shape[ ..] - len(b) + 1] + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + let out_full = convolve(y, (&b).into(), ConvolveMode::Full)?; + out_full + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + Ok(()) + })?; + + Ok((out, None)) + } + } + } + }; +} + +lfilter_for_dim!(1); +lfilter_for_dim!(2); +lfilter_for_dim!(3); +lfilter_for_dim!(4); +lfilter_for_dim!(5); +lfilter_for_dim!(6); + +/// Filter data `x` along one-dimension with an IIR or FIR filter. +/// +/// Filter a data sequence, `x`, using a digital filter. This works for many +/// fundamental data types (including Object type). The filter is a direct +/// form II transposed implementation of the standard difference equation +/// (see Notes). +/// +/// The function [super::sosfilt_dyn] (and filter design using ``output='sos'``) should be +/// preferred over `lfilter` for most filtering tasks, as second-order sections +/// have fewer numerical problems. +/// +/// ## Parameters +/// * `b` : array_like +/// The numerator coefficient vector in a 1-D sequence. +/// * `a` : array_like +/// The denominator coefficient vector in a 1-D sequence. If ``a[0]`` +/// is not 1, then both `a` and `b` are normalized by ``a[0]``. +/// * `x` : array_like +/// An N-dimensional input array. +/// * `axis`: `Option` +/// Default to `-1` if `None`. +/// Panics in accordance with [ndarray::ArrayBase::axis_iter]. +/// * `zi`: array_like +/// Currently not implemented. +/// Initial conditions for filter delays. It is a vector +/// (or array of vectors for an N-dimensional input) of length +/// ``max(len(a), len(b)) - 1``. If `zi` is None or is not given then +/// initial rest is assumed. See `lfiltic` and [super::lfilter_zi_dyn] for more information. +/// +/// ## Returns +/// * `y` : array +/// The output of the digital filter. +/// * `zf` : array, optional +/// If `zi` is None, this is not returned, otherwise, `zf` holds the +/// final filter delay values. +/// +/// # See Also +/// * [super::lfilter_zi_dyn] +/// +/// # Notes +/// If Array<_, IxDyn as provided by this function is not desired, consider using [LFilter]. +/// +/// # Examples +/// On a 1-dimensional signal: +/// ``` +/// use ndarray::{array, ArrayBase, Array1, ArrayView1, Dim, Ix, OwnedRepr}; +/// use sci_rs::signal::filter::lfilter; +/// +/// let b = array![5., 4., 1., 2.]; +/// let a = array![1.]; +/// let x = array![1., 2., 3., 4., 3., 5., 6.]; +/// let expected = array![5., 14., 24., 36., 38., 47., 61.]; +/// let (result, _) = lfilter((&b).into(), (&a).into(), x.view(), None, None).unwrap(); // By ref +/// +/// assert_eq!(result.len(), expected.len()); +/// result.into_iter().zip(expected).for_each(|(r, e)| { +/// assert_eq!(r, e); +/// }); +/// +/// let (result, _) = lfilter((&b).into(), (&a).into(), x.clone().into_dyn(), None, None).unwrap(); // Dynamic arrays +/// let (result, _) = lfilter((&b).into(), (&a).into(), x, None, None).unwrap(); // By value +/// ``` +/// +/// # Panics +/// Currently yet to implement for `a.len() > 1`. +// NOTE: zi's TypeSig inherits from lfilter's output, in accordance with examples section of +// documentation, both lfilter_zi and this should eventually support NDArray. +pub fn lfilter<'a, T, S, D>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: ArrayBase, + axis: Option, + zi: Option>, +) -> Result> +where + S: Data + 'a, + T: NumAssign + FromPrimitive + Copy + 'a, + D: Dimension, + SliceInfo, D, D>: SliceArg, +{ + let ndim = D::NDIM.unwrap_or(x.ndim()); + + if ndim == 0 { + return Err(Error::InvalidArg { + arg: "x".into(), + reason: "Linear filter requires at least 1-dimensional `x`.".into(), + }); + } + + if a.len() > 1 { + todo!(); + }; + + let (axis, axis_inner) = { + let ax = check_and_get_axis_dyn(axis, &x)?; + (Axis(ax), ax) + }; + + if a.is_empty() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: + "Empty 1D array will result in inf/nan result. Consider setting to `array![1.]`." + .into(), + }); + } else if a.first().unwrap().is_zero() { + return Err(Error::InvalidArg { + arg: "a".into(), + reason: "First element of a found to be zero.".into(), + }); + } + let b: Array1 = b.mapv(|bi| bi / a[0]); // b /= a[0] + + if let Some(zii) = zi { + // Use a separate branch to avoid unnecessary heap allocation of `out_full` in `zi` = None + // case. + let mut zi = zii.clone().reborrow().into_dyn(); + + // if zi.ndim != x.ndim { return Err(...) } is signature asserted. + + let mut expected_shape: Vec = x.shape().to_vec(); + *expected_shape // expected_shape[axis] = b.shape[0] - 1 + .get_mut(axis_inner) + .expect("invalid axis_inner") = b + .shape() + .first() + .expect("Could not get 0th axis len of b") + .checked_sub(1) + .expect("underflowing subtract"); + + if *zi.shape() != expected_shape { + let strides: Vec = { + let zi_shape = zi.shape(); + let zi_strides = zi.strides(); + + // Waiting for try_collect() from nightly... we use this Vec> -> Result> method.. + let tmp_heap: Vec> = (0..ndim) + .map(|k| { + if zi_shape[k] == expected_shape[k] { + zi_strides[k].try_into().map_err(|_| Error::InvalidArg { + arg: "zi".into(), + reason: "zi found with negative stride".into(), + }) + } else if k != axis_inner && zi_shape[k] == 1 { + Ok(0) + } else { + Err(Error::InvalidArg { + arg: "zi".into(), + reason: "Unexpected shape for parameter zi".into(), + }) + } + }) + .collect(); + let tmp_heap: Result> = tmp_heap.into_iter().collect(); + + tmp_heap? + }; + + // ArrayView::from_shape(strides, + // zi.as_slice_memory_order().unwrap()).unwrap().to_owned() + zi = ArrayView::from_shape((expected_shape).strides(strides), zii.as_slice().unwrap()) + .unwrap(); + }; + + let (out_full_dim, out_full_dim_inner): (Dim<_>, Vec) = { + let mut tmp = x.shape().to_vec(); + tmp[axis_inner] += b.len_of(Axis(0)) - 1; // From np.convolve(..., 'full') + (IntoDimension::into_dimension(tmp.as_ref()), tmp) + }; + + let mut out_full = ArrayD::::zeros(out_full_dim); + out_full + .lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_full_slice, y)| { + // np.convolve uses full mode by default + // ```py + // out_full = np.apply_along_axis(lambda y: np.convolve(b, y), axis, x) + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + convolve(y, (&b).into(), ConvolveMode::Full)?.assign_to(&mut out_full_slice); + Ok(()) + })?; + + // ```py + // ind[axis] = slice(zi.shape[axis]) + // out_full[tuple(ind)] += zi + // ``` + { + let slice_info: SliceInfo<_, D, D> = { + let t = zi.shape()[axis_inner]; + let mut tmp = vec![SliceInfoElem::from(..); ndim]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: 0, + end: Some(t as isize), + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; // Does not work because unless N: N<=6 cannot be bounded on type_sig + let mut s = out_full.slice_mut(&slice_info); + s += &zi; + } + + let (out_dim, out_dim_inner) = { + // let mut out_dim_inner = out_full_dim_inner; + // if let Some(inner) = out_dim_inner.get_mut(axis_inner) { + // *inner = inner + // .checked_sub({ + // // Safety: b is Array1 + // *b.shape().first().unwrap() + // }) + // // Safety: inner is defined by having added b.len() + // .unwrap() + // + 1; + // } else { + // unsafe { unreachable_unchecked() }; + // }; + // (IntoDimension::into_dimension(out_dim_inner), out_dim_inner) + let tmp = x.shape(); + (IntoDimension::into_dimension(tmp), tmp) + }; + // Safety: All elements are overwritten by convolve in subsequent step. + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; + out.lanes_mut(axis) + .into_iter() + .zip(out_full.lanes(axis)) + .for_each(|(mut out_slice, out_full_slice)| { + // ```py + // # Create the [...; :out_full.shape[axis] - len(b) + 1; ...] at index=axis + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) + // out = out_full[tuple(ind)] + // ``` + out_full_slice + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + }); + + // ```py + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1, None) + // zf = out_full[tuple(ind)] + // ``` + let zf = { + let slice_info: SliceInfo<_, D, IxDyn> = { + let t = out_full.shape()[axis_inner] + .checked_add(1) + .unwrap() + .checked_sub(b.len()) + .unwrap(); + let mut tmp = vec![SliceInfoElem::from(..); ndim]; + tmp[axis_inner] = SliceInfoElem::Slice { + start: t as isize, + end: None, + step: 1, + }; + + SliceInfo::try_from(tmp).unwrap() + }; + out_full.slice(slice_info).to_owned() + }; + + Ok((out, Some(zf))) + } else { + // In contrast to the case where zi.is_some(), we can inline a slicing operation to reduce + // one extra heap allocation. + + let (out_dim, out_dim_inner) = { + let tmp = x.shape(); + (IntoDimension::into_dimension(tmp), tmp) + }; + let mut out = unsafe { Array::uninit(out_dim).assume_init() }; // Safety: All elements are overwritten by convolve in subsequent step. + + out.lanes_mut(axis) + .into_iter() + .zip(x.lanes(axis)) // Almost basically np.apply_along_axis + .try_for_each(|(mut out_slice, y)| { + // np.convolve uses full mode, but is eventually slices out with + // ```py + // ind = out_full.ndim * [slice(None)] # creates the "[:, :, ..., :]" slice r + // ind[axis] = slice(out_full.shape[axis] - len(b) + 1) # [:out_full.shape[ ..] - len(b) + 1] + // ``` + use sci_rs_core::num_rs::{convolve, ConvolveMode}; + let out_full = convolve(y, (&b).into(), ConvolveMode::Full)?; + out_full + .slice( + SliceInfo::try_from([SliceInfoElem::Slice { + start: 0, + end: Some(out_dim_inner[axis_inner] as isize), + step: 1, + }]) + .unwrap(), + ) + .assign_to(&mut out_slice); + Ok(()) + })?; + + Ok((out, None)) + } +} + +/// Internal function called by [LFilter::lfilter] for situation a.len() > 1. +fn linear_filter<'a, T, S, D>( + b: ArrayView1<'a, T>, + a: ArrayView1<'a, T>, + x: ArrayBase, + axis: Option, + zi: Option>, +) -> Result> +where + D: Dimension, + T: 'a, + S: Data + 'a, +{ + todo!() +} + +#[cfg(test)] +mod test { + use super::*; + use alloc::vec; + use approx::assert_relative_eq; + use ndarray::{array, ArrayBase, Dim, Ix, OwnedRepr, ViewRepr}; + + // Tests that have a = [1.] with zi = None on input x with dim = 1. + #[test] + fn one_dim_fir_no_zi() { + { + // Tests for b.sum() > 1. + let b = array![5., 4., 1., 2.]; + let a = array![1.]; + let x = array![1., 2., 3., 4., 3., 5., 6.]; + let expected = array![5., 14., 24., 36., 38., 47., 61.]; + + let Ok((result, None)) = Array1::lfilter((&b).into(), (&a).into(), x, None, None) + else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_eq!(r, e); + }) + } + { + // Tests for b[i] < 0 for some i, such that b.sum() = 1. + let b = array![0.7, -0.3, 0.6]; + let a = array![1.]; + let x = array![1., 2., 3., 4., 3., 5., 6.]; + let expected = array![0.7, 1.1, 2.1, 3.1, 2.7, 5., 4.5]; + + let Ok((result, None)) = Array1::lfilter((&b).into(), (&a).into(), x, None, None) + else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + } + + #[test] + fn one_dim_fir_with_zi() { + { + // Case which does not falls into zi.shape() != expected_shape branch + let b = array![0.5, 0.4]; + let a = array![1.]; + let x = array![ + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + ]; + let zi = array![[-1.6], [1.4]]; + let expected = array![ + [-3.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7], + [-0.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7] + ]; + let expected_zi = array![[0.4], [0.4]]; + + let Ok((result, Some(r_zi))) = Array::<_, Dim<[Ix; 2]>>::lfilter( + (&b).into(), + (&a).into(), + x, + None, + Some((&zi).into()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]]; + let zi = array![[0.4], [0.45], [0.05]]; + let expected = array![ + [5.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.85, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85], + ]; + let expected_zi = array![ + [4.25, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.], + ]; + + let Ok((result, Some(r_zi))) = Array::<_, Dim<[Ix; 2]>>::lfilter( + (&b).into(), + (&a).into(), + x, + Some(0), + Some((&zi).into()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch for 3D input + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![ + [[0.2, 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]], + [[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]] + ]; + let zi = array![[[0.4], [0.45], [0.05]], [[0.6], [0.15], [0.25]]]; + let expected = array![ + [ + [1.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.53, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85] + ], + [ + [5.6, 10.6, 15.6, 20.6, 15.6, 25.6, 30.6], + [40.55, 0.95, 6.35, 1.75, 16.35, 37.15, 32.55] + ] + ]; + let expected_zi = array![ + [ + [3.45, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [7.6, -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ], + [ + [4.45, 2.25, 3.65, 4.25, 4.45, 8.05, 8.65], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ] + ]; + + let Ok((result, Some(r_zi))) = Array::<_, Dim<[Ix; 3]>>::lfilter( + (&b).into(), + (&a).into(), + x, + Some(1), + Some((&zi).into()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + } + + #[test] + fn invalid_axis() { + let b = array![5., 4., 1., 2.]; + let a = array![1.]; + let x = array![1., 2., 3., 4., 3., 5., 6.]; + + let result = ArrayView1::lfilter((&b).into(), (&a).into(), (&x).into(), Some(2), None); + assert!(result.is_err()); + + let result = Array1::lfilter((&b).into(), (&a).into(), x.clone(), Some(1), None); + assert!(result.is_err()); + + let result = Array1::lfilter((&b).into(), (&a).into(), x.clone(), Some(0), None); + assert!(result.is_ok()); + + let result = Array1::lfilter((&b).into(), (&a).into(), x.clone(), Some(-1), None); + assert!(result.is_ok()); + + let result = Array1::lfilter((&b).into(), (&a).into(), x, Some(-2), None); + assert!(result.is_err()); + } + + #[test] + fn dyn_dim_fir_with_zi() { + { + // Case which does not falls into zi.shape() != expected_shape branch + let b = array![0.5, 0.4]; + let a = array![1.]; + let x = array![ + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + [-4., -3., -1., -2., 1., 2., -3., 4., 3., 5., 6., 7., -8., 1.], + ]; + let zi = array![[-1.6], [1.4]]; + let expected = array![ + [-3.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7], + [-0.6, -3.1, -1.7, -1.4, -0.3, 1.4, -0.7, 0.8, 3.1, 3.7, 5., 5.9, -1.2, -2.7] + ]; + let expected_zi = array![[0.4], [0.4]]; + + // Test static dim input + let Ok((result, Some(r_zi))) = + lfilter((&b).into(), (&a).into(), x.view(), None, Some((&zi).into())) + else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(&expected).for_each(|(r, &e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(&expected_zi).for_each(|(r, &e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + + // Test dyn input + let Ok((result, Some(r_zi))) = lfilter( + (&b).into(), + (&a).into(), + x.into_dyn(), + None, + Some(zi.into_dyn().view()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]]; + let zi = array![[0.4], [0.45], [0.05]]; + let expected = array![ + [5.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.85, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85], + ]; + let expected_zi = array![ + [4.25, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.], + ]; + + let Ok((result, Some(r_zi))) = lfilter( + (&b).into(), + (&a).into(), + x.into_dyn(), + Some(0), + Some(zi.into_dyn().view()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + { + // Case which does falls into zi.shape() != expected_shape branch for 3D input + let b = array![5., 0.4, 1., -2.]; + let a = array![1.]; + let x = array![ + [[0.2, 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]], + [[1., 2., 3., 4., 3., 5., 6.], [8., 0., 1., 0., 3., 7., 6.]] + ]; + let zi = array![[[0.4], [0.45], [0.05]], [[0.6], [0.15], [0.25]]]; + let expected = array![ + [ + [1.4, 10.4, 15.4, 20.4, 15.4, 25.4, 30.4], + [40.53, 1.25, 6.65, 2.05, 16.65, 37.45, 32.85] + ], + [ + [5.6, 10.6, 15.6, 20.6, 15.6, 25.6, 30.6], + [40.55, 0.95, 6.35, 1.75, 16.35, 37.15, 32.55] + ] + ]; + let expected_zi = array![ + [ + [3.45, 2.05, 3.45, 4.05, 4.25, 7.85, 8.45], + [7.6, -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ], + [ + [4.45, 2.25, 3.65, 4.25, 4.45, 8.05, 8.65], + [6., -4., -5., -8., -3., -3., -6.], + [-16., 0., -2., 0., -6., -14., -12.] + ] + ]; + + let Ok((result, Some(r_zi))) = lfilter( + (&b).into(), + (&a).into(), + x.into_dyn(), + Some(1), + Some(zi.into_dyn().view()), + ) else { + panic!("Should not have errored") + }; + + assert_eq!(result.len(), expected.len()); + result.into_iter().zip(expected).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }); + assert_eq!(r_zi.len(), expected_zi.len()); + r_zi.into_iter().zip(expected_zi).for_each(|(r, e)| { + assert_relative_eq!(r, e, max_relative = 1e-6); + }) + } + } +} diff --git a/sci-rs/src/signal/filter/lfilter_zi.rs b/sci-rs/src/signal/filter/lfilter_zi.rs index 6c92df5a..48296030 100644 --- a/sci-rs/src/signal/filter/lfilter_zi.rs +++ b/sci-rs/src/signal/filter/lfilter_zi.rs @@ -9,20 +9,17 @@ use alloc::vec; #[cfg(feature = "alloc")] use alloc::vec::Vec; +/// Construct initial conditions for [lfilter][super::lfilter::LFilter] for step response +/// steady-state. /// -/// Construct initial conditions for lfilter for step response steady-state. -/// -/// Compute an initial state `zi` for the `lfilter` function that corresponds -/// to the steady state of the step response. +/// Compute an initial state `zi` for the `lfilter` function that corresponds to the steady state +/// of the step response. /// /// A typical use of this function is to set the initial state so that the /// output of the filter starts at the same value as the first element of /// the signal to be filtered. /// -/// /// -/// -/// #[inline] pub fn lfilter_zi_dyn(b: &[F], a: &[F]) -> Vec where @@ -38,7 +35,7 @@ where .expect("There must be at least one nonzero `a` coefficient.") .0; - // Mormalize to a[0] == 1 + // Normalize to a[0] == 1 let mut a = a.iter().skip(ai0).cloned().collect::>(); let mut b = b.to_vec(); let a0 = a[0]; diff --git a/sci-rs/src/signal/filter/mod.rs b/sci-rs/src/signal/filter/mod.rs index 805c65b7..ec0c5f2c 100644 --- a/sci-rs/src/signal/filter/mod.rs +++ b/sci-rs/src/signal/filter/mod.rs @@ -13,7 +13,9 @@ pub use kalmanfilt::kalman::kalman_filter; /// pub use gaussfilt as gaussian_filter; -/// Digital IIR/FIR filter design +/// Digital IIR/FIR filter design +/// Functions located in the [`Filter design` section of +/// `scipy.signal`](https://docs.scipy.org/doc/scipy/reference/signal.html#filter-design). pub mod design; mod ext; @@ -22,6 +24,13 @@ mod sosfilt; pub use ext::*; pub use sosfilt::*; +#[cfg(feature = "alloc")] +mod arraytools; +#[cfg(feature = "alloc")] +use arraytools::*; + +#[cfg(feature = "alloc")] +mod lfilter; #[cfg(feature = "alloc")] mod lfilter_zi; #[cfg(feature = "alloc")] @@ -31,6 +40,8 @@ mod sosfilt_zi; #[cfg(feature = "alloc")] mod sosfiltfilt; +#[cfg(feature = "alloc")] +pub use lfilter::*; #[cfg(feature = "alloc")] pub use lfilter_zi::*; #[cfg(feature = "alloc")] diff --git a/sci-rs/src/signal/mod.rs b/sci-rs/src/signal/mod.rs index 8aede037..5dd9a505 100644 --- a/sci-rs/src/signal/mod.rs +++ b/sci-rs/src/signal/mod.rs @@ -1,13 +1,32 @@ -/// Digital Filtering +/// Digital Filtering +/// Contains functions from [Filtering section of +/// `scipy.signal`](https://docs.scipy.org/doc/scipy/reference/signal.html#filtering). pub mod filter; -/// Signal Generation +/// Signal Generation +/// Contains functions from the [Waveforms section of +/// `scipy.signal`](). pub mod wave; -/// Convolution +/// Convolution +/// Contains functions from the [Convolution section of +/// `scipy.signal`](). #[cfg(feature = "std")] pub mod convolve; -/// Signal Resampling +/// Window functions +/// This contains all window functions in the +/// [`scipy.signal.windows`](https://docs.scipy.org/doc/scipy/reference/signal.windows.html#module-scipy.signal.windows) +/// namespace. +/// The convenience function +/// [`get_windows`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html#scipy.signal.get_window) +/// in the [scipy.signal](https://docs.scipy.org/doc/scipy/reference/signal.html#window-functions) +/// namespace is located here. +pub mod windows; + +/// Signal Resampling +/// This contains only the +/// [`resample`](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.resample.html#scipy.signal.resample) +/// function from `scipy.signal`. #[cfg(feature = "std")] pub mod resample; diff --git a/sci-rs/src/signal/resample.rs b/sci-rs/src/signal/resample.rs index 9296c1d6..a9b6ac94 100644 --- a/sci-rs/src/signal/resample.rs +++ b/sci-rs/src/signal/resample.rs @@ -87,22 +87,18 @@ mod tests { // Resample each to length 100 // Check that each resampled vector is length 100 - let mut rng = rand::thread_rng(); + let mut rng = rand::rng(); for i in (0..100) { - let len = rng.gen_range((10..50)); - let x = (0..len) - .map(|_| rng.gen_range((-100.0..100.))) - .collect::>(); + let len = rng.random_range(10..50); + let x: Vec<_> = (0..len).map(|_| rng.random_range(-100.0..100.)).collect(); let y = resample(&x, 100); assert_eq!(y.len(), 100); } for i in (0..50) { - let len = rng.gen_range((200..10000)); - let target_len = rng.gen_range((50..50000)); - let x = (0..len) - .map(|_| rng.gen_range((-100.0..100.))) - .collect::>(); + let len = rng.random_range(200..10000); + let target_len = rng.random_range(50..50000); + let x: Vec<_> = (0..len).map(|_| rng.random_range(-100.0..100.)).collect(); let y = resample(&x, target_len); assert_eq!(y.len(), target_len); } diff --git a/sci-rs/src/signal/wave/mod.rs b/sci-rs/src/signal/wave/mod.rs index 87372fcc..90bd9ea0 100644 --- a/sci-rs/src/signal/wave/mod.rs +++ b/sci-rs/src/signal/wave/mod.rs @@ -1,7 +1,6 @@ use nalgebra::RealField; use ndarray::{Array, ArrayBase, Data, Dimension, RawData}; -/// """ /// Return a periodic square-wave waveform. /// /// The square wave has a period ``2*pi``, has value +1 from 0 to @@ -30,15 +29,18 @@ use ndarray::{Array, ArrayBase, Data, Dimension, RawData}; /// -------- /// A 5 Hz waveform sampled at 500 Hz for 1 second: /// +/// ```custom,{class=language-python} /// >>> import numpy as np /// >>> from scipy import signal /// >>> import matplotlib.pyplot as plt /// >>> t = np.linspace(0, 1, 500, endpoint=False) /// >>> plt.plot(t, signal.square(2 * np.pi * 5 * t)) /// >>> plt.ylim(-2, 2) +/// ``` /// /// A pulse-width modulated sine wave: /// +/// ```custom,{class=language-python} /// >>> plt.figure() /// >>> sig = np.sin(2 * np.pi * t) /// >>> pwm = signal.square(2 * np.pi * 30 * t, duty=(sig + 1)/2) @@ -47,8 +49,7 @@ use ndarray::{Array, ArrayBase, Data, Dimension, RawData}; /// >>> plt.subplot(2, 1, 2) /// >>> plt.plot(t, pwm) /// >>> plt.ylim(-1.5, 1.5) -/// -/// """ +/// ``` pub fn square(t: &ArrayBase, duty: F) -> Array where F: RealField, diff --git a/sci-rs/src/signal/windows/blackman.rs b/sci-rs/src/signal/windows/blackman.rs new file mode 100644 index 00000000..5f6f2718 --- /dev/null +++ b/sci-rs/src/signal/windows/blackman.rs @@ -0,0 +1,120 @@ +use super::GeneralCosine; +use nalgebra::RealField; +use num_traits::{real::Real, Float}; + +#[cfg(feature = "alloc")] +use super::GetWindow; +#[cfg(feature = "alloc")] +use alloc::{vec, vec::Vec}; + +/// Collection of arguments for window `Blackman` for use in [GetWindow]. +#[derive(Debug, Clone, PartialEq)] +pub struct Blackman { + /// Number of points in the output window. If zero, an empty array is returned in [GetWindow]. + pub m: usize, + /// Whether the window is symmetric. + /// + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub sym: bool, +} + +impl Blackman { + /// Returns a Blackman struct. + /// + /// # Parameters + /// * `m`: + /// Number of points in the output window. If zero, an empty array is returned. + /// * `sym`: + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub fn new(m: usize, sym: bool) -> Self { + Blackman { m, sym } + } +} + +#[cfg(feature = "alloc")] +impl GetWindow for Blackman +where + W: Real + Float + RealField, +{ + /// Return a window of type: Blackman. + /// + /// The Blackman window is a taper formed by using the first three terms of a summation of + /// cosines. It was designed to have close to the minimal leakage possible. It is close to + /// optimal, only slightly worse than a Kaiser window. + /// + /// # Parameters + /// `self`: [Blackman] + /// + /// # Returns + /// `w`: `vec` + /// The window, with the maximum value normalized to 1 (though the value 1 does not appear + /// if `M` is even and `sym` is True). + /// + /// # Example + /// ``` + /// use sci_rs::signal::windows::{Blackman, GetWindow}; + /// + /// let nx = 8; + /// let b = Blackman::new(nx, true); + /// ``` + /// + /// # References + /// + #[cfg(feature = "alloc")] + fn get_window(&self) -> Vec { + GeneralCosine::new( + self.m, + [0.42, 0.50, 0.08] + .into_iter() + .map(|n| W::from(n).unwrap()) + .collect(), + self.sym, + ) + .get_window() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + + #[test] + fn blackman_20() { + // Created with + // >>> from scipy.signal.windows import blackman + // >>> blackman(20) + let expected = vec![ + -1.38777878e-17, + 1.02226199e-02, + 4.50685843e-02, + 1.14390287e-01, + 2.26899356e-01, + 3.82380768e-01, + 5.66665187e-01, + 7.52034438e-01, + 9.03492728e-01, + 9.88846031e-01, + 9.88846031e-01, + 9.03492728e-01, + 7.52034438e-01, + 5.66665187e-01, + 3.82380768e-01, + 2.26899356e-01, + 1.14390287e-01, + 4.50685843e-02, + 1.02226199e-02, + -1.38777878e-17, + ]; + assert_vec_eq(expected, Blackman::new(20, true).get_window()); + } + + #[track_caller] + fn assert_vec_eq(a: Vec, b: Vec) { + for (a, b) in a.into_iter().zip(b) { + assert_abs_diff_eq!(a, b, epsilon = 1e-6); + } + } +} diff --git a/sci-rs/src/signal/windows/boxcar.rs b/sci-rs/src/signal/windows/boxcar.rs new file mode 100644 index 00000000..5c5faaaa --- /dev/null +++ b/sci-rs/src/signal/windows/boxcar.rs @@ -0,0 +1,83 @@ +use super::{extend, len_guard}; +use num_traits::real::Real; + +#[cfg(feature = "alloc")] +use super::GetWindow; +#[cfg(feature = "alloc")] +use alloc::{vec, vec::Vec}; + +/// Collection of arguments for window `Boxcar` for use in [GetWindow]. +#[derive(Debug, Clone, PartialEq)] +pub struct Boxcar { + /// Number of points in the output window. If zero, an empty array is returned in [GetWindow]. + pub m: usize, + /// Whether the window is symmetric. (Has no effect for boxcar.) + /// + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub sym: bool, +} + +impl Boxcar { + /// Returns a Boxcar struct. + /// + /// # Parameters + /// * `m`: + /// Number of points in the output window. If zero, an empty array is returned. + /// * `sym`: + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub fn new(m: usize, sym: bool) -> Self { + Boxcar { m, sym } + } +} + +#[cfg(feature = "alloc")] +impl GetWindow for Boxcar +where + W: Real, +{ + /// Return a window of type: boxcar or rectangular window. + /// + /// Also known as a rectangular window or Dirichlet window, this is equivalent to no window at + /// all. + /// + /// # Parameters + /// `self`: [Boxcar] + /// + /// # Example + /// ``` + /// use sci_rs::signal::windows::{Boxcar, GetWindow}; + /// + /// let nx = 5; + /// let boxcar = Boxcar::new(nx, false); + /// let window: Vec = boxcar.get_window(); + /// assert_eq!(vec![1.; nx], window); + /// ``` + #[cfg(feature = "alloc")] + fn get_window(&self) -> Vec { + if len_guard(self.m) { + return Vec::::new(); + } + let (m, needs_trunc) = extend(self.m, self.sym); + + if !needs_trunc { + vec![W::one(); m] + } else { + vec![W::one(); m - 1] + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn case_a() { + let nx = 5; + let boxcar = Boxcar::new(nx, false); + let window: Vec = boxcar.get_window(); + assert_eq!(vec![1.; nx], window); + } +} diff --git a/sci-rs/src/signal/windows/general_cosine.rs b/sci-rs/src/signal/windows/general_cosine.rs new file mode 100644 index 00000000..3ac09dcc --- /dev/null +++ b/sci-rs/src/signal/windows/general_cosine.rs @@ -0,0 +1,213 @@ +use super::{extend, len_guard, truncate}; +use nalgebra::RealField; +use num_traits::{real::Real, Float}; + +#[cfg(feature = "alloc")] +use super::GetWindow; +#[cfg(feature = "alloc")] +use alloc::{vec, vec::Vec}; + +/// Collection of arguments for window `GeneralCosine` for use in [GetWindow]. +#[cfg(feature = "alloc")] +#[derive(Debug, Clone, PartialEq)] +pub struct GeneralCosine +where + F: Real, +{ + /// Number of points in the output window. If zero, an empty array is returned in [GetWindow]. + pub m: usize, + /// Sequence of weighting coefficients. + pub a: Vec, // Is there a better type, such as impl Into? + /// Whether the window is symmetric. + /// + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub sym: bool, +} + +impl GeneralCosine +where + F: Real, +{ + /// Returns a GeneralCosine struct. + /// + /// # Parameters + /// * `m`: + /// Number of points in the output window. If zero, an empty array is returned. + /// * `a`: `Vec` + /// Sequence of weighting coefficients. This uses the convention of being centered on the + /// origin, so these will typically all be positive numbers, not alternating sign. + /// * `sym`: + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub fn new(m: usize, a: Vec, sym: bool) -> Self { + GeneralCosine { m, a, sym } + } +} + +#[cfg(feature = "alloc")] +impl GetWindow for GeneralCosine +where + F: Real, + W: Real + Float + RealField, +{ + /// Return a window of type: GeneralCosine. + /// + /// The general cosine window is a generic weighted sum of cosine terms. + /// + /// # Parameters + /// `self`: [GeneralCosine] + /// + /// # Returns + /// `w`: `vec` + /// The window, with the maximum value normalized to 1 (though the value 1 does not appear + /// if `M` is even and `sym` is True). + /// + /// # Example + /// We can create a `flat-top` window named "HFT90D" as following: + /// ``` + /// use sci_rs::signal::windows::{GeneralCosine, GetWindow}; + /// + /// let hfd90 = [1., 1.942604, 1.340318, 0.440811, 0.043097].into(); + /// let window: Vec = GeneralCosine::new(30, hfd90, true).get_window(); + /// ``` + /// + /// This is equivalent to the Python code: + /// ```custom,{class=language-python} + /// from scipy.signal.windows import general_cosine + /// HFT90D = [1, 1.942604, 1.340318, 0.440811, 0.043097] + /// window = general_cosine(30, HFT90D, sym=False) + /// ``` + /// + /// # References + /// + #[cfg(feature = "alloc")] + fn get_window(&self) -> Vec { + if len_guard(self.m) { + return Vec::::new(); + } + let (m, needs_trunc) = extend(self.m, self.sym); + + let linspace = (0..m).map(|i| W::from(i).unwrap()); + let fac = linspace.map(|i| W::two_pi() * i / W::from(m - 1).unwrap() - W::pi()); + let w: Vec<_> = self + .a + .iter() + .enumerate() + .map(|(k, a)| { + fac.clone() + .map(move |f| Float::cos(f * W::from(k).unwrap()) * W::from(*a).unwrap()) + }) + .fold( + vec![W::from(0).unwrap(); fac.clone().count()], + // Would this have made more sense if using ndarray::Array1? No, Array1 has no pop. + |acc, x| acc.iter().zip(x).map(|(&a, b)| a + b).collect(), + ) + .into_iter() + .collect(); + + truncate(w, needs_trunc) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + + #[test] + fn general_cosine_scipy_eg() { + // Created with + // >>> from scipy.signal.windows import general_cosine + // >>> HFT90D = [1, 1.942604, 1.340318, 0.440811, 0.043097] + // >>> window = general_cosine(30, HFT90D, sym=False) + let expected = vec![ + -1.04083409e-16, + -3.49808964e-03, + -1.85322176e-02, + -5.60667246e-02, + -1.25488810e-01, + -2.22198500e-01, + -3.14696393e-01, + -3.38497088e-01, + -2.04818447e-01, + 1.72651725e-01, + 8.38783500e-01, + 1.76097559e+00, + 2.81469639e+00, + 3.80321808e+00, + 4.51005597e+00, + 4.76683000e+00, + 4.51005597e+00, + 3.80321808e+00, + 2.81469639e+00, + 1.76097559e+00, + 8.38783500e-01, + 1.72651725e-01, + -2.04818447e-01, + -3.38497088e-01, + -3.14696393e-01, + -2.22198500e-01, + -1.25488810e-01, + -5.60667246e-02, + -1.85322176e-02, + -3.49808964e-03, + ]; + + let hfd90 = [1., 1.942604, 1.340318, 0.440811, 0.043097].into(); + let gc = GeneralCosine::new(30, hfd90, false); + assert_vec_eq(expected, gc.get_window()); + } + + #[test] + fn constant() { + let x = 0.42; + let n = 6; + let expected = vec![x; n]; + + let actual: Vec = GeneralCosine::new(n, vec![x], true).get_window(); + assert_eq!(expected, actual); + } + + #[test] + fn case_b() { + // Created with + // >>> from scipy.signal.windows import general_cosine + // >>> n = 20 + // >>> a = [0.42, 0.50] + // >>> window = general_cosine(30, a, sym=False) + let n = 20; + let a = vec![0.42, 0.50]; + let expected = vec![ + -0.08, + -0.05552826, + 0.0154915, + 0.12610737, + 0.2654915, + 0.42, + 0.5745085, + 0.71389263, + 0.8245085, + 0.89552826, + 0.92, + 0.89552826, + 0.8245085, + 0.71389263, + 0.5745085, + 0.42, + 0.2654915, + 0.12610737, + 0.0154915, + -0.05552826, + ]; + + assert_vec_eq(expected, GeneralCosine::new(n, a, false).get_window()); + } + + #[track_caller] + fn assert_vec_eq(a: Vec, b: Vec) { + for (a, b) in a.into_iter().zip(b) { + assert_abs_diff_eq!(a, b, epsilon = 1e-6); + } + } +} diff --git a/sci-rs/src/signal/windows/general_gaussian.rs b/sci-rs/src/signal/windows/general_gaussian.rs new file mode 100644 index 00000000..9f19801a --- /dev/null +++ b/sci-rs/src/signal/windows/general_gaussian.rs @@ -0,0 +1,148 @@ +use super::{extend, len_guard, truncate}; +use num_traits::{real::Real, Float}; + +#[cfg(feature = "alloc")] +use super::GetWindow; +#[cfg(feature = "alloc")] +use alloc::{vec, vec::Vec}; + +/// Collection of arguments for window `GeneralCosine` for use in [GetWindow]. +#[derive(Debug, Clone, PartialEq)] +pub struct GeneralGaussian +where + F: Real, +{ + /// Number of points in the output window. If zero, an empty array is returned in [GetWindow]. + pub m: usize, + /// Shape parameter. + pub p: F, + /// The standard deviation, σ. + pub sigma: F, + /// Whether the window is symmetric. + /// + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub sym: bool, +} + +impl GeneralGaussian +where + F: Real, +{ + /// Returns a GeneralGaussian struct. + /// + /// # Parameters + /// * `m`: + /// Number of points in the output window. If zero, an empty array is returned. + /// * `p` : float + /// Shape parameter. p = 1 is identical to `gaussian`, p = 0.5 is + /// the same shape as the Laplace distribution. + /// * `sig` : float + /// The standard deviation, σ. + /// * `sym`: + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub fn new(m: usize, p: F, sigma: F, sym: bool) -> Self { + GeneralGaussian { m, p, sigma, sym } + } +} + +#[cfg(feature = "alloc")] +impl GetWindow for GeneralGaussian +where + F: Real, + W: Real, +{ + /// Return a window with a generalized gaussian shape. + /// + /// # Parameters + /// `self`: [GeneralGaussian] + /// + /// # Returns + /// `w`: `vec` + /// The window, with the maximum value normalized to 1 (though the value 1 does not appear + /// if `M` is even and `sym` is True). + /// + /// # Notes + /// The generalized Gaussian window is defined as + /// $$w(n) = e^{ -\frac{1}{2}\left|\frac{n}{\sigma}\right|^{2p} }$$ + /// the half-power point is at + /// $$(2 \log(2))^{1/(2 p)} \sigma$$ + /// + /// # Example + /// ``` + /// use sci_rs::signal::windows::{GeneralGaussian, GetWindow}; + /// let window: Vec = GeneralGaussian::new(51, 1.5, 7., true).get_window(); + /// ``` + /// + /// This is equivalent to the Python code: + /// ```custom,{class=language-python} + /// from scipy import signal + /// window = signal.windows.general_gaussian(51, p=1.5, sig=7) + /// ``` + /// + /// # References + /// + #[cfg(feature = "alloc")] + fn get_window(&self) -> Vec { + if len_guard(self.m) { + return Vec::::new(); + } + let (m, needs_trunc) = extend(self.m, self.sym); + + let n = (0..m).map(|v| { + W::from(v).unwrap() - (W::from(m).unwrap() - W::from(1).unwrap()) / W::from(2).unwrap() + }); + let sig = W::from(self.sigma).unwrap(); + let two_p = W::from(self.p).unwrap() * W::from(2).unwrap(); + let w = n + .into_iter() + .map(|v| { + (v / sig) + .abs() + .powf(two_p) + .mul(W::from(-0.5).unwrap()) + .exp() + }) + .collect(); + + truncate(w, needs_trunc) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + + #[test] + fn general_gaussian_case_a() { + let gc = GeneralGaussian::new(20, 0.8, 4.1, true); + let expected = vec![ + 0.14688425, 0.2008071, 0.2687279, 0.35164027, 0.44932387, 0.55972797, 0.67829536, + 0.79725628, 0.90478285, 0.98289486, 0.98289486, 0.90478285, 0.79725628, 0.67829536, + 0.55972797, 0.44932387, 0.35164027, 0.2687279, 0.2008071, 0.14688425, + ]; + + assert_vec_eq(expected, gc.get_window()); + } + + #[test] + fn general_gaussian_case_b() { + let gc = GeneralGaussian::new(20, 0.8, 4.1, false); + let expected = vec![ + 0.12465962, 0.17219048, 0.23293383, 0.30828863, 0.3987132, 0.5031538, 0.61839303, + 0.73835825, 0.85338105, 0.94904249, 1., 0.94904249, 0.85338105, 0.73835825, 0.61839303, + 0.5031538, 0.3987132, 0.30828863, 0.23293383, 0.17219048, + ]; + + assert_vec_eq(expected, gc.get_window()); + } + + #[track_caller] + fn assert_vec_eq(a: Vec, b: Vec) { + for (a, b) in a.into_iter().zip(b) { + assert_abs_diff_eq!(a, b, epsilon = 1e-6); + } + } +} diff --git a/sci-rs/src/signal/windows/general_hamming.rs b/sci-rs/src/signal/windows/general_hamming.rs new file mode 100644 index 00000000..0ad2b356 --- /dev/null +++ b/sci-rs/src/signal/windows/general_hamming.rs @@ -0,0 +1,192 @@ +use super::GeneralCosine; +use nalgebra::RealField; +use num_traits::{real::Real, Float}; + +#[cfg(feature = "alloc")] +use super::GetWindow; +#[cfg(feature = "alloc")] +use alloc::vec::Vec; + +/// Collection of arguments for window `GeneralHamming` for use in [GetWindow]. +#[derive(Debug, Clone, PartialEq)] +pub struct GeneralHamming +where + F: Real, +{ + /// Number of points in the output window. If zero, an empty array is returned in [GetWindow]. + pub m: usize, + /// The window coefficient, α. + pub alpha: F, + /// Whether the window is symmetric. + /// + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub sym: bool, +} + +impl GeneralHamming +where + F: Real, +{ + /// Returns a GeneralHamming struct. + /// + /// # Parameters + /// * `m`: + /// Number of points in the output window. If zero, an empty array is returned. + /// * `alpha` : float + /// The window coefficient, α. + /// * `sym`: bool + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub fn new(m: usize, alpha: F, sym: bool) -> Self { + GeneralHamming { m, alpha, sym } + } +} + +#[cfg(feature = "alloc")] +impl GetWindow for GeneralHamming +where + F: Real, + W: Real + Float + RealField, +{ + /// Return a generalized Hamming window. + /// + /// The generalized Hamming window is constructed by multiplying a rectangular + /// window by one period of a cosine function [[1][1]]. + /// + /// # Parameters + /// * `M` : usize + /// Number of points in the output window. If zero, an empty array is returned. An + /// exception is thrown when it is negative. + /// * `alpha` : float + /// The window coefficient, α. + /// * `sym` : bool, optional + /// When True (default), generates a symmetric window, for use in filter design. + /// When False, generates a periodic window, for use in spectral analysis. + /// + /// # Returns + /// `w` : ndarray + /// The window, with the maximum value normalized to 1 (though the value 1 does not appear + /// if `M` is even and `sym` is True). + /// + /// # See Also + /// hamming, hann + /// + /// # Notes + /// The generalized Hamming window is defined as + /// $$w(n) = \alpha - \left(1 - \alpha\right) + /// \cos\left(\frac{2\pi{n}}{M-1}\right) \qquad 0 \leq n \leq M-1$$ + /// Both the common Hamming window and Hann window are special cases of the generalized Hamming + /// window with :math:`\alpha` = 0.54 and :math:`\alpha` = 0.5, respectively [[2][2]]. + /// + /// # References + /// + /// [[1]] DSPRelated, "Generalized Hamming Window Family", + /// + /// [[2]] Wikipedia, "Window function", + /// [[3]] Riccardo Piantanida ESA, "Sentinel-1 Level 1 Detailed Algorithm Definition", + /// + /// [[4]] Matthieu Bourbigot ESA, "Sentinel-1 Product Definition", + /// + /// [[5]] Scipy, + /// + /// + /// # Examples + /// The Sentinel-1A/B Instrument Processing Facility uses generalized Hamming + /// windows in the processing of spaceborne Synthetic Aperture Radar (SAR) + /// data [[3][3]]. The facility uses various values for the :math:`\alpha` + /// parameter based on operating mode of the SAR instrument. Some common + /// :math:`\alpha` values include 0.75, 0.7 and 0.52 [[4][4]]. As an example, we + /// plot these different windows. + /// + /// ```custom,{class=language-python} + /// >>> import numpy as np + /// >>> from scipy.signal.windows import general_hamming + /// >>> from scipy.fft import fft, fftshift + /// >>> import matplotlib.pyplot as plt + /// ``` + /// + /// ```custom,{class=language-python} + /// >>> fig1, spatial_plot = plt.subplots() + /// >>> spatial_plot.set_title("Generalized Hamming Windows") + /// >>> spatial_plot.set_ylabel("Amplitude") + /// >>> spatial_plot.set_xlabel("Sample") + /// ``` + /// + /// ```custom,{class=language-python} + /// >>> fig2, freq_plot = plt.subplots() + /// >>> freq_plot.set_title("Frequency Responses") + /// >>> freq_plot.set_ylabel("Normalized magnitude [dB]") + /// >>> freq_plot.set_xlabel("Normalized frequency [cycles per sample]") + /// ``` + /// + /// ```custom,{class=language-python} + /// >>> for alpha in [0.75, 0.7, 0.52]: + /// ... window = general_hamming(41, alpha) + /// ... spatial_plot.plot(window, label="{:.2f}".format(alpha)) + /// ... A = fft(window, 2048) / (len(window)/2.0) + /// ... freq = np.linspace(-0.5, 0.5, len(A)) + /// ... response = 20 * np.log10(np.abs(fftshift(A / abs(A).max()))) + /// ... freq_plot.plot(freq, response, label="{:.2f}".format(alpha)) + /// >>> freq_plot.legend(loc="upper right") + /// >>> spatial_plot.legend(loc="upper right") + /// ``` + /// + /// The equivalent is: + /// ``` + /// use sci_rs::signal::windows::{GetWindow, GeneralHamming}; + /// let window = GeneralHamming::new(41, 0.75, true); + /// ``` + /// + /// [1]: #references + /// [2]: #references + /// [3]: #references + /// [4]: #references + /// [5]: #references + #[cfg(feature = "alloc")] + fn get_window(&self) -> Vec { + GeneralCosine::new(self.m, [self.alpha, F::one() - self.alpha].into(), self.sym) + .get_window() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + + #[test] + fn general_hamming_case_a() { + // from scipy.signal.windows import general_hamming + // general_hamming(15, 0.65) + let gh = GeneralHamming::new(15, 0.65, true); + + let expected = [ + 0.3, 0.3346609, 0.43177857, 0.57211767, 0.72788233, 0.86822143, 0.9653391, 1., + 0.9653391, 0.86822143, 0.72788233, 0.57211767, 0.43177857, 0.3346609, 0.3, + ] + .into(); + assert_vec_eq(expected, gh.get_window()); + } + + #[test] + fn general_hamming_case_b() { + // from scipy.signal.windows import general_hamming + // general_hamming(15, 0.65) + let gh = GeneralHamming::new(15, 0.65, false); + + let expected = [ + 0.3, 0.33025909, 0.41580429, 0.54184405, 0.68658496, 0.825, 0.93315595, 0.99235166, + 0.99235166, 0.93315595, 0.825, 0.68658496, 0.54184405, 0.41580429, 0.33025909, + ] + .into(); + assert_vec_eq(expected, gh.get_window()); + } + + #[track_caller] + fn assert_vec_eq(a: Vec, b: Vec) { + for (a, b) in a.into_iter().zip(b) { + assert_abs_diff_eq!(a, b, epsilon = 1e-6); + } + } +} diff --git a/sci-rs/src/signal/windows/hamming.rs b/sci-rs/src/signal/windows/hamming.rs new file mode 100644 index 00000000..90016f5c --- /dev/null +++ b/sci-rs/src/signal/windows/hamming.rs @@ -0,0 +1,171 @@ +use super::GeneralHamming; +use nalgebra::RealField; +use num_traits::{real::Real, Float}; + +#[cfg(feature = "alloc")] +use super::GetWindow; +#[cfg(feature = "alloc")] +use alloc::vec::Vec; + +/// Collection of arguments for window `Hamming` for use in [GetWindow]. +#[derive(Debug, Clone, PartialEq)] +pub struct Hamming { + /// Number of points in the output window. If zero, an empty array is returned in [GetWindow]. + pub m: usize, + /// Whether the window is symmetric. + /// + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub sym: bool, +} + +impl Hamming { + /// Returns a Hamming struct. + /// + /// # Parameters + /// * `m`: + /// Number of points in the output window. If zero, an empty array is returned. + /// * `sym`: bool + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub fn new(m: usize, sym: bool) -> Self { + Hamming { m, sym } + } +} + +#[cfg(feature = "alloc")] +impl GetWindow for Hamming +where + W: Real + Float + RealField, +{ + /// Return a Hamming window. + /// + /// The Hamming window is a taper formed by using a raised cosine with non-zero endpoints, + /// optimized to minimize the nearest side lobe. + /// + /// # Parameters + /// * `M` : int + /// Number of points in the output window. If zero, an empty array is returned. An + /// exception is thrown when it is negative. + /// * `sym` : bool, optional + /// When True (default), generates a symmetric window, for use in filter + /// design. + /// When False, generates a periodic window, for use in spectral analysis. + /// + /// # Returns + /// `w` : ndarray + /// The window, with the maximum value normalized to 1 (though the value 1 + /// does not appear if `M` is even and `sym` is True). + /// + /// # Notes + /// The Hamming window is defined as + /// + /// $$w(n) = 0.54 - 0.46 \cos\left(\frac{2\pi{n}}{M-1}\right) \qquad 0 \leq n \leq M-1$$ + /// + /// The Hamming was named for R. W. Hamming, an associate of J. W. Tukey and is described in + /// Blackman and Tukey. It was recommended for smoothing the truncated autocovariance function + /// in the time domain. Most references to the Hamming window come from the signal processing + /// literature, where it is used as one of many windowing functions for smoothing values. It + /// is also known as an apodization (which means "removing the foot", i.e. smoothing + /// discontinuities at the beginning and end of the sampled signal) or tapering function. + /// + /// # References + /// [[1]] Blackman, R.B. and Tukey, J.W., (1958) The measurement of power spectra, Dover + /// Publications, New York. + /// [[2]] E.R. Kanasewich, "Time Sequence Analysis in Geophysics", The University of Alberta + /// Press, 1975, pp. 109-110. + /// [[3]] Wikipedia, "Window function", + /// [[4]] W.H. Press, B.P. Flannery, S.A. Teukolsky, and W.T. Vetterling, "Numerical Recipes", + /// Cambridge University Press, 1986, page 425. + /// [[5]] Scipy, + /// + /// + /// Examples + /// -------- + /// Plot the window and its frequency response: + /// + /// ```custom,{class=language-python} + /// >>> import numpy as np + /// >>> from scipy import signal + /// >>> from scipy.fft import fft, fftshift + /// >>> import matplotlib.pyplot as plt + /// ``` + /// + /// ```custom,{class=language-python} + /// >>> window = signal.windows.hamming(51) + /// >>> plt.plot(window) + /// >>> plt.title("Hamming window") + /// >>> plt.ylabel("Amplitude") + /// >>> plt.xlabel("Sample") + /// ``` + /// + /// ```custom,{class=language-python} + /// >>> plt.figure() + /// >>> A = fft(window, 2048) / (len(window)/2.0) + /// >>> freq = np.linspace(-0.5, 0.5, len(A)) + /// >>> response = 20 * np.log10(np.abs(fftshift(A / abs(A).max()))) + /// >>> plt.plot(freq, response) + /// >>> plt.axis([-0.5, 0.5, -120, 0]) + /// >>> plt.title("Frequency response of the Hamming window") + /// >>> plt.ylabel("Normalized magnitude [dB]") + /// >>> plt.xlabel("Normalized frequency [cycles per sample]") + /// ``` + /// + /// The equivalent is: + /// ``` + /// use sci_rs::signal::windows::{GetWindow, Hamming}; + /// let window: Vec = Hamming::new(51, true).get_window(); + /// ``` + /// + /// [1]: #references + /// [2]: #references + /// [3]: #references + /// [4]: #references + /// [5]: #references + #[cfg(feature = "alloc")] + fn get_window(&self) -> Vec { + GeneralHamming::::new(self.m, W::from(0.54).unwrap(), self.sym).get_window() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + + #[test] + fn hamming_case_a() { + // from scipy.signal.windows import hamming + // hamming(17) + let h = Hamming::new(17, true); + let expected = [ + 0.08, 0.11501542, 0.21473088, 0.36396562, 0.54, 0.71603438, 0.86526912, 0.96498458, 1., + 0.96498458, 0.86526912, 0.71603438, 0.54, 0.36396562, 0.21473088, 0.11501542, 0.08, + ] + .into(); + + assert_vec_eq(expected, h.get_window()); + } + + #[test] + fn hamming_case_b() { + // from scipy.signal.windows import hamming + // hamming(17, false) + let h = Hamming::new(17, false); + let expected = [ + 0.08, 0.11106277, 0.2000559, 0.33496036, 0.49755655, 0.66588498, 0.81721193, + 0.93109988, 0.99216763, 0.99216763, 0.93109988, 0.81721193, 0.66588498, 0.49755655, + 0.33496036, 0.2000559, 0.11106277, + ] + .into(); + + assert_vec_eq(expected, h.get_window()); + } + + #[track_caller] + fn assert_vec_eq(a: Vec, b: Vec) { + for (a, b) in a.into_iter().zip(b) { + assert_abs_diff_eq!(a, b, epsilon = 1e-6); + } + } +} diff --git a/sci-rs/src/signal/windows/kaiser.rs b/sci-rs/src/signal/windows/kaiser.rs new file mode 100644 index 00000000..e468fdcd --- /dev/null +++ b/sci-rs/src/signal/windows/kaiser.rs @@ -0,0 +1,229 @@ +use super::{extend, len_guard, truncate}; +use crate::special::Bessel; +use num_traits::{real::Real, Float}; + +#[cfg(feature = "alloc")] +use super::GetWindow; + +#[cfg(feature = "alloc")] +use alloc::{vec, vec::Vec}; + +/// Collection of arguments for window `Kaiser` for use in [GetWindow]. +#[cfg(feature = "alloc")] +#[derive(Debug, Clone, PartialEq)] +pub struct Kaiser +where + F: Real, +{ + /// Number of points in the output window. If zero, an empty array is returned in [GetWindow]. + pub m: usize, + /// Shape parameter. + pub beta: F, + /// Whether the window is symmetric. + /// + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub sym: bool, +} + +impl Kaiser +where + F: Real, +{ + /// Returns a Kaiser struct. + /// + /// # Parameters + /// * `m`: + /// Number of points in the output window. If zero, an empty array is returned. + /// * `beta` : float + /// Shape parameter, determines trade-off between main-lobe width and side lobe level. As + /// beta gets large, the window narrows. + /// * `sym`: + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub fn new(m: usize, beta: F, sym: bool) -> Self { + Kaiser { m, beta, sym } + } +} + +impl GetWindow for Kaiser +where + F: Real, + W: Real + Bessel, +{ + /// Return a [Kaiser] window. + /// + /// The Kaiser window is a taper formed by using a Bessel function. + /// + /// Parameters + /// ---------- + /// * `M` : int + /// Number of points in the output window. If zero, an empty array is returned. An + /// exception is thrown when it is negative. + /// * `beta` : float + /// Shape parameter, determines trade-off between main-lobe width and side lobe level. As + /// beta gets large, the window narrows. + /// * `sym` : bool, optional + /// When True (default), generates a symmetric window, for use in filter design. + /// When False, generates a periodic window, for use in spectral analysis. + /// + /// Returns + /// ------- + /// `w` : ndarray + /// The window, with the maximum value normalized to 1 (though the value 1 does not appear + /// if `M` is even and `sym` is True). + /// + /// Notes + /// ----- + /// The Kaiser window is defined as + /// + /// $$w(n) = I_0\left( \beta \sqrt{1-\frac{4n^2}{(M-1)^2}} \right)/I_0(\beta)$$ + /// + /// with + /// + /// $$\quad -\frac{M-1}{2} \leq n \leq \frac{M-1}{2}$$, + /// + /// where $I_0$ is the modified zeroth-order Bessel function. + /// + /// The Kaiser was named for Jim Kaiser, who discovered a simple approximation to the DPSS + /// window based on Bessel functions. + /// The Kaiser window is a very good approximation to the Digital Prolate Spheroidal Sequence, + /// or Slepian window, which is the transform which maximizes the energy in the main lobe of + /// the window relative to total energy. + /// + /// The Kaiser can approximate other windows by varying the beta parameter. [(Some literature + /// uses alpha = beta/pi.)]([4]) + /// + /// ==== ======================= + /// beta Window shape + /// ==== ======================= + /// 0 Rectangular + /// 5 Similar to a Hamming + /// 6 Similar to a Hann + /// 8.6 Similar to a Blackman + /// ==== ======================= + /// + /// A beta value of 14 is probably a good starting point. Note that as beta gets large, the + /// window narrows, and so the number of samples needs to be large enough to sample the + /// increasingly narrow spike, otherwise NaNs will be returned. + /// + /// Most references to the Kaiser window come from the signal processing literature, where it + /// is used as one of many windowing functions for smoothing values. It is also known as an + /// apodization (which means "removing the foot", i.e. smoothing discontinuities at the + /// beginning and end of the sampled signal) or tapering function. + /// + /// # References + /// [[1]] J. F. Kaiser, "Digital Filters" - Ch 7 in "Systems analysis by digital computer", + /// Editors: F.F. Kuo and J.F. Kaiser, p 218-285. John Wiley and Sons, New York, (1966). + /// [[2]] E.R. Kanasewich, "Time Sequence Analysis in Geophysics", The University of Alberta + /// Press, 1975, pp. 177-178. + /// [[3]] Wikipedia, "Window function", + /// [[4]] F. J. Harris, "On the use of windows for harmonic analysis with the discrete Fourier + /// transform," Proceedings of the IEEE, vol. 66, no. 1, pp. 51-83, Jan. 1978. + /// :doi:`10.1109/PROC.1978.10837`. + /// [[5]] + /// [Scipy]() + /// + /// # Examples + /// Plot the window and its frequency response: + /// + /// ```custom,{class=language-python} + /// >>> import numpy as np + /// >>> from scipy import signal + /// >>> from scipy.fft import fft, fftshift + /// >>> import matplotlib.pyplot as plt + /// ``` + /// + /// ```custom,{class=language-python} + /// >>> window = signal.windows.kaiser(51, beta=14) + /// >>> plt.plot(window) + /// >>> plt.title(r"Kaiser window ($\beta$=14)") + /// >>> plt.ylabel("Amplitude") + /// >>> plt.xlabel("Sample") + /// ``` + /// + /// ```custom,{class=language-python} + /// >>> plt.figure() + /// >>> A = fft(window, 2048) / (len(window)/2.0) + /// >>> freq = np.linspace(-0.5, 0.5, len(A)) + /// >>> response = 20 * np.log10(np.abs(fftshift(A / abs(A).max()))) + /// >>> plt.plot(freq, response) + /// >>> plt.axis([-0.5, 0.5, -120, 0]) + /// >>> plt.title(r"Frequency response of the Kaiser window ($\beta$=14)") + /// >>> plt.ylabel("Normalized magnitude [dB]") + /// >>> plt.xlabel("Normalized frequency [cycles per sample]") + /// ``` + /// + /// The equivalent is: + /// ``` + /// use sci_rs::signal::windows::{GetWindow, Kaiser}; + /// let window: Vec = Kaiser::new(51, 14., true).get_window(); + /// ``` + /// + /// Due to current implementation limitations, note that the following might not possible: + /// ``` + /// use sci_rs::signal::windows::{GetWindow, Kaiser}; + /// let window: Vec = Kaiser::new(51, 14., true).get_window(); + /// println!("window = {:?}", window); + /// ``` + /// + /// [1]: #references + /// [2]: #references + /// [3]: #references + /// [4]: #references + /// [5]: #references + #[cfg(feature = "alloc")] + fn get_window(&self) -> Vec { + if len_guard(self.m) { + return Vec::::new(); + } + let (m, needs_trunc) = extend(self.m, self.sym); + let n = (0..m); + let alpha = W::from(self.m - 1).unwrap() / W::from(2).unwrap(); + let beta = W::from(self.beta).unwrap(); + // let w: Vec = n + // .map(|ni| W::from(ni).unwrap() - alpha) + // .map(|ni| (ni / alpha).powf(W::from(2).unwrap())) + // .map(|ni| (W::one() - ni).sqrt()) + // .map(|ni| i0(beta * ni) / i0(beta)) + // .collect(); + let w: Vec = n + .map(|ni| { + (beta + * (W::one() + - ((W::from(ni).unwrap() - alpha) / alpha).powf(W::from(2).unwrap())) + .sqrt()) + .i0() + / (beta.i0()) + }) + .collect(); + truncate(w, needs_trunc) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + + #[test] + fn kaiser_17_8_true() { + // from scipy.signal.windows import kaiser + // kaiser(17, beta = 0.8) + let expected = vec![ + 0.85725436, 0.88970403, 0.9183205, 0.94289618, 0.96325245, 0.97924114, 0.99074569, + 0.99768219, 1., 0.99768219, 0.99074569, 0.97924114, 0.96325245, 0.94289618, 0.9183205, + 0.88970403, 0.85725436, + ]; + let k = Kaiser::new(17, 0.8, true); + + assert_vec_eq(expected, k.get_window()); + } + + #[track_caller] + fn assert_vec_eq(a: Vec, b: Vec) { + for (a, b) in a.into_iter().zip(b) { + assert_abs_diff_eq!(a, b, epsilon = 1e-6); + } + } +} diff --git a/sci-rs/src/signal/windows/mod.rs b/sci-rs/src/signal/windows/mod.rs new file mode 100644 index 00000000..5183d043 --- /dev/null +++ b/sci-rs/src/signal/windows/mod.rs @@ -0,0 +1,334 @@ +use crate::special; +use nalgebra::RealField; +use num_traits::{real::Real, Float}; + +#[cfg(feature = "alloc")] +use alloc::vec::Vec; + +/// Corresponding window representation for tuple-structs of [Window] variants. +#[cfg(feature = "alloc")] +pub trait GetWindow +where + W: Real, +{ + /// Returns a window of given length and type. + /// + /// # Parameters + /// `self`: + /// The type of window to construct, typically consists of at least the following + /// arguments: + /// * `Nx`: usize + /// The number of samples in the window + /// * `fftbins/~sym`: bool + /// If fftbins=true/sym=false, create a “periodic” window, ready to use with ifftshift and + /// be multiplied by the result of an FFT. This is the default behaviour in scipy. + /// If fftbins=false/sym=true, create a "symmetric" window, for use in filter design. + /// * `*args`: + /// Other arguments relevant to the window type. + /// + /// # Reference + /// + #[cfg(feature = "alloc")] + fn get_window(&self) -> Vec; +} + +/// Private function for windows implementing [GetWindow] +/// Handle small or incorrect window lengths. +#[inline(always)] +fn len_guard(m: usize) -> bool { + m <= 1 +} + +/// Private function for windows implementing [GetWindow] +/// Extend window by 1 sample if needed for DFT-even symmetry. +#[inline(always)] +fn extend(m: usize, sym: bool) -> (usize, bool) { + if !sym { + (m + 1, true) + } else { + (m, false) + } +} + +/// Private function for windows implementing [GetWindow] +/// Truncate window by 1 sample if needed for DFT-even symmetry. +#[inline(always)] +fn truncate(mut w: Vec, needed: bool) -> Vec { + if needed { + w.pop(); + } + w +} + +mod blackman; +mod boxcar; +mod general_cosine; +mod general_gaussian; +mod general_hamming; +mod hamming; +mod kaiser; +mod nuttall; +mod triangle; +pub use blackman::Blackman; +pub use boxcar::Boxcar; +pub use general_cosine::GeneralCosine; +pub use general_gaussian::GeneralGaussian; +pub use general_hamming::GeneralHamming; +pub use hamming::Hamming; +pub use kaiser::Kaiser; +pub use nuttall::Nuttall; +pub use triangle::Triangle; + +/// This collects all structs that implement the [GetWindow] trait. +/// This allows for running `.get_window()` on the struct, which can then be, for example, used in +/// Firwin. +// Ordering is as in accordance with +// https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.get_window.html. +#[derive(Debug, Clone, PartialEq)] +// Is it possible for the enums which wraps the structs to only require the generic that the struct +// implements GetWindow? +pub enum Window +where + F: Real, +{ + /// [Boxcar] window, also known as a rectangular window or Dirichlet window; This is equivalent + /// to no window at all. + Boxcar(Boxcar), + /// [Triangle] window. + Triangle(Triangle), + /// [Blackman] window. + Blackman(Blackman), + /// [Hamming] window. + Hamming(Hamming), + // Hann, + // Bartlett, + // Flattop, + // Parzen, + // Bohman, + // BlackmanHarris, + /// [Nuttall] window. + Nuttall(Nuttall), + // BartHann, + // Cosine, + // Exponential, + // Tukey, + // Taylor, + // Lanczos, + /// [Kaiser] window. + // Needs Beta + Kaiser(Kaiser), + // KaiserBesselDerived, // Needs Beta + // Gaussian, // Needs Standard Deviation + /// [GeneralCosine] window, a generic weighted sum of cosine term windows. + // Needs Weighting Coefficients + GeneralCosine(GeneralCosine), + /// [GeneralGaussian] window. + // Needs Power, Width + GeneralGaussian(GeneralGaussian), + /// [GeneralHamming] window. + // Needs Window Coefficients. + GeneralHamming(GeneralHamming), + // Dpss, // Needs Normalized Half-Bandwidth. + // Chebwin, // Needs Attenuation. +} + +impl GetWindow for Window +where + F: Real, + W: Real + Float + RealField + special::Bessel, +{ + fn get_window(&self) -> Vec { + match &self { + Window::Boxcar(x) => x.get_window(), + Window::Triangle(x) => x.get_window(), + Window::Blackman(x) => x.get_window(), + Window::Hamming(x) => x.get_window(), + Window::Nuttall(x) => x.get_window(), + Window::Kaiser(x) => x.get_window(), + Window::GeneralCosine(x) => x.get_window(), + Window::GeneralGaussian(x) => x.get_window(), + Window::GeneralHamming(x) => x.get_window(), + } + } +} + +/// This provides a set of enum variants that for use in [get_window]. +#[derive(Debug, Clone, PartialEq)] // Derive eq? +pub enum GetWindowBuilder<'a, F> +where + F: Real, +{ + /// [Boxcar] window, also known as a rectangular window or Dirichlet window; This is equivalent + /// to no window at all. + Boxcar, + /// [Triangle] window. + Triangle, + /// [Blackman] window. + Blackman, + /// [Hamming] window. + Hamming, + // Hann, + // Bartlett, + // Flattop, + // Parzen, + // Bohman, + // BlackmanHarris, + /// [Nuttall] window. + Nuttall, + // BartHann, + // Cosine, + // Exponential, + // Tukey, + // Taylor, + // Lanczos, + /// [Kaiser] window. + Kaiser { + /// Shape parameter `β`, please refer to [Kaiser]. + beta: F, + }, + // KaiserBesselDerived, // Needs Beta + // Gaussian, // Needs Standard Deviation + /// [GeneralCosine] window: Generic weighted sum of cosine term windows. + GeneralCosine { + /// Weighting Coefficients `a`, please refer to [GeneralCosine]. + weights: &'a [F], + }, + /// [GeneralGaussian] window: Generalized Gaussian Shape. + GeneralGaussian { + /// Shape parameter. + p: F, + /// The standard deviation, σ. + width: F, + }, + /// [GeneralHamming] window. + // Needs Window Coefficients. + GeneralHamming { + /// Window coefficient, ɑ + coefficient: F, + }, + // Dpss, // Needs Normalized Half-Bandwidth. + // Chebwin, // Needs Attenuation. +} + +/// Return a window of a given length and type. +/// +/// Parameters +/// ---------- +/// * `window`: [GetWindowBuilder] +/// The type of window to create. See below for more details. +/// * `Nx`: usize +/// The number of samples in the window. +/// * `fftbins`: bool, optional +/// If True (default), create a "periodic" window, ready to use with `ifftshift` and be +/// multiplied by the result of an FFT (see also :func:`~scipy.fft.fftfreq`). +/// If False, create a "symmetric" window, for use in filter design. +/// +/// Returns +/// ------- +/// * `get_window` : ndarray +/// Returns a window of length `Nx` and type `window` +/// +/// Notes +/// ----- +/// Window types: +/// * [Boxcar] +/// * [Triangle] +/// * [Blackman] +/// * [Hamming] +// Hann, +// Bartlett, +// Flattop, +// Parzen, +// Bohman, +// BlackmanHarris, +/// * [Nuttall] +// BartHann, +// Cosine, +// Exponential, +// Tukey, +// Taylor, +// Lanczos, +/// * [Kaiser] // Needs Beta +// KaiserBesselDerived, // Needs Beta +// Gaussian, // Needs Standard Deviation +/// * [GeneralCosine] +/// * [GeneralGaussian] // Needs Power, Width +/// * [GeneralHamming] // Needs Window Coefficients. +// Dpss, // Needs Normalized Half-Bandwidth. +// Chebwin, // Needs Attenuation. +/// +/// Examples +/// ----- +/// ``` +/// use approx:: assert_abs_diff_eq; +/// use sci_rs::signal::filter::design::{firwin_dyn, FilterBandType}; +/// use sci_rs::signal::windows::{get_window, GetWindow, GetWindowBuilder}; +/// +/// let window_struct = get_window(GetWindowBuilder::::Hamming, 3, None); +/// let window: Vec = window_struct.get_window(); +/// let expected = vec![0.08, 0.77, 0.77]; +/// +/// fn assert_vec_eq(a: Vec, b: Vec) { +/// for (a, b) in a.into_iter().zip(b) { +/// assert_abs_diff_eq!(a, b, epsilon = 1e-6); +/// } +/// } +/// +/// assert_vec_eq(window, expected); +/// ``` +/// +/// +/// # References +/// +pub fn get_window(window: GetWindowBuilder<'_, F>, nx: usize, fftbins: Option) -> Window +where + F: Real, +{ + match window { + GetWindowBuilder::Boxcar => Window::Boxcar(Boxcar { + m: nx, + sym: !fftbins.unwrap_or(true), + }), + GetWindowBuilder::Triangle => Window::Triangle(Triangle { + m: nx, + sym: !fftbins.unwrap_or(true), + }), + GetWindowBuilder::Blackman => Window::Blackman(Blackman { + m: nx, + sym: !fftbins.unwrap_or(true), + }), + GetWindowBuilder::Hamming => Window::Hamming(Hamming { + m: nx, + sym: !fftbins.unwrap_or(true), + }), + GetWindowBuilder::Nuttall => Window::Nuttall(Nuttall { + m: nx, + sym: !fftbins.unwrap_or(true), + }), + GetWindowBuilder::Kaiser { beta } => Window::Kaiser(Kaiser { + m: nx, + beta, + sym: !fftbins.unwrap_or(true), + }), + GetWindowBuilder::GeneralCosine { weights } => Window::GeneralCosine(GeneralCosine { + m: nx, + a: weights.into(), + sym: !fftbins.unwrap_or(true), + }), + GetWindowBuilder::GeneralGaussian { p, width } => { + Window::GeneralGaussian(GeneralGaussian { + m: nx, + p, + sigma: width, + sym: !fftbins.unwrap_or(true), + }) + } + GetWindowBuilder::GeneralHamming { coefficient } => { + Window::GeneralHamming(GeneralHamming { + m: nx, + alpha: coefficient, + sym: !fftbins.unwrap_or(true), + }) + } + } +} diff --git a/sci-rs/src/signal/windows/nuttall.rs b/sci-rs/src/signal/windows/nuttall.rs new file mode 100644 index 00000000..d3d37f5b --- /dev/null +++ b/sci-rs/src/signal/windows/nuttall.rs @@ -0,0 +1,192 @@ +use super::GeneralCosine; +use nalgebra::RealField; +use num_traits::{real::Real, Float}; + +#[cfg(feature = "alloc")] +use super::GetWindow; +#[cfg(feature = "alloc")] +use alloc::vec::Vec; + +/// Collection of arguments for window `Nuttall` for use in [GetWindow]. +#[derive(Debug, Clone, PartialEq)] +pub struct Nuttall { + /// Number of points in the output window. If zero, an empty array is returned in [GetWindow]. + pub m: usize, + /// Whether the window is symmetric. + /// + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub sym: bool, +} + +impl Nuttall { + /// Returns a Nuttall struct. + /// + /// # Parameters + /// * `m`: + /// Number of points in the output window. If zero, an empty array is returned. + /// * `sym`: bool + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub fn new(m: usize, sym: bool) -> Self { + Nuttall { m, sym } + } +} + +#[cfg(feature = "alloc")] +impl GetWindow for Nuttall +where + W: Real + Float + RealField, +{ + /// Return a minimum 4-term Blackman-Harris window according to Nuttall. + /// + /// This variation is called "Nuttall4c" by Heinzel. [[2]] + /// + /// # Parameters + /// * `M` : int + /// Number of points in the output window. If zero, an empty array is returned. An + /// exception is thrown when it is negative. + /// * `sym` : bool, optional + /// When True (default), generates a symmetric window, for use in filter + /// design. + /// When False, generates a periodic window, for use in spectral analysis. + /// + /// # Returns + /// `w` : ndarray + /// The window, with the maximum value normalized to 1 (though the value 1 does not appear + /// if `M` is even and `sym` is True). + /// + /// # References + /// [[1]] A. Nuttall, "Some windows with very good sidelobe behavior," IEEE Transactions on + /// Acoustics, Speech, and Signal Processing, vol. 29, no. 1, pp. 84-91, Feb 1981. + /// :doi:`10.1109/TASSP.1981.1163506`. + /// [[2]] Heinzel G. et al., "Spectrum and spectral density estimation by the Discrete Fourier + /// transform (DFT), including a comprehensive list of window functions and some new flat-top + /// windows", February 15, 2002 + /// [[3]] Scipy, + /// + /// + /// Examples + /// -------- + /// Plot the window and its frequency response: + /// + /// ```custom,{class=language-python} + /// >>> import numpy as np + /// >>> from scipy import signal + /// >>> from scipy.fft import fft, fftshift + /// >>> import matplotlib.pyplot as plt + /// ``` + /// + /// ```custom,{class=language-python} + /// >>> window = signal.windows.nuttall(51) + /// >>> plt.plot(window) + /// >>> plt.title("Nuttall window") + /// >>> plt.ylabel("Amplitude") + /// >>> plt.xlabel("Sample") + /// ``` + /// + /// ```custom,{class=language-python} + /// >>> plt.figure() + /// >>> A = fft(window, 2048) / (len(window)/2.0) + /// >>> freq = np.linspace(-0.5, 0.5, len(A)) + /// >>> response = 20 * np.log10(np.abs(fftshift(A / abs(A).max()))) + /// >>> plt.plot(freq, response) + /// >>> plt.axis([-0.5, 0.5, -120, 0]) + /// >>> plt.title("Frequency response of the Hamming window") + /// >>> plt.ylabel("Normalized magnitude [dB]") + /// >>> plt.xlabel("Normalized frequency [cycles per sample]") + /// ``` + /// + /// The equivalent is: + /// ``` + /// use sci_rs::signal::windows::{GetWindow, Nuttall}; + /// let window: Vec = Nuttall::new(51, true).get_window(); + /// ``` + /// + /// [1]: #references + /// [2]: #references + /// [3]: #references + #[cfg(feature = "alloc")] + fn get_window(&self) -> Vec { + GeneralCosine::::new( + self.m, + [0.3635819, 0.4891775, 0.1365995, 0.0106411] + .map(|f| W::from(f).unwrap()) + .into_iter() + .collect(), + self.sym, + ) + .get_window() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_abs_diff_eq; + + #[test] + fn nuttall_case_a() { + // from scipy.signal.windows import nuttall + // nuttall(17) + let h = Nuttall::new(17, true); + let expected = [ + 3.62800000e-04, + 4.15908007e-03, + 2.52055665e-02, + 8.96224370e-02, + 2.26982400e-01, + 4.44360497e-01, + 7.01958233e-01, + 9.16185585e-01, + 1.00000000e+00, + 9.16185585e-01, + 7.01958233e-01, + 4.44360497e-01, + 2.26982400e-01, + 8.96224370e-02, + 2.52055665e-02, + 4.15908007e-03, + 3.62800000e-04, + ] + .into(); + + assert_vec_eq(expected, h.get_window()); + } + + #[test] + fn nuttall_case_b() { + // from scipy.signal.windows import nuttall + // nuttall(17, false) + let h = Nuttall::new(17, false); + let expected = [ + 3.62800000e-04, + 3.64256817e-03, + 2.10918726e-02, + 7.36770505e-02, + 1.87084736e-01, + 3.73448574e-01, + 6.11072447e-01, + 8.39394793e-01, + 9.80852709e-01, + 9.80852709e-01, + 8.39394793e-01, + 6.11072447e-01, + 3.73448574e-01, + 1.87084736e-01, + 7.36770505e-02, + 2.10918726e-02, + 3.64256817e-03, + ] + .into(); + + assert_vec_eq(expected, h.get_window()); + } + + #[track_caller] + fn assert_vec_eq(a: Vec, b: Vec) { + for (a, b) in a.into_iter().zip(b) { + assert_abs_diff_eq!(a, b, epsilon = 1e-6); + } + } +} diff --git a/sci-rs/src/signal/windows/triangle.rs b/sci-rs/src/signal/windows/triangle.rs new file mode 100644 index 00000000..7a5c080b --- /dev/null +++ b/sci-rs/src/signal/windows/triangle.rs @@ -0,0 +1,198 @@ +use super::{extend, len_guard, truncate}; +use num_traits::{real::Real, Float}; + +#[cfg(feature = "alloc")] +use super::GetWindow; +#[cfg(feature = "alloc")] +use alloc::{vec, vec::Vec}; + +/// Collection of arguments for window `Triangle` for use in [GetWindow]. +#[derive(Debug, Clone, PartialEq)] +pub struct Triangle { + /// Number of points in the output window. If zero, an empty array is returned in [GetWindow]. + pub m: usize, + /// Whether the window is symmetric. + /// + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub sym: bool, +} + +impl Triangle { + /// Returns a Triangle struct. + /// + /// # Parameters + /// * `m`: + /// Number of points in the output window. If zero, an empty array is returned. + /// * `sym`: + /// When true, generates a symmetric window, for use in filter design. + /// When false, generates a periodic window, for use in spectral analysis. + pub fn new(m: usize, sym: bool) -> Self { + Triangle { m, sym } + } +} + +#[cfg(feature = "alloc")] +impl GetWindow for Triangle +where + W: Real, +{ + /// Return a window of type: Triangle. + /// + /// This is the not the same as the Bartlett window. + /// + /// # Parameters + /// `self`: [Triangle] + /// + /// # Returns + /// `w`: `vec` + /// The window, with the maximum value normalized to 1 (though the value 1 does not appear + /// if `M` is even and `sym` is True). + /// + /// # Example + /// ``` + /// use sci_rs::signal::windows::{Triangle, GetWindow}; + /// + /// let nx = 8; + /// let tri = Triangle::new(nx, true); + /// assert_eq!(vec![0.125, 0.375, 0.625, 0.875, 0.875, 0.625, 0.375, 0.125], tri.get_window()); + /// + /// let nx = 9; + /// let tri = Triangle::new(nx, false); + /// assert_eq!(vec![0.1, 0.3, 0.5, 0.7, 0.9, 0.9, 0.7, 0.5, 0.3], tri.get_window()); + /// ``` + /// + /// # References + /// + #[cfg(feature = "alloc")] + fn get_window(&self) -> Vec { + if len_guard(self.m) { + return Vec::::new(); + } + let (m, needs_trunc) = extend(self.m, self.sym); + + let mut n = (1..=(m.div_ceil(2))).map(|x| W::from(x).unwrap()); + let m_f: W = W::from(m).unwrap(); + let w: Vec = match m % 2 { + 0 => { + let mut w: Vec = n + .map(|n| (W::from(2).unwrap() * n - W::one()) / m_f) + .collect(); + w.extend(w.clone().iter().rev()); + w + } + 1 => { + let mut w: Vec = n + .map(|n| W::from(2).unwrap() * n / (m_f + W::one())) + .collect(); + w.extend(w.clone().iter().rev().skip(1)); + w + } + _ => panic!(), + }; + + truncate(w, needs_trunc) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use approx::assert_relative_eq; + + #[test] + #[cfg(feature = "alloc")] + fn case_even_true() { + // from scipy.signal.windows import triangle + // triangle(n) + + let upper = 1_000; + for i in 0..upper { + let nx = 2 * i; + let tri = Triangle::new(nx, true); + let expected: Vec = (0..nx) + .chain((0..nx).rev()) + .filter(|n| n % 2 == 1) + .map(|n| n as f64 / nx as f64) + .collect(); + let result: Vec = tri.get_window(); + + assert_eq!(expected, result); + // for (&e, t) in expected.iter().zip(tri.get_window::()) { + // assert_relative_eq!(e, t, max_relative = 1e-7) + // } + } + } + + #[test] + #[cfg(feature = "alloc")] + fn case_even_false() { + // from scipy.signal import get_window + // get_window('triangle', 4) + + let upper = 1_000; + for i in 0..upper { + let nx = 2 * i; + let tri = Triangle::new(nx, false); + let expected: Vec = (1..=(nx / 2) + 1) + .chain((1..(nx / 2) + 1).rev()) + .take(nx) + .map(|n| n as f64 / ((nx / 2) + 1) as f64) + .collect(); + + let result: Vec = tri.get_window(); + assert_eq!(expected, result); + // for (&e, t) in expected.iter().zip(tri.get_window::()) { + // assert_relative_eq!(e, t, max_relative = 1e-7) + // } + } + } + + #[test] + #[cfg(feature = "alloc")] + fn case_odd_true() { + // from scipy.signal.windows import triangle + // triangle(nx) + + let upper = 1_000; + for i in 1..upper { + let nx = 2 * i + 1; + let tri = Triangle::new(nx, true); + let expected: Vec<_> = (1..=nx.div_ceil(2)) + .chain((1..=nx.div_ceil(2)).rev().skip(1)) + .map(|n| (n as f64) / ((nx + 1) as f64 / 2.)) + .collect(); + + let result: Vec = tri.get_window(); + assert_eq!(expected, result); + // for (&e, t) in expected.iter().zip(tri.get_window::()) { + // assert_relative_eq!(e, t); + // } + } + } + + #[test] + #[cfg(feature = "alloc")] + fn case_odd_false() { + // from scipy.signal import get_window + // get_window('triangle', 5) + + let upper = 1_000; + for i in 1..upper { + let nx = 2 * i + 1; + let tri = Triangle::new(nx, false); + let expected: Vec<_> = (0..=nx) + .chain((0..=nx).rev()) + .filter(|n| n % 2 == 1) + .take(nx) + .map(|n| n as f64 / (nx as f64 + 1.)) + .collect(); + + let result: Vec = tri.get_window(); + assert_eq!(expected, result); + // for (&e, t) in expected.iter().zip(tri.get_window::()) { + // assert_relative_eq!(e, t); + // } + } + } +}