Skip to content

Commit

Permalink
Merge pull request #209 from pynapple-org/dev
Browse files Browse the repository at this point in the history
Fixing issue with event trigger average
  • Loading branch information
gviejo authored Nov 20, 2023
2 parents a3530d1 + 554af51 commit a8d9f4d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
11 changes: 9 additions & 2 deletions pynapple/process/perievent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: gviejo
# @Date: 2022-01-30 22:59:00
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-11-19 19:13:24
# @Last Modified time: 2023-11-20 12:08:15

import numpy as np
from scipy.linalg import hankel
Expand Down Expand Up @@ -121,7 +121,7 @@ def compute_event_trigger_average(
group : TsGroup
The group of Ts/Tsd objects that hold the trigger time.
feature : Tsd
The 1-dimensional feature to average
The 1-dimensional feature to average. Can be a TsdFrame with one column only.
binsize : float
The bin size. Default is second.
If different, specify with the parameter time_units ('s' [default], 'ms', 'us').
Expand All @@ -147,6 +147,13 @@ def compute_event_trigger_average(
if type(group) is not nap.TsGroup:
raise RuntimeError("Unknown format for group")

if isinstance(feature, nap.TsdFrame):
if feature.shape[1] == 1:
feature = feature[:, 0]

if type(feature) is not nap.Tsd:
raise RuntimeError("Feature should be a Tsd or a TsdFrame with one column")

binsize = nap.TsIndex.format_timestamps(
np.array([binsize], dtype=np.float64), time_units
)[0]
Expand Down
20 changes: 19 additions & 1 deletion tests/test_spike_trigger_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: gviejo
# @Date: 2022-08-29 17:27:02
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-09-18 16:33:39
# @Last Modified time: 2023-11-20 12:07:53
#!/usr/bin/env python

"""Tests of spike trigger average for `pynapple` package."""
Expand Down Expand Up @@ -37,6 +37,12 @@ def test_compute_spike_trigger_average():
assert sta.shape == output.shape
np.testing.assert_array_almost_equal(sta, output)

feature = nap.TsdFrame(
t=feature.index.values, d=feature.values[:,None], time_support=ep
)
sta = nap.compute_event_trigger_average(spikes, feature, 0.2, (0.6, 0.6), ep)
np.testing.assert_array_almost_equal(sta, output)


def test_compute_spike_trigger_average_raise_error():
ep = nap.IntervalSet(0, 101)
Expand All @@ -51,6 +57,18 @@ def test_compute_spike_trigger_average_raise_error():
nap.compute_event_trigger_average(feature, feature, 0.1, (0.5, 0.5), ep)
assert str(e_info.value) == "Unknown format for group"

feature = nap.TsdFrame(
t=np.arange(0, 101, 0.01), d=np.random.rand(int(101 / 0.01), 3), time_support=ep
)
spikes = nap.TsGroup(
{0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep
)
with pytest.raises(Exception) as e_info:
nap.compute_event_trigger_average(spikes, feature, 0.1, (0.5, 0.5), ep)
assert str(e_info.value) == "Feature should be a Tsd or a TsdFrame with one column"




def test_compute_spike_trigger_average_time_units():
ep = nap.IntervalSet(0, 100)
Expand Down

0 comments on commit a8d9f4d

Please sign in to comment.