diff --git a/sci-rs/Cargo.toml b/sci-rs/Cargo.toml index 697e7933..a6d10fd4 100644 --- a/sci-rs/Cargo.toml +++ b/sci-rs/Cargo.toml @@ -32,6 +32,7 @@ num-traits = { version = "0.2.15", default-features = false } itertools = { version = "0.13.0", default-features = false } nalgebra = { version = "0.33.2", default-features = false } ndarray = { version = "0.16.1", default-features = false } +ndarray-conv = { version = "0.4.1" } lstsq = { version = "0.6.0", default-features = false } rustfft = { version = "6.2.0", optional = true } kalmanfilt = { version = "0.3.0", default-features = false } diff --git a/sci-rs/src/signal/convolve.rs b/sci-rs/src/signal/convolve.rs index b527805e..2306181a 100644 --- a/sci-rs/src/signal/convolve.rs +++ b/sci-rs/src/signal/convolve.rs @@ -1,6 +1,10 @@ -use nalgebra::Complex; -use num_traits::{Float, FromPrimitive, Signed, Zero}; -use rustfft::{FftNum, FftPlanner}; +use ndarray::{ + Array, ArrayBase, ArrayView, Data, Dim, IntoDimension, Ix, RemoveAxis, SliceArg, SliceInfo, + SliceInfoElem, +}; +use ndarray_conv::{ConvFFTExt, ConvMode}; +use num_traits::NumAssign; +use rustfft::FftNum; /// Convolution mode determines behavior near edges and output size pub enum ConvolveMode { @@ -12,78 +16,44 @@ pub enum ConvolveMode { Same, } -/// Performs FFT-based convolution on two slices of floating point values. +/// Convolve two N-dimensional arrays using the fourier method. /// /// According to Python docs, this is generally much faster than direct convolution /// for large arrays (n > ~500), but can be slower when only a few output values are needed. -/// We only implement the FFT version in Rust for now. /// /// # Arguments -/// - `in1`: First input signal -/// - `in2`: Second input signal -/// - `mode`: Convolution mode (currently only Full is supported) +/// - `in1`: First input signal by reference. Can be `[std::vec::Vec]` or `[ndarray::Array]`. +/// - `in2`: Second input signal by reference. (Same type and dimensions as `in1`.) +/// - `mode`: [ConvolveMode] /// /// # Returns -/// A Vec containing the discrete linear convolution of `in1` with `in2`. -/// For Full mode, the output length will be `in1.len() + in2.len() - 1`. -pub fn fftconvolve(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec { - // Determine the size of the FFT (next power of 2 for zero-padding) - let n1 = in1.len(); - let n2 = in2.len(); - let n = n1 + n2 - 1; - let fft_size = n.next_power_of_two(); - - // Prepare input buffers as Complex with zero-padding to fft_size - let mut padded_in1 = vec![Complex::zero(); fft_size]; - let mut padded_in2 = vec![Complex::zero(); fft_size]; - - // Copy input data into zero-padded buffers - padded_in1.iter_mut().zip(in1.iter()).for_each(|(p, &v)| { - *p = Complex::new(v, F::zero()); - }); - padded_in2.iter_mut().zip(in2.iter()).for_each(|(p, &v)| { - *p = Complex::new(v, F::zero()); - }); - - // Perform the FFT - let mut planner = FftPlanner::new(); - let fft = planner.plan_fft_forward(fft_size); - fft.process(&mut padded_in1); - fft.process(&mut padded_in2); - - // Multiply element-wise in the frequency domain - let mut result_freq: Vec> = padded_in1 - .iter() - .zip(&padded_in2) - .map(|(a, b)| a * b) - .collect(); - - // Perform the inverse FFT - let ifft = planner.plan_fft_inverse(fft_size); - ifft.process(&mut result_freq); - - // Take only the real part, normalize, and truncate to the original output size (n) - let fft_size = F::from(fft_size).unwrap(); - let full_convolution = result_freq - .iter() - .take(n) - .map(|x| x.re / fft_size) - .collect(); - - // Extract the appropriate slice based on the mode +/// An `[Array]` containing the discrete linear convolution of `in1` with `in2`. +/// For [ConvolveMode::Full] mode, the output length will be `in1.shape() "+" in2.shape() "-" 1`. +/// For [ConvolveMode::Valid] mode, the output length will be `max(in1.shape(), + in2.shape())`. +/// For [ConvolveMode::Same] mode, the output length will be `in1.shape()`. +pub fn fftconvolve<'a, T, S, const N: usize>( + in1: ArrayBase>, + in2: ArrayBase>, + mode: ConvolveMode, +) -> Array> +where + T: NumAssign + FftNum, + S: Data + 'a, + [Ix; N]: IntoDimension>, + Dim<[Ix; N]>: RemoveAxis, + SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: + SliceArg, OutDim = Dim<[Ix; N]>>, +{ match mode { - ConvolveMode::Full => full_convolution, + ConvolveMode::Full => { + todo!() + } ConvolveMode::Valid => { - if n1 >= n2 { - full_convolution[(n2 - 1)..(n1)].to_vec() - } else { - Vec::new() - } + todo!() } ConvolveMode::Same => { - let start = (n2 - 1) / 2; - let end = start + n1; - full_convolution[start..end].to_vec() + in1.conv_fft(&in2, ConvMode::Same, ndarray_conv::PaddingMode::Zeros) + .unwrap() // TODO: Result type from core } } } @@ -91,13 +61,32 @@ pub fn fftconvolve(in1: &[F], in2: &[F], mode: ConvolveMode) /// Compute the convolution of two signals using FFT. /// /// # Arguments -/// * `in1` - First input array -/// * `in2` - Second input array +/// - `in1`: First input signal by reference. Can be `[ndarray::Array]`. +/// - `in2`: Second input signal by reference. (Same type and dimensions as `in1`.) +/// - `mode`: [ConvolveMode] /// /// # Returns -/// A Vec containing the convolution of `in1` with `in2`. -/// With Full mode, the output length will be `in1.len() + in2.len() - 1`. -pub fn convolve(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec { +/// An `[Array]` containing the discrete linear convolution of `in1` with `in2`. +/// For [ConvolveMode::Full] mode, the output length will be `in1.shape() "+" in2.shape() "-" 1`. +/// For [ConvolveMode::Valid] mode, the output length will be `max(in1.shape(), + in2.shape())`. +/// For [ConvolveMode::Same] mode, the output length will be `in1.shape()`. +/// +/// # Note +/// Automatic choice between convolution through direct summation or via FFT has yet to be done +#[inline] +pub fn convolve<'a, T, S, const N: usize>( + in1: ArrayBase>, + in2: ArrayBase>, + mode: ConvolveMode, +) -> Array> +where + T: NumAssign + FftNum, + S: Data + 'a, + [Ix; N]: IntoDimension>, + Dim<[Ix; N]>: RemoveAxis, + SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: + SliceArg, OutDim = Dim<[Ix; N]>>, +{ fftconvolve(in1, in2, mode) } @@ -111,26 +100,45 @@ pub fn convolve(in1: &[F], in2: &[F], mode: ConvolveMode) -> /// * `in2` - Second input array /// /// # Returns -/// A Vec containing the cross-correlation of `in1` with `in2`. +/// An array containing the cross-correlation of `in1` with `in2`. /// With Full mode, the output length will be `in1.len() + in2.len() - 1`. -pub fn correlate(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec { - // For correlation, we need to reverse in2 - let mut in2_rev = in2.to_vec(); - in2_rev.reverse(); - fftconvolve(in1, &in2_rev, mode) +pub fn correlate<'a, T, S, const N: usize>( + in1: ArrayBase>, + in2: ArrayBase>, + mode: ConvolveMode, +) -> Array> +where + T: NumAssign + FftNum, + S: Data + 'a, + [Ix; N]: IntoDimension>, + Dim<[Ix; N]>: RemoveAxis, + SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>: + SliceArg, OutDim = Dim<[Ix; N]>>, +{ + in1.conv_fft( + &in2.t(), + match mode { + ConvolveMode::Full => ConvMode::Full, + ConvolveMode::Valid => ConvMode::Valid, + ConvolveMode::Same => ConvMode::Same, + }, + ndarray_conv::PaddingMode::Zeros, + ) + .unwrap() // TODO: Result type from core } #[cfg(test)] mod tests { use super::*; use approx::assert_relative_eq; + use ndarray::{array, Array1, ArrayView1}; #[test] - fn test_convolve() { - let in1 = vec![1.0, 2.0, 3.0]; - let in2 = vec![4.0, 5.0, 6.0]; - let result = convolve(&in1, &in2, ConvolveMode::Full); - let expected = vec![4.0, 13.0, 28.0, 27.0, 18.0]; + fn test_convolve_full() { + let in1 = array![1.0, 2.0, 3.0]; + let in2 = array![4.0, 5.0, 6.0]; + let result: Array1<_> = convolve(in1, in2, ConvolveMode::Full); + let expected: Array1<_> = vec![4.0, 13.0, 28.0, 27.0, 18.0].into(); for (a, b) in result.iter().zip(expected.iter()) { assert_relative_eq!(a, b, epsilon = 1e-10); @@ -138,11 +146,22 @@ mod tests { } #[test] - fn test_correlate() { - let in1 = vec![1.0, 2.0, 3.0]; - let in2 = vec![4.0, 5.0, 6.0]; - let result = correlate(&in1, &in2, ConvolveMode::Full); - let expected = vec![6.0, 17.0, 32.0, 23.0, 12.0]; + fn test_correlate_full() { + let in1 = array![1.0, 2.0, 3.0]; + let in2 = array![4.0, 5.0, 6.0]; + + { + let in1_ref: ArrayView1<_> = (&in1).into(); + let in2_ref: ArrayView1<_> = (&in2).into(); + let result: Array> = correlate(in1_ref, in2_ref, ConvolveMode::Full); + let expected: Array> = vec![6.0, 17.0, 32.0, 23.0, 12.0].into(); + for (a, b) in result.iter().zip(expected.iter()) { + assert_relative_eq!(a, b, epsilon = 1e-10); + } + } + + let result: Array1<_> = correlate(in1, in2, ConvolveMode::Full); + let expected: Array1<_> = array![6.0, 17.0, 32.0, 23.0, 12.0]; for (a, b) in result.iter().zip(expected.iter()) { assert_relative_eq!(a, b, epsilon = 1e-10); } @@ -150,10 +169,10 @@ mod tests { #[test] fn test_convolve_valid() { - let in1 = vec![1.0, 2.0, 3.0, 4.0]; - let in2 = vec![1.0, 2.0]; - let result = convolve(&in1, &in2, ConvolveMode::Valid); - let expected = vec![4.0, 7.0, 10.0]; + let in1 = array![1.0, 2.0, 5.0, 7.0]; + let in2 = array![1.4, 2.2]; + let result: Array1<_> = convolve(in1, in2, ConvolveMode::Valid); + let expected: Array1<_> = array![5.0, 11.4, 20.8]; for (a, b) in result.iter().zip(expected.iter()) { assert_relative_eq!(a, b, epsilon = 1e-10); } @@ -161,10 +180,10 @@ mod tests { #[test] fn test_convolve_same() { - let in1 = vec![1.0, 2.0, 3.0, 4.0]; - let in2 = vec![1.0, 2.0, 1.0]; - let result = convolve(&in1, &in2, ConvolveMode::Same); - let expected = vec![4.0, 8.0, 12.0, 11.0]; + let in1 = array![1.0, 2.0, 3.0, 4.0]; + let in2 = array![1.0, 2.0, 1.5]; + let result: Array1<_> = convolve(in1, in2, ConvolveMode::Same); + let expected: Array1<_> = array![4.0, 8.5, 13.0, 12.5]; for (a, b) in result.iter().zip(expected.iter()) { assert_relative_eq!(a, b, epsilon = 1e-10); } @@ -179,9 +198,10 @@ mod tests { // Generate 1000 random samples from standard normal distribution let mut rng = thread_rng(); let sig: Vec = Standard.sample_iter(&mut rng).take(1000).collect(); + let sig_ref: ArrayView1<_> = (&sig).into(); // Compute autocorrelation using correlate directly - let autocorr = correlate(&sig, &sig, ConvolveMode::Full); + let autocorr = correlate(sig_ref, sig_ref, ConvolveMode::Full); // Basic sanity checks assert_eq!(autocorr.len(), 1999); // Full convolution length should be 2N-1