Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 22 additions & 10 deletions src/dspeed/processors/time_point_thresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@
def time_point_thresh(
w_in: np.ndarray, a_threshold: float, t_start: int, walk_forward: int, t_out: float
) -> None:
"""Find the index where the waveform value crosses above the threshold, walking
either forward or backward from the starting index. Find only crossings where the
waveform is rising through the threshold when moving forward in time (polarity check).
Return the waveform index just before the threshold crossing (i.e. below the threshold
when searching forward and above the threshold when searching backward).
"""Find the index where the waveform value crosses the threshold, walking
either forward or backward from the starting index. Find crossings where the
waveform crosses through the threshold in either direction (rising or falling).
Return the waveform index just before the threshold crossing.

Parameters
----------
Expand Down Expand Up @@ -74,12 +73,18 @@ def time_point_thresh(

if int(walk_forward) == 1:
for i in range(int(t_start), len(w_in) - 1, 1):
if w_in[i] <= a_threshold < w_in[i + 1]:
# Check for crossing in either direction (rising or falling)
if (w_in[i] <= a_threshold < w_in[i + 1]) or (
w_in[i] >= a_threshold > w_in[i + 1]
):
t_out[0] = i
return
else:
for i in range(int(t_start), 0, -1):
if w_in[i - 1] < a_threshold <= w_in[i]:
# Check for crossing in either direction (rising or falling)
if (w_in[i - 1] < a_threshold <= w_in[i]) or (
w_in[i - 1] > a_threshold >= w_in[i]
):
t_out[0] = i
return

Expand Down Expand Up @@ -179,7 +184,8 @@ def interpolated_time_point_thresh(
"""Find the time where the waveform value crosses the threshold

Search performed walking either forward or backward from the starting
index. Use interpolation to estimate a time between samples. Interpolation
index. Detects crossings in both directions (rising or falling).
Use interpolation to estimate a time between samples. Interpolation
mode selected with `mode_in`.

Parameters
Expand Down Expand Up @@ -250,12 +256,18 @@ def interpolated_time_point_thresh(
i_cross = -1
if walk_forward > 0:
for i in range(int(t_start), len(w_in) - 1, 1):
if w_in[i] <= a_threshold < w_in[i + 1]:
# Check for crossing in either direction (rising or falling)
if (w_in[i] <= a_threshold < w_in[i + 1]) or (
w_in[i] >= a_threshold > w_in[i + 1]
):
i_cross = i
break
else:
for i in range(int(t_start), 1, -1):
if w_in[i - 1] < a_threshold <= w_in[i]:
# Check for crossing in either direction (rising or falling)
if (w_in[i - 1] < a_threshold <= w_in[i]) or (
w_in[i - 1] > a_threshold >= w_in[i]
):
i_cross = i - 1
break

Expand Down
73 changes: 65 additions & 8 deletions tests/processors/test_time_point_thresh.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,43 @@ def test_time_point_thresh(compare_numba_vs_python):
w_in = np.concatenate([np.arange(-1, 5, 1), np.arange(-1, 5, 1)], dtype="float")
assert compare_numba_vs_python(time_point_thresh, w_in, 3, 0, 1) == 4.0

# -------- Differentiation tests for time_point_thresh --------
# These tests use waveforms where time_point_thresh (with polarity check)
# produces different results than time_point_thresh_nopol (without polarity check).
# -------- Tests for both rising and falling crossings --------
# These tests verify that time_point_thresh detects crossings in both directions.

# Test differentiation: waveform falls through threshold without rising back
# Test: waveform falls through threshold
# [5, 4, 3, 2, 1, 0, -1] - threshold 2.5 walking forward from 0
# time_point_thresh looks for w_in[i] <= threshold < w_in[i+1] (rising through)
# This never happens here since waveform is falling, so returns nan
# time_point_thresh now detects the falling crossing at index 2 (w_in[2]=3 >= 2.5 > w_in[3]=2)
w_falling = np.array([5.0, 4.0, 3.0, 2.0, 1.0, 0.0, -1.0])
assert np.isnan(compare_numba_vs_python(time_point_thresh, w_falling, 2.5, 0, 1))
assert compare_numba_vs_python(time_point_thresh, w_falling, 2.5, 0, 1) == 2.0

# Test differentiation: waveform that starts below threshold and rises
# Test: waveform that starts below threshold and rises
# [0, 1, 2, 3, 4, 5] - threshold 2.5 walking forward from 0
# time_point_thresh finds the rising crossing at index 2 (w_in[2]=2 <= 2.5 < w_in[3]=3)
w_rising = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
assert compare_numba_vs_python(time_point_thresh, w_rising, 2.5, 0, 1) == 2.0

# Test: negative polarity waveform (inverted falling = rising in negative direction)
# Walk forward through a negative-going pulse
w_neg_falling = np.array([0.0, -1.0, -2.0, -3.0, -4.0, -5.0])
assert compare_numba_vs_python(time_point_thresh, w_neg_falling, -2.5, 0, 1) == 2.0

# Test: negative polarity waveform rising back up
# Walk forward through a negative waveform that rises back through threshold
w_neg_rising = np.array([-5.0, -4.0, -3.0, -2.0, -1.0, 0.0])
assert compare_numba_vs_python(time_point_thresh, w_neg_rising, -2.5, 0, 1) == 2.0

# Test walk backward with falling waveform
# [5, 4, 3, 2, 1, 0, -1] - threshold 2.5 walking backward from 6
# Should find the falling crossing at index 3 (w_in[2]=3 > 2.5 >= w_in[3]=2)
w_falling = np.array([5.0, 4.0, 3.0, 2.0, 1.0, 0.0, -1.0])
assert compare_numba_vs_python(time_point_thresh, w_falling, 2.5, 6, 0) == 3.0

# Test walk backward with rising waveform
# [0, 1, 2, 3, 4, 5] - threshold 2.5 walking backward from 5
# Should find the rising crossing at index 3 (w_in[2]=2 < 2.5 <= w_in[3]=3)
w_rising = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5.0])
assert compare_numba_vs_python(time_point_thresh, w_rising, 2.5, 5, 0) == 3.0


def test_time_point_thresh_nopol(compare_numba_vs_python):
"""Testing function for the time_point_thresh_nopol processor."""
Expand Down Expand Up @@ -324,6 +344,43 @@ def test_interpolated_time_point_thresh(compare_numba_vs_python):
== 4.5
)

# -------- Tests for falling waveforms with interpolation --------
# Test linear interpolation with falling waveform
# [5, 4, 3, 2, 1, 0, -1] - threshold 2.5 walking forward from 0
# Crossing between index 2 (value=3) and index 3 (value=2)
# Linear interpolation: 2 + (2.5 - 3) / (2 - 3) = 2 + (-0.5) / (-1) = 2 + 0.5 = 2.5
w_falling = np.array([5.0, 4.0, 3.0, 2.0, 1.0, 0.0, -1.0])
assert (
compare_numba_vs_python(
interpolated_time_point_thresh, w_falling, 2.5, 0, 1, 108
)
== 2.5
)

# Test linear interpolation with negative polarity waveform
# [0, -1, -2, -3, -4, -5] - threshold -2.5 walking forward from 0
# Crossing between index 2 (value=-2) and index 3 (value=-3)
# Linear interpolation: 2 + (-2.5 - (-2)) / (-3 - (-2)) = 2 + (-0.5) / (-1) = 2.5
w_neg_falling = np.array([0.0, -1.0, -2.0, -3.0, -4.0, -5.0])
assert (
compare_numba_vs_python(
interpolated_time_point_thresh, w_neg_falling, -2.5, 0, 1, 108
)
== 2.5
)

# Test linear interpolation with negative polarity waveform rising back
# [-5, -4, -3, -2, -1, 0] - threshold -2.5 walking forward from 0
# Crossing between index 2 (value=-3) and index 3 (value=-2)
# Linear interpolation: 2 + (-2.5 - (-3)) / (-2 - (-3)) = 2 + 0.5 / 1 = 2.5
w_neg_rising = np.array([-5.0, -4.0, -3.0, -2.0, -1.0, 0.0])
assert (
compare_numba_vs_python(
interpolated_time_point_thresh, w_neg_rising, -2.5, 0, 1, 108
)
== 2.5
)


def test_bi_level_zero_crossing_time_points(compare_numba_vs_python):
# Test exceptions and initial checks
Expand Down
Loading