From 554af517c47ecc2d406f4ec1b19ae9315201aec6 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 20 Nov 2023 12:10:57 -0500 Subject: [PATCH] Fixing issue with event trigger average --- pynapple/process/perievent.py | 11 +++++++++-- tests/test_spike_trigger_average.py | 20 +++++++++++++++++++- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index 6d935e70..8492cdb7 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-01-30 22:59:00 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-11-19 19:13:24 +# @Last Modified time: 2023-11-20 12:08:15 import numpy as np from scipy.linalg import hankel @@ -121,7 +121,7 @@ def compute_event_trigger_average( group : TsGroup The group of Ts/Tsd objects that hold the trigger time. feature : Tsd - The 1-dimensional feature to average + The 1-dimensional feature to average. Can be a TsdFrame with one column only. binsize : float The bin size. Default is second. If different, specify with the parameter time_units ('s' [default], 'ms', 'us'). @@ -147,6 +147,13 @@ def compute_event_trigger_average( if type(group) is not nap.TsGroup: raise RuntimeError("Unknown format for group") + if isinstance(feature, nap.TsdFrame): + if feature.shape[1] == 1: + feature = feature[:, 0] + + if type(feature) is not nap.Tsd: + raise RuntimeError("Feature should be a Tsd or a TsdFrame with one column") + binsize = nap.TsIndex.format_timestamps( np.array([binsize], dtype=np.float64), time_units )[0] diff --git a/tests/test_spike_trigger_average.py b/tests/test_spike_trigger_average.py index 5a268a46..ee741f84 100644 --- a/tests/test_spike_trigger_average.py +++ b/tests/test_spike_trigger_average.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-08-29 17:27:02 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-09-18 16:33:39 +# @Last Modified time: 2023-11-20 12:07:53 #!/usr/bin/env python """Tests of spike trigger average for `pynapple` package.""" @@ -37,6 +37,12 @@ def test_compute_spike_trigger_average(): assert sta.shape == output.shape np.testing.assert_array_almost_equal(sta, output) + feature = nap.TsdFrame( + t=feature.index.values, d=feature.values[:,None], time_support=ep + ) + sta = nap.compute_event_trigger_average(spikes, feature, 0.2, (0.6, 0.6), ep) + np.testing.assert_array_almost_equal(sta, output) + def test_compute_spike_trigger_average_raise_error(): ep = nap.IntervalSet(0, 101) @@ -51,6 +57,18 @@ def test_compute_spike_trigger_average_raise_error(): nap.compute_event_trigger_average(feature, feature, 0.1, (0.5, 0.5), ep) assert str(e_info.value) == "Unknown format for group" + feature = nap.TsdFrame( + t=np.arange(0, 101, 0.01), d=np.random.rand(int(101 / 0.01), 3), time_support=ep + ) + spikes = nap.TsGroup( + {0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep + ) + with pytest.raises(Exception) as e_info: + nap.compute_event_trigger_average(spikes, feature, 0.1, (0.5, 0.5), ep) + assert str(e_info.value) == "Feature should be a Tsd or a TsdFrame with one column" + + + def test_compute_spike_trigger_average_time_units(): ep = nap.IntervalSet(0, 100)