Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sci-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ num-traits = { version = "0.2.15", default-features = false }
itertools = { version = "0.13.0", default-features = false }
nalgebra = { version = "0.33.2", default-features = false }
ndarray = { version = "0.16.1", default-features = false }
ndarray-conv = { version = "0.4.1" }
lstsq = { version = "0.6.0", default-features = false }
rustfft = { version = "6.2.0", optional = true }
kalmanfilt = { version = "0.3.0", default-features = false }
Expand Down
208 changes: 114 additions & 94 deletions sci-rs/src/signal/convolve.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
use nalgebra::Complex;
use num_traits::{Float, FromPrimitive, Signed, Zero};
use rustfft::{FftNum, FftPlanner};
use ndarray::{
Array, ArrayBase, ArrayView, Data, Dim, IntoDimension, Ix, RemoveAxis, SliceArg, SliceInfo,
SliceInfoElem,
};
use ndarray_conv::{ConvFFTExt, ConvMode};
use num_traits::NumAssign;
use rustfft::FftNum;

/// Convolution mode determines behavior near edges and output size
pub enum ConvolveMode {
Expand All @@ -12,92 +16,77 @@ pub enum ConvolveMode {
Same,
}

/// Performs FFT-based convolution on two slices of floating point values.
/// Convolve two N-dimensional arrays using the fourier method.
///
/// According to Python docs, this is generally much faster than direct convolution
/// for large arrays (n > ~500), but can be slower when only a few output values are needed.
/// We only implement the FFT version in Rust for now.
///
/// # Arguments
/// - `in1`: First input signal
/// - `in2`: Second input signal
/// - `mode`: Convolution mode (currently only Full is supported)
/// - `in1`: First input signal by reference. Can be `[std::vec::Vec]` or `[ndarray::Array]`.
/// - `in2`: Second input signal by reference. (Same type and dimensions as `in1`.)
/// - `mode`: [ConvolveMode]
///
/// # Returns
/// A Vec containing the discrete linear convolution of `in1` with `in2`.
/// For Full mode, the output length will be `in1.len() + in2.len() - 1`.
pub fn fftconvolve<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> {
// Determine the size of the FFT (next power of 2 for zero-padding)
let n1 = in1.len();
let n2 = in2.len();
let n = n1 + n2 - 1;
let fft_size = n.next_power_of_two();

// Prepare input buffers as Complex<F> with zero-padding to fft_size
let mut padded_in1 = vec![Complex::zero(); fft_size];
let mut padded_in2 = vec![Complex::zero(); fft_size];

// Copy input data into zero-padded buffers
padded_in1.iter_mut().zip(in1.iter()).for_each(|(p, &v)| {
*p = Complex::new(v, F::zero());
});
padded_in2.iter_mut().zip(in2.iter()).for_each(|(p, &v)| {
*p = Complex::new(v, F::zero());
});

// Perform the FFT
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(fft_size);
fft.process(&mut padded_in1);
fft.process(&mut padded_in2);

// Multiply element-wise in the frequency domain
let mut result_freq: Vec<Complex<F>> = padded_in1
.iter()
.zip(&padded_in2)
.map(|(a, b)| a * b)
.collect();

// Perform the inverse FFT
let ifft = planner.plan_fft_inverse(fft_size);
ifft.process(&mut result_freq);

// Take only the real part, normalize, and truncate to the original output size (n)
let fft_size = F::from(fft_size).unwrap();
let full_convolution = result_freq
.iter()
.take(n)
.map(|x| x.re / fft_size)
.collect();

// Extract the appropriate slice based on the mode
/// An `[Array]` containing the discrete linear convolution of `in1` with `in2`.
/// For [ConvolveMode::Full] mode, the output length will be `in1.shape() "+" in2.shape() "-" 1`.
/// For [ConvolveMode::Valid] mode, the output length will be `max(in1.shape(), + in2.shape())`.
/// For [ConvolveMode::Same] mode, the output length will be `in1.shape()`.
pub fn fftconvolve<'a, T, S, const N: usize>(
in1: ArrayBase<S, Dim<[Ix; N]>>,
in2: ArrayBase<S, Dim<[Ix; N]>>,
mode: ConvolveMode,
) -> Array<T, Dim<[Ix; N]>>
where
T: NumAssign + FftNum,
S: Data<Elem = T> + 'a,
[Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
Dim<[Ix; N]>: RemoveAxis,
SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
{
match mode {
ConvolveMode::Full => full_convolution,
ConvolveMode::Full => {
todo!()
}
ConvolveMode::Valid => {
if n1 >= n2 {
full_convolution[(n2 - 1)..(n1)].to_vec()
} else {
Vec::new()
}
todo!()
}
ConvolveMode::Same => {
let start = (n2 - 1) / 2;
let end = start + n1;
full_convolution[start..end].to_vec()
in1.conv_fft(&in2, ConvMode::Same, ndarray_conv::PaddingMode::Zeros)
.unwrap() // TODO: Result type from core
}
}
}

/// Compute the convolution of two signals using FFT.
///
/// # Arguments
/// * `in1` - First input array
/// * `in2` - Second input array
/// - `in1`: First input signal by reference. Can be `[ndarray::Array]`.
/// - `in2`: Second input signal by reference. (Same type and dimensions as `in1`.)
/// - `mode`: [ConvolveMode]
///
/// # Returns
/// A Vec containing the convolution of `in1` with `in2`.
/// With Full mode, the output length will be `in1.len() + in2.len() - 1`.
pub fn convolve<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> {
/// An `[Array]` containing the discrete linear convolution of `in1` with `in2`.
/// For [ConvolveMode::Full] mode, the output length will be `in1.shape() "+" in2.shape() "-" 1`.
/// For [ConvolveMode::Valid] mode, the output length will be `max(in1.shape(), + in2.shape())`.
/// For [ConvolveMode::Same] mode, the output length will be `in1.shape()`.
///
/// # Note
/// Automatic choice between convolution through direct summation or via FFT has yet to be done
#[inline]
pub fn convolve<'a, T, S, const N: usize>(
in1: ArrayBase<S, Dim<[Ix; N]>>,
in2: ArrayBase<S, Dim<[Ix; N]>>,
mode: ConvolveMode,
) -> Array<T, Dim<[Ix; N]>>
where
T: NumAssign + FftNum,
S: Data<Elem = T> + 'a,
[Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
Dim<[Ix; N]>: RemoveAxis,
SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
{
fftconvolve(in1, in2, mode)
}

Expand All @@ -111,60 +100,90 @@ pub fn convolve<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) ->
/// * `in2` - Second input array
///
/// # Returns
/// A Vec containing the cross-correlation of `in1` with `in2`.
/// An array containing the cross-correlation of `in1` with `in2`.
/// With Full mode, the output length will be `in1.len() + in2.len() - 1`.
pub fn correlate<F: Float + FftNum>(in1: &[F], in2: &[F], mode: ConvolveMode) -> Vec<F> {
// For correlation, we need to reverse in2
let mut in2_rev = in2.to_vec();
in2_rev.reverse();
fftconvolve(in1, &in2_rev, mode)
pub fn correlate<'a, T, S, const N: usize>(
in1: ArrayBase<S, Dim<[Ix; N]>>,
in2: ArrayBase<S, Dim<[Ix; N]>>,
mode: ConvolveMode,
) -> Array<T, Dim<[Ix; N]>>
where
T: NumAssign + FftNum,
S: Data<Elem = T> + 'a,
[Ix; N]: IntoDimension<Dim = Dim<[Ix; N]>>,
Dim<[Ix; N]>: RemoveAxis,
SliceInfo<[SliceInfoElem; N], Dim<[Ix; N]>, Dim<[Ix; N]>>:
SliceArg<Dim<[Ix; N]>, OutDim = Dim<[Ix; N]>>,
{
in1.conv_fft(
&in2.t(),
match mode {
ConvolveMode::Full => ConvMode::Full,
ConvolveMode::Valid => ConvMode::Valid,
ConvolveMode::Same => ConvMode::Same,
},
ndarray_conv::PaddingMode::Zeros,
)
.unwrap() // TODO: Result type from core
}

#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use ndarray::{array, Array1, ArrayView1};

#[test]
fn test_convolve() {
let in1 = vec![1.0, 2.0, 3.0];
let in2 = vec![4.0, 5.0, 6.0];
let result = convolve(&in1, &in2, ConvolveMode::Full);
let expected = vec![4.0, 13.0, 28.0, 27.0, 18.0];
fn test_convolve_full() {
let in1 = array![1.0, 2.0, 3.0];
let in2 = array![4.0, 5.0, 6.0];
let result: Array1<_> = convolve(in1, in2, ConvolveMode::Full);
let expected: Array1<_> = vec![4.0, 13.0, 28.0, 27.0, 18.0].into();

for (a, b) in result.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}

#[test]
fn test_correlate() {
let in1 = vec![1.0, 2.0, 3.0];
let in2 = vec![4.0, 5.0, 6.0];
let result = correlate(&in1, &in2, ConvolveMode::Full);
let expected = vec![6.0, 17.0, 32.0, 23.0, 12.0];
fn test_correlate_full() {
let in1 = array![1.0, 2.0, 3.0];
let in2 = array![4.0, 5.0, 6.0];

{
let in1_ref: ArrayView1<_> = (&in1).into();
let in2_ref: ArrayView1<_> = (&in2).into();
let result: Array<f64, Dim<[Ix; 1]>> = correlate(in1_ref, in2_ref, ConvolveMode::Full);
let expected: Array<f64, Dim<[Ix; 1]>> = vec![6.0, 17.0, 32.0, 23.0, 12.0].into();
for (a, b) in result.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}

let result: Array1<_> = correlate(in1, in2, ConvolveMode::Full);
let expected: Array1<_> = array![6.0, 17.0, 32.0, 23.0, 12.0];
for (a, b) in result.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}

#[test]
fn test_convolve_valid() {
let in1 = vec![1.0, 2.0, 3.0, 4.0];
let in2 = vec![1.0, 2.0];
let result = convolve(&in1, &in2, ConvolveMode::Valid);
let expected = vec![4.0, 7.0, 10.0];
let in1 = array![1.0, 2.0, 5.0, 7.0];
let in2 = array![1.4, 2.2];
let result: Array1<_> = convolve(in1, in2, ConvolveMode::Valid);
let expected: Array1<_> = array![5.0, 11.4, 20.8];
for (a, b) in result.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}

#[test]
fn test_convolve_same() {
let in1 = vec![1.0, 2.0, 3.0, 4.0];
let in2 = vec![1.0, 2.0, 1.0];
let result = convolve(&in1, &in2, ConvolveMode::Same);
let expected = vec![4.0, 8.0, 12.0, 11.0];
let in1 = array![1.0, 2.0, 3.0, 4.0];
let in2 = array![1.0, 2.0, 1.5];
let result: Array1<_> = convolve(in1, in2, ConvolveMode::Same);
let expected: Array1<_> = array![4.0, 8.5, 13.0, 12.5];
for (a, b) in result.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
Expand All @@ -179,9 +198,10 @@ mod tests {
// Generate 1000 random samples from standard normal distribution
let mut rng = thread_rng();
let sig: Vec<f64> = Standard.sample_iter(&mut rng).take(1000).collect();
let sig_ref: ArrayView1<_> = (&sig).into();

// Compute autocorrelation using correlate directly
let autocorr = correlate(&sig, &sig, ConvolveMode::Full);
let autocorr = correlate(sig_ref, sig_ref, ConvolveMode::Full);

// Basic sanity checks
assert_eq!(autocorr.len(), 1999); // Full convolution length should be 2N-1
Expand Down
Loading