diff --git a/src/dspeed/processors/__init__.py b/src/dspeed/processors/__init__.py index 8ae1b2a..3ec9f08 100644 --- a/src/dspeed/processors/__init__.py +++ b/src/dspeed/processors/__init__.py @@ -64,6 +64,9 @@ # Mapping from function to name of module in which it is defined # To add a new function to processors, it must be added here! _modules = { + "mean": "arithmetic", + "mean_below_threshold": "arithmetic", + "sum": "arithmetic", "bl_subtract": "bl_subtract", "convolve_damped_oscillator": "convolutions", "convolve_exp": "convolutions", @@ -131,6 +134,7 @@ "saturation": "saturation", "soft_pileup_corr": "soft_pileup_corr", "soft_pileup_corr_bl": "soft_pileup_corr", + "sort": "sort", "svm_predict": "svm", "tf_model": "tf_model", "time_over_threshold": "time_over_threshold", @@ -138,6 +142,7 @@ "interpolated_time_point_thresh": "time_point_thresh", "multi_time_point_thresh": "time_point_thresh", "time_point_thresh": "time_point_thresh", + "time_point_thresh_nopol": "time_point_thresh", "asym_trap_filter": "trap_filters", "trap_filter": "trap_filters", "trap_norm": "trap_filters", diff --git a/src/dspeed/processors/arithmetic.py b/src/dspeed/processors/arithmetic.py new file mode 100644 index 0000000..5fcb683 --- /dev/null +++ b/src/dspeed/processors/arithmetic.py @@ -0,0 +1,170 @@ +from __future__ import annotations + +import numpy as np +from numba import guvectorize + +from ..utils import numba_defaults_kwargs as nb_kwargs + + +@guvectorize( + [ + "void(float32[:], float32[:])", + "void(float64[:], float64[:])", + ], + "(n)->()", + **nb_kwargs, +) +def sum(w_in: np.ndarray, result: float) -> None: + """Sum the waveform values from index a to b. + + Parameters + ---------- + w_in + the input waveform + result + the sum of all values in w_in. + + YAML Configuration Example + -------------------------- + + .. code-block:: yaml + + wf_sum: + function: sum + module: dspeed.processors + args: + - waveform + - wf_sum + unit: + - ADC + """ + result[0] = np.nan + + if np.isnan(w_in).any(): + return + + start = 0 + end = len(w_in) - 1 + + if start < 0: + start = 0 + if end > len(w_in) - 1: + end = len(w_in) - 1 + if start > end: + return + + total = 0.0 + for i in range(start, end + 1): + total += w_in[i] + + result[0] = total + + +@guvectorize( + [ + "void(float32[:], float32[:])", + "void(float64[:], float64[:])", + ], + "(n)->()", + **nb_kwargs, +) +def mean(w_in: np.ndarray, result: float) -> None: + """Calculate the mean of waveform values. + + Parameters + ---------- + w_in + the input waveform. + result + the mean of all values in w_in. + + YAML Configuration Example + -------------------------- + + .. code-block:: yaml + + wf_mean: + function: mean + module: dspeed.processors + args: + - waveform + - wf_mean + unit: + - ADC + """ + result[0] = np.nan + + if np.isnan(w_in).any(): + return + + start = 0 + end = len(w_in) - 1 + + if start < 0: + start = 0 + if end > len(w_in) - 1: + end = len(w_in) - 1 + if start > end: + return + + total = 0.0 + for i in range(start, end + 1): + total += w_in[i] + + result[0] = total / (end - start + 1) + + +@guvectorize( + [ + "void(float32[:], float32, float32[:])", + "void(float64[:], float64, float64[:])", + ], + "(n),()->()", + **nb_kwargs, +) +def mean_below_threshold(w_in: np.ndarray, threshold: float, result: float) -> None: + """Calculate the mean of waveform values that are below a threshold. + + Parameters + ---------- + w_in + the input waveform. + threshold + the threshold value. Only waveform values below this threshold + are included in the mean calculation. + result + the mean of all values in w_in that are below the threshold. + Returns NaN if no values are below the threshold. + + YAML Configuration Example + -------------------------- + + .. code-block:: yaml + + wf_mean_below_threshold: + function: mean_below_threshold + module: dspeed.processors + args: + - waveform + - 100 + - wf_mean_below_threshold + unit: + - ADC + """ + result[0] = np.nan + + if np.isnan(w_in).any() or np.isnan(threshold): + return + + total = 0.0 + count = 0 + + for i in range(len(w_in)): + if w_in[i] < threshold: + total += w_in[i] + count += 1 + + if count == 0: + return + + result[0] = total / count diff --git a/src/dspeed/processors/sort.py b/src/dspeed/processors/sort.py new file mode 100644 index 0000000..802694c --- /dev/null +++ b/src/dspeed/processors/sort.py @@ -0,0 +1,41 @@ +from __future__ import annotations + +import numpy as np +from numba import guvectorize + +from ..utils import numba_defaults_kwargs as nb_kwargs + + +@guvectorize( + ["void(float32[:], float32[:])", "void(float64[:], float64[:])"], + "(n)->(n)", + **nb_kwargs, +) +def sort(w_in: np.ndarray, w_out: np.ndarray) -> None: + """Return a sorted array using :func:`numpy.sort`. + + Parameters + ---------- + w_in + the input waveform. + w_out + the output sorted waveform. + + YAML Configuration Example + -------------------------- + + .. code-block:: yaml + + wf_sorted: + function: sort + module: dspeed.processors + args: + - waveform + - wf_sorted + """ + w_out[:] = np.nan + + if np.isnan(w_in).any(): + return + + w_out[:] = np.sort(w_in) diff --git a/src/dspeed/processors/time_point_thresh.py b/src/dspeed/processors/time_point_thresh.py index c6067a3..e9ddb18 100644 --- a/src/dspeed/processors/time_point_thresh.py +++ b/src/dspeed/processors/time_point_thresh.py @@ -18,8 +18,11 @@ def time_point_thresh( w_in: np.ndarray, a_threshold: float, t_start: int, walk_forward: int, t_out: float ) -> None: - """Find the index where the waveform value crosses the threshold, walking - either forward or backward from the starting index. + """Find the index where the waveform value crosses above the threshold, walking + either forward or backward from the starting index. Find only crossings where the + waveform is rising through the threshold when moving forward in time (polarity check). + Return the waveform index just before the threshold crossing (i.e. below the threshold + when searching forward and above the threshold when searching backward). Parameters ---------- @@ -81,6 +84,82 @@ def time_point_thresh( return +@guvectorize( + [ + "void(float32[:], float32, float32, float32, float32[:])", + "void(float64[:], float64, float64, float64, float64[:])", + ], + "(n),(),(),()->()", + **nb_kwargs, +) +def time_point_thresh_nopol( + w_in: np.ndarray, a_threshold: float, t_start: int, walk_forward: int, t_out: float +) -> None: + """Find the index where the waveform value crosses below a threshold, walking + either forward or backward from the starting index, without polarity check. + I.e., find the first crossing in the specified direction, regardless of whether the waveform is rising or falling. + Return the waveform index just above the threshold crossing. + + Parameters + ---------- + w_in + the input waveform. + a_threshold + the threshold value. + t_start + the starting index. + walk_forward + the backward (``0``) or forward (``1``) search direction. + t_out + the index where the waveform value crosses the threshold. + + YAML Configuration Example + -------------------------- + + .. code-block:: yaml + + tp_0: + function: time_point_thresh + module: dspeed.processors + args: + - wf_atrap + - bl_std + - tp_start + - 0 + - tp_0 + unit: ns + """ + t_out[0] = np.nan + + if ( + np.isnan(w_in).any() + or np.isnan(a_threshold) + or np.isnan(t_start) + or np.isnan(walk_forward) + ): + return + + if np.floor(t_start) != t_start: + raise DSPFatal("The starting index must be an integer") + + if np.floor(walk_forward) != walk_forward: + raise DSPFatal("The search direction must be an integer") + + if int(t_start) < 0 or int(t_start) >= len(w_in): + raise DSPFatal("The starting index is out of range") + + if int(walk_forward) == 1: + for i in range(int(t_start), len(w_in) - 2, 1): + if w_in[i + 1] <= a_threshold: + t_out[0] = i + return + else: + for i in range(int(t_start), 0, -1): + if w_in[i - 1] < a_threshold: + t_out[0] = i + return + + @guvectorize( [ "void(float32[:], float32, float32, int64, char, float32[:])", diff --git a/tests/processors/test_arithmetic.py b/tests/processors/test_arithmetic.py new file mode 100644 index 0000000..6806477 --- /dev/null +++ b/tests/processors/test_arithmetic.py @@ -0,0 +1,117 @@ +import numpy as np +import pytest + +from dspeed.processors import mean, mean_below_threshold, sum + + +@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +def test_sum_basic(compare_numba_vs_python): + """Test basic sum functionality.""" + # Test sum of entire array + w_in = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + result = compare_numba_vs_python(sum, w_in) + assert np.isclose(result, 15.0) + + +@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +def test_sum_with_nan_input(compare_numba_vs_python): + """Test sum returns nan if input contains nan.""" + w_in = np.array([1.0, 2.0, np.nan, 4.0, 5.0]) + result = compare_numba_vs_python(sum, w_in) + assert np.isnan(result) + + +def test_sum_single_element(compare_numba_vs_python): + """Test sum with single element array.""" + w_in = np.array([42.0]) + result = compare_numba_vs_python(sum, w_in) + assert np.isclose(result, 42.0) + + +@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +def test_mean_basic(compare_numba_vs_python): + """Test basic mean functionality.""" + # Mean of entire array: 15/5 = 3 + w_in = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + result = compare_numba_vs_python(mean, w_in) + assert np.isclose(result, 3.0) + + +@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +def test_mean_with_nan_input(compare_numba_vs_python): + """Test mean returns nan if input contains nan.""" + w_in = np.array([1.0, 2.0, np.nan, 4.0, 5.0]) + result = compare_numba_vs_python(mean, w_in) + assert np.isnan(result) + + +def test_mean_single_element(compare_numba_vs_python): + """Test mean with single element array.""" + w_in = np.array([42.0]) + result = compare_numba_vs_python(mean, w_in) + assert np.isclose(result, 42.0) + + +@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +def test_mean_below_threshold_basic(compare_numba_vs_python): + """Test basic mean_below_threshold functionality.""" + # Values below 4.0: 1, 2, 3. Mean = (1 + 2 + 3) / 3 = 2 + w_in = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + result = compare_numba_vs_python(mean_below_threshold, w_in, 4.0) + assert np.isclose(result, 2.0) + + +@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +def test_mean_below_threshold_all_above(compare_numba_vs_python): + """Test mean_below_threshold when all values are above threshold.""" + # All values >= 10.0, should return NaN + w_in = np.array([10.0, 20.0, 30.0, 40.0, 50.0]) + result = compare_numba_vs_python(mean_below_threshold, w_in, 10.0) + assert np.isnan(result) + + +@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +def test_mean_below_threshold_all_below(compare_numba_vs_python): + """Test mean_below_threshold when all values are below threshold.""" + # All values < 100.0. Mean = 15 / 5 = 3 + w_in = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + result = compare_numba_vs_python(mean_below_threshold, w_in, 100.0) + assert np.isclose(result, 3.0) + + +@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +def test_mean_below_threshold_with_nan_input(compare_numba_vs_python): + """Test mean_below_threshold returns nan if input contains nan.""" + w_in = np.array([1.0, 2.0, np.nan, 4.0, 5.0]) + result = compare_numba_vs_python(mean_below_threshold, w_in, 4.0) + assert np.isnan(result) + + +@pytest.mark.filterwarnings("ignore:invalid value encountered:RuntimeWarning") +def test_mean_below_threshold_with_nan_threshold(compare_numba_vs_python): + """Test mean_below_threshold returns nan if threshold is nan.""" + w_in = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + result = compare_numba_vs_python(mean_below_threshold, w_in, np.nan) + assert np.isnan(result) + + +def test_mean_below_threshold_negative_values(compare_numba_vs_python): + """Test mean_below_threshold with negative values.""" + # Values below 0.0: -2, -1. Mean = (-2 + -1) / 2 = -1.5 + w_in = np.array([-2.0, -1.0, 0.0, 1.0, 2.0]) + result = compare_numba_vs_python(mean_below_threshold, w_in, 0.0) + assert np.isclose(result, -1.5) + + +def test_mean_below_threshold_single_element_below(compare_numba_vs_python): + """Test mean_below_threshold with single element below threshold.""" + w_in = np.array([42.0]) + result = compare_numba_vs_python(mean_below_threshold, w_in, 50.0) + assert np.isclose(result, 42.0) + + +def test_mean_below_threshold_single_element_above(compare_numba_vs_python): + """Test mean_below_threshold with single element above threshold.""" + w_in = np.array([42.0]) + result = compare_numba_vs_python(mean_below_threshold, w_in, 30.0) + assert np.isnan(result) diff --git a/tests/processors/test_sort.py b/tests/processors/test_sort.py new file mode 100644 index 0000000..dcecbb2 --- /dev/null +++ b/tests/processors/test_sort.py @@ -0,0 +1,31 @@ +import numpy as np + +from dspeed.processors import sort + + +def test_sort(compare_numba_vs_python): + """Testing function for the sort processor.""" + + # test basic sorting functionality + w_in = np.array([5.0, 2.0, 8.0, 1.0, 9.0, 3.0]) + w_out_expected = np.array([1.0, 2.0, 3.0, 5.0, 8.0, 9.0]) + assert np.allclose(compare_numba_vs_python(sort, w_in), w_out_expected) + + # test with negative values + w_in = np.array([3.0, -1.0, 2.0, -5.0, 0.0]) + w_out_expected = np.array([-5.0, -1.0, 0.0, 2.0, 3.0]) + assert np.allclose(compare_numba_vs_python(sort, w_in), w_out_expected) + + # test with already sorted array + w_in = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + w_out_expected = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + assert np.allclose(compare_numba_vs_python(sort, w_in), w_out_expected) + + # test with reverse sorted array + w_in = np.array([5.0, 4.0, 3.0, 2.0, 1.0]) + w_out_expected = np.array([1.0, 2.0, 3.0, 4.0, 5.0]) + assert np.allclose(compare_numba_vs_python(sort, w_in), w_out_expected) + + # test that nan in w_in produces all nans in output + w_in = np.array([1.0, 2.0, np.nan, 4.0, 5.0]) + assert np.all(np.isnan(compare_numba_vs_python(sort, w_in))) diff --git a/tests/processors/test_time_point_thresh.py b/tests/processors/test_time_point_thresh.py index 0c783f5..4cfa1f9 100644 --- a/tests/processors/test_time_point_thresh.py +++ b/tests/processors/test_time_point_thresh.py @@ -7,6 +7,7 @@ interpolated_time_point_thresh, rc_cr2, time_point_thresh, + time_point_thresh_nopol, ) @@ -85,6 +86,141 @@ def test_time_point_thresh(compare_numba_vs_python): w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float") assert compare_numba_vs_python(time_point_thresh, w_in, 3, 0, 1) == 4.0 + # -------- Differentiation tests for time_point_thresh -------- + # These tests use waveforms where time_point_thresh (with polarity check) + # produces different results than time_point_thresh_nopol (without polarity check). + + # Test differentiation: waveform falls through threshold without rising back + # [5, 4, 3, 2, 1, 0, -1] - threshold 2.5 walking forward from 0 + # time_point_thresh looks for w_in[i] <= threshold < w_in[i+1] (rising through) + # This never happens here since waveform is falling, so returns nan + w_falling = np.array([5.0, 4.0, 3.0, 2.0, 1.0, 0.0, -1.0]) + assert np.isnan(compare_numba_vs_python(time_point_thresh, w_falling, 2.5, 0, 1)) + + # Test differentiation: waveform that starts below threshold and rises + # [0, 1, 2, 3, 4, 5] - threshold 2.5 walking forward from 0 + # time_point_thresh finds the rising crossing at index 2 (w_in[2]=2 <= 2.5 < w_in[3]=3) + w_rising = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]) + assert compare_numba_vs_python(time_point_thresh, w_rising, 2.5, 0, 1) == 2.0 + + +def test_time_point_thresh_nopol(compare_numba_vs_python): + """Testing function for the time_point_thresh_nopol processor.""" + + # test for nan if w_in has a nan + w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float") + w_in[4] = np.nan + assert np.isnan( + compare_numba_vs_python( + time_point_thresh_nopol, + w_in, + 1, + 11, + 0, + ) + ) + + # test for nan if nan is passed to a_threshold + w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float") + assert np.isnan( + compare_numba_vs_python( + time_point_thresh_nopol, + w_in, + np.nan, + 11, + 0, + ) + ) + + # test for nan if nan is passed to t_start + w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float") + assert np.isnan( + compare_numba_vs_python( + time_point_thresh_nopol, + w_in, + 1, + np.nan, + 0, + ) + ) + + # test for nan if nan is passed to walk_forward + w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float") + assert np.isnan( + compare_numba_vs_python( + time_point_thresh_nopol, + w_in, + 1, + 11, + np.nan, + ) + ) + + # test for error if t_start non integer + with pytest.raises(DSPFatal): + w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float") + time_point_thresh_nopol(w_in, 1, 10.5, 0, np.array([0.0])) + + # test for error if walk_forward non integer + with pytest.raises(DSPFatal): + w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float") + time_point_thresh_nopol(w_in, 1, 11, 0.5, np.array([0.0])) + + # test for error if t_start out of range + with pytest.raises(DSPFatal): + w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float") + time_point_thresh_nopol(w_in, 1, 12, 0, np.array([0.0])) + + # test walk backward - finds first point where w_in[i-1] < threshold + w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float") + # w_in = [-1, 0, 1, 2, 3, 4, -1, 0, 1, 2, 3, 4], threshold=1, start=11 + # Walking backward from index 11, find first i where w_in[i-1] < 1 + # i=11: w_in[10]=3 < 1? No + # i=10: w_in[9]=2 < 1? No + # i=9: w_in[8]=1 < 1? No + # i=8: w_in[7]=0 < 1? Yes -> return 8 + assert compare_numba_vs_python(time_point_thresh_nopol, w_in, 1, 11, 0) == 8.0 + + # test walk forward - finds the index before the first point where w_in[i+1] <= threshold + w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float") + # w_in = [-1, 0, 1, 2, 3, 4, -1, 0, 1, 2, 3, 4], threshold=3, start=0 + # Walking forward from index 0, find first i where w_in[i+1] <= 3 + # i=0: w_in[1]=0 <= 3? Yes -> return 0 + assert compare_numba_vs_python(time_point_thresh_nopol, w_in, 3, 0, 1) == 0.0 + + # -------- Differentiation tests for time_point_thresh_nopol -------- + # These tests use waveforms where time_point_thresh_nopol (without polarity check) + # produces different results than time_point_thresh (with polarity check). + + # Test differentiation: waveform falls through threshold without rising back + # [5, 4, 3, 2, 1, 0, -1] - threshold 2.5 walking forward from 0 + # time_point_thresh_nopol looks for w_in[i+1] <= threshold (index before first at or below) + # i=0: w_in[1]=4 <= 2.5? No + # i=1: w_in[2]=3 <= 2.5? No + # i=2: w_in[3]=2 <= 2.5? Yes -> return 2 + w_falling = np.array([5.0, 4.0, 3.0, 2.0, 1.0, 0.0, -1.0]) + assert compare_numba_vs_python(time_point_thresh_nopol, w_falling, 2.5, 0, 1) == 2.0 + + # Test differentiation: waveform that starts below threshold and rises + # [0, 1, 2, 3, 4, 5] - threshold 2.5 walking forward from 0 + # time_point_thresh_nopol finds first i where w_in[i+1] <= 2.5 + # i=0: w_in[1]=1 <= 2.5? Yes -> return 0 + w_rising = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]) + assert compare_numba_vs_python(time_point_thresh_nopol, w_rising, 2.5, 0, 1) == 0.0 + + # Test walk forward with no crossing - all values above threshold + # time_point_thresh_nopol returns nan when no point is at or below threshold + w_above = np.array([5.0, 6.0, 7.0, 8.0, 9.0]) + assert np.isnan( + compare_numba_vs_python(time_point_thresh_nopol, w_above, 2.5, 0, 1) + ) + + # Test walk backward with no crossing - all values above threshold + w_above = np.array([5.0, 6.0, 7.0, 8.0, 9.0]) + assert np.isnan( + compare_numba_vs_python(time_point_thresh_nopol, w_above, 2.5, 4, 0) + ) + def test_interpolated_time_point_thresh(compare_numba_vs_python): """Testing function for the interpolated_time_point_thresh processor."""