diff --git a/src/dspeed/processors/time_point_thresh.py b/src/dspeed/processors/time_point_thresh.py index e9ddb18..1481cbd 100644 --- a/src/dspeed/processors/time_point_thresh.py +++ b/src/dspeed/processors/time_point_thresh.py @@ -18,11 +18,10 @@ def time_point_thresh( w_in: np.ndarray, a_threshold: float, t_start: int, walk_forward: int, t_out: float ) -> None: - """Find the index where the waveform value crosses above the threshold, walking - either forward or backward from the starting index. Find only crossings where the - waveform is rising through the threshold when moving forward in time (polarity check). - Return the waveform index just before the threshold crossing (i.e. below the threshold - when searching forward and above the threshold when searching backward). + """Find the index where the waveform value crosses the threshold, walking + either forward or backward from the starting index. Find crossings where the + waveform crosses through the threshold in either direction (rising or falling). + Return the waveform index just before the threshold crossing. Parameters ---------- @@ -74,12 +73,18 @@ def time_point_thresh( if int(walk_forward) == 1: for i in range(int(t_start), len(w_in) - 1, 1): - if w_in[i] <= a_threshold < w_in[i + 1]: + # Check for crossing in either direction (rising or falling) + if (w_in[i] <= a_threshold < w_in[i + 1]) or ( + w_in[i] >= a_threshold > w_in[i + 1] + ): t_out[0] = i return else: for i in range(int(t_start), 0, -1): - if w_in[i - 1] < a_threshold <= w_in[i]: + # Check for crossing in either direction (rising or falling) + if (w_in[i - 1] < a_threshold <= w_in[i]) or ( + w_in[i - 1] > a_threshold >= w_in[i] + ): t_out[0] = i return @@ -179,7 +184,8 @@ def interpolated_time_point_thresh( """Find the time where the waveform value crosses the threshold Search performed walking either forward or backward from the starting - index. Use interpolation to estimate a time between samples. Interpolation + index. Detects crossings in both directions (rising or falling). + Use interpolation to estimate a time between samples. Interpolation mode selected with `mode_in`. Parameters @@ -250,12 +256,18 @@ def interpolated_time_point_thresh( i_cross = -1 if walk_forward > 0: for i in range(int(t_start), len(w_in) - 1, 1): - if w_in[i] <= a_threshold < w_in[i + 1]: + # Check for crossing in either direction (rising or falling) + if (w_in[i] <= a_threshold < w_in[i + 1]) or ( + w_in[i] >= a_threshold > w_in[i + 1] + ): i_cross = i break else: for i in range(int(t_start), 1, -1): - if w_in[i - 1] < a_threshold <= w_in[i]: + # Check for crossing in either direction (rising or falling) + if (w_in[i - 1] < a_threshold <= w_in[i]) or ( + w_in[i - 1] > a_threshold >= w_in[i] + ): i_cross = i - 1 break diff --git a/tests/processors/test_time_point_thresh.py b/tests/processors/test_time_point_thresh.py index 4cfa1f9..393ccff 100644 --- a/tests/processors/test_time_point_thresh.py +++ b/tests/processors/test_time_point_thresh.py @@ -86,23 +86,43 @@ def test_time_point_thresh(compare_numba_vs_python): w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float") assert compare_numba_vs_python(time_point_thresh, w_in, 3, 0, 1) == 4.0 - # -------- Differentiation tests for time_point_thresh -------- - # These tests use waveforms where time_point_thresh (with polarity check) - # produces different results than time_point_thresh_nopol (without polarity check). + # -------- Tests for both rising and falling crossings -------- + # These tests verify that time_point_thresh detects crossings in both directions. - # Test differentiation: waveform falls through threshold without rising back + # Test: waveform falls through threshold # [5, 4, 3, 2, 1, 0, -1] - threshold 2.5 walking forward from 0 - # time_point_thresh looks for w_in[i] <= threshold < w_in[i+1] (rising through) - # This never happens here since waveform is falling, so returns nan + # time_point_thresh now detects the falling crossing at index 2 (w_in[2]=3 >= 2.5 > w_in[3]=2) w_falling = np.array([5.0, 4.0, 3.0, 2.0, 1.0, 0.0, -1.0]) - assert np.isnan(compare_numba_vs_python(time_point_thresh, w_falling, 2.5, 0, 1)) + assert compare_numba_vs_python(time_point_thresh, w_falling, 2.5, 0, 1) == 2.0 - # Test differentiation: waveform that starts below threshold and rises + # Test: waveform that starts below threshold and rises # [0, 1, 2, 3, 4, 5] - threshold 2.5 walking forward from 0 # time_point_thresh finds the rising crossing at index 2 (w_in[2]=2 <= 2.5 < w_in[3]=3) w_rising = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]) assert compare_numba_vs_python(time_point_thresh, w_rising, 2.5, 0, 1) == 2.0 + # Test: negative polarity waveform (inverted falling = rising in negative direction) + # Walk forward through a negative-going pulse + w_neg_falling = np.array([0.0, -1.0, -2.0, -3.0, -4.0, -5.0]) + assert compare_numba_vs_python(time_point_thresh, w_neg_falling, -2.5, 0, 1) == 2.0 + + # Test: negative polarity waveform rising back up + # Walk forward through a negative waveform that rises back through threshold + w_neg_rising = np.array([-5.0, -4.0, -3.0, -2.0, -1.0, 0.0]) + assert compare_numba_vs_python(time_point_thresh, w_neg_rising, -2.5, 0, 1) == 2.0 + + # Test walk backward with falling waveform + # [5, 4, 3, 2, 1, 0, -1] - threshold 2.5 walking backward from 6 + # Should find the falling crossing at index 3 (w_in[2]=3 > 2.5 >= w_in[3]=2) + w_falling = np.array([5.0, 4.0, 3.0, 2.0, 1.0, 0.0, -1.0]) + assert compare_numba_vs_python(time_point_thresh, w_falling, 2.5, 6, 0) == 3.0 + + # Test walk backward with rising waveform + # [0, 1, 2, 3, 4, 5] - threshold 2.5 walking backward from 5 + # Should find the rising crossing at index 3 (w_in[2]=2 < 2.5 <= w_in[3]=3) + w_rising = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0]) + assert compare_numba_vs_python(time_point_thresh, w_rising, 2.5, 5, 0) == 3.0 + def test_time_point_thresh_nopol(compare_numba_vs_python): """Testing function for the time_point_thresh_nopol processor.""" @@ -324,6 +344,43 @@ def test_interpolated_time_point_thresh(compare_numba_vs_python): == 4.5 ) + # -------- Tests for falling waveforms with interpolation -------- + # Test linear interpolation with falling waveform + # [5, 4, 3, 2, 1, 0, -1] - threshold 2.5 walking forward from 0 + # Crossing between index 2 (value=3) and index 3 (value=2) + # Linear interpolation: 2 + (2.5 - 3) / (2 - 3) = 2 + (-0.5) / (-1) = 2 + 0.5 = 2.5 + w_falling = np.array([5.0, 4.0, 3.0, 2.0, 1.0, 0.0, -1.0]) + assert ( + compare_numba_vs_python( + interpolated_time_point_thresh, w_falling, 2.5, 0, 1, 108 + ) + == 2.5 + ) + + # Test linear interpolation with negative polarity waveform + # [0, -1, -2, -3, -4, -5] - threshold -2.5 walking forward from 0 + # Crossing between index 2 (value=-2) and index 3 (value=-3) + # Linear interpolation: 2 + (-2.5 - (-2)) / (-3 - (-2)) = 2 + (-0.5) / (-1) = 2.5 + w_neg_falling = np.array([0.0, -1.0, -2.0, -3.0, -4.0, -5.0]) + assert ( + compare_numba_vs_python( + interpolated_time_point_thresh, w_neg_falling, -2.5, 0, 1, 108 + ) + == 2.5 + ) + + # Test linear interpolation with negative polarity waveform rising back + # [-5, -4, -3, -2, -1, 0] - threshold -2.5 walking forward from 0 + # Crossing between index 2 (value=-3) and index 3 (value=-2) + # Linear interpolation: 2 + (-2.5 - (-3)) / (-2 - (-3)) = 2 + 0.5 / 1 = 2.5 + w_neg_rising = np.array([-5.0, -4.0, -3.0, -2.0, -1.0, 0.0]) + assert ( + compare_numba_vs_python( + interpolated_time_point_thresh, w_neg_rising, -2.5, 0, 1, 108 + ) + == 2.5 + ) + def test_bi_level_zero_crossing_time_points(compare_numba_vs_python): # Test exceptions and initial checks