Skip to content

Commit 20f2254

Browse files
authored
Merge pull request #336 from pynapple-org/dev
Dev
2 parents deca7e2 + 53c3221 commit 20f2254

12 files changed

+1260
-850
lines changed

docs/HISTORY.md

+6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@ and [Edoardo Balzani](https://www.simonsfoundation.org/people/edoardo-balzani/)
1414
of the Flatiron institute.
1515

1616

17+
0.7.1 (2024-09-24)
18+
------------------
19+
20+
- Fixing nan issue when computing 1d tuning curve (See issue #334).
21+
- Refactor tuning curves and correlogram tests.
22+
- Adding validators decorators for tuning curves and correlogram modules.
1723

1824
0.7.0 (2024-09-16)
1925
------------------

docs/index.md

+1-3
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@ PYthon Neural Analysis Package.
1212
pynapple is a light-weight python library for neurophysiological data analysis. The goal is to offer a versatile set of tools to study typical data in the field, i.e. time series (spike times, behavioral events, etc.) and time intervals (trials, brain states, etc.). It also provides users with generic functions for neuroscience such as tuning curves and cross-correlograms.
1313

1414
- Free software: MIT License
15-
- __Documentation__: <https://pynapple-org.github.io/pynapple>
16-
- __Notebooks and tutorials__ : <https://pynapple-org.github.io/pynapple/generated/gallery/>
17-
<!-- - __Collaborative repository__: <https://github.com/PeyracheLab/pynacollada> -->
15+
- __Documentation__: <https://pynapple.org>
1816

1917

2018
> **Note**

pynapple/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.7.0"
1+
__version__ = "0.7.1"
22
from .core import (
33
IntervalSet,
44
Ts,

pynapple/core/time_series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def size(self):
139139
return self.values.size
140140

141141
def __array__(self, dtype=None):
142-
return self.values.astype(dtype)
142+
return np.asarray(self.values, dtype=dtype)
143143

144144
def __array_ufunc__(self, ufunc, method, *args, **kwargs):
145145
# print("In __array_ufunc__")

pynapple/process/correlograms.py

+86-40
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
1-
"""Cross-correlograms """
1+
"""
2+
This module holds the functions to compute discrete cross-correlogram
3+
for timestamps data (i.e. spike times).
24
5+
| Function | Description |
6+
|------|------|
7+
| `nap.compute_autocorrelogram` | Autocorrelograms from a TsGroup object |
8+
| `nap.compute_crosscorrelogram` | Crosscorrelogram from a TsGroup object |
9+
| `nap.compute_eventcorrelogram` | Crosscorrelogram between a TsGroup object and a Ts object |
10+
11+
"""
12+
13+
import inspect
14+
from functools import wraps
315
from itertools import combinations, product
16+
from numbers import Number
417

518
import numpy as np
619
import pandas as pd
@@ -9,9 +22,53 @@
922
from .. import core as nap
1023

1124

12-
#########################################################
13-
# CORRELATION
14-
#########################################################
25+
def _validate_correlograms_inputs(func):
26+
@wraps(func)
27+
def wrapper(*args, **kwargs):
28+
# Validate each positional argument
29+
sig = inspect.signature(func)
30+
kwargs = sig.bind_partial(*args, **kwargs).arguments
31+
32+
# Only TypeError here
33+
if getattr(func, "__name__") == "compute_crosscorrelogram" and isinstance(
34+
kwargs["group"], (tuple, list)
35+
):
36+
if (
37+
not all([isinstance(g, nap.TsGroup) for g in kwargs["group"]])
38+
or len(kwargs["group"]) != 2
39+
):
40+
raise TypeError(
41+
"Invalid type. Parameter group must be of type TsGroup or a tuple/list of (TsGroup, TsGroup)."
42+
)
43+
else:
44+
if not isinstance(kwargs["group"], nap.TsGroup):
45+
msg = "Invalid type. Parameter group must be of type TsGroup"
46+
if getattr(func, "__name__") == "compute_crosscorrelogram":
47+
msg = msg + " or a tuple/list of (TsGroup, TsGroup)."
48+
raise TypeError(msg)
49+
50+
parameters_type = {
51+
"binsize": Number,
52+
"windowsize": Number,
53+
"ep": nap.IntervalSet,
54+
"norm": bool,
55+
"time_units": str,
56+
"reverse": bool,
57+
"event": (nap.Ts, nap.Tsd),
58+
}
59+
for param, param_type in parameters_type.items():
60+
if param in kwargs:
61+
if not isinstance(kwargs[param], param_type):
62+
raise TypeError(
63+
f"Invalid type. Parameter {param} must be of type {param_type}."
64+
)
65+
66+
# Call the original function with validated inputs
67+
return func(**kwargs)
68+
69+
return wrapper
70+
71+
1572
@jit(nopython=True)
1673
def _cross_correlogram(t1, t2, binsize, windowsize):
1774
"""
@@ -81,6 +138,7 @@ def _cross_correlogram(t1, t2, binsize, windowsize):
81138
return C, B
82139

83140

141+
@_validate_correlograms_inputs
84142
def compute_autocorrelogram(
85143
group, binsize, windowsize, ep=None, norm=True, time_units="s"
86144
):
@@ -118,13 +176,10 @@ def compute_autocorrelogram(
118176
RuntimeError
119177
group must be TsGroup
120178
"""
121-
if type(group) is nap.TsGroup:
122-
if isinstance(ep, nap.IntervalSet):
123-
newgroup = group.restrict(ep)
124-
else:
125-
newgroup = group
179+
if isinstance(ep, nap.IntervalSet):
180+
newgroup = group.restrict(ep)
126181
else:
127-
raise RuntimeError("Unknown format for group")
182+
newgroup = group
128183

129184
autocorrs = {}
130185

@@ -152,6 +207,7 @@ def compute_autocorrelogram(
152207
return autocorrs.astype("float")
153208

154209

210+
@_validate_correlograms_inputs
155211
def compute_crosscorrelogram(
156212
group, binsize, windowsize, ep=None, norm=True, time_units="s", reverse=False
157213
):
@@ -207,7 +263,24 @@ def compute_crosscorrelogram(
207263
np.array([windowsize], dtype=np.float64), time_units
208264
)[0]
209265

210-
if isinstance(group, nap.TsGroup):
266+
if isinstance(group, tuple):
267+
if isinstance(ep, nap.IntervalSet):
268+
newgroup = [group[i].restrict(ep) for i in range(2)]
269+
else:
270+
newgroup = group
271+
272+
pairs = product(list(newgroup[0].keys()), list(newgroup[1].keys()))
273+
274+
for i, j in pairs:
275+
spk1 = newgroup[0][i].index
276+
spk2 = newgroup[1][j].index
277+
auc, times = _cross_correlogram(spk1, spk2, binsize, windowsize)
278+
if norm:
279+
auc /= newgroup[1][j].rate
280+
crosscorrs[(i, j)] = pd.Series(index=times, data=auc, dtype="float")
281+
282+
crosscorrs = pd.DataFrame.from_dict(crosscorrs)
283+
else:
211284
if isinstance(ep, nap.IntervalSet):
212285
newgroup = group.restrict(ep)
213286
else:
@@ -232,34 +305,10 @@ def compute_crosscorrelogram(
232305
)
233306
crosscorrs = crosscorrs / freq2
234307

235-
elif (
236-
isinstance(group, (tuple, list))
237-
and len(group) == 2
238-
and all(map(lambda g: isinstance(g, nap.TsGroup), group))
239-
):
240-
if isinstance(ep, nap.IntervalSet):
241-
newgroup = [group[i].restrict(ep) for i in range(2)]
242-
else:
243-
newgroup = group
244-
245-
pairs = product(list(newgroup[0].keys()), list(newgroup[1].keys()))
246-
247-
for i, j in pairs:
248-
spk1 = newgroup[0][i].index
249-
spk2 = newgroup[1][j].index
250-
auc, times = _cross_correlogram(spk1, spk2, binsize, windowsize)
251-
if norm:
252-
auc /= newgroup[1][j].rate
253-
crosscorrs[(i, j)] = pd.Series(index=times, data=auc, dtype="float")
254-
255-
crosscorrs = pd.DataFrame.from_dict(crosscorrs)
256-
257-
else:
258-
raise RuntimeError("Unknown format for group")
259-
260308
return crosscorrs.astype("float")
261309

262310

311+
@_validate_correlograms_inputs
263312
def compute_eventcorrelogram(
264313
group, event, binsize, windowsize, ep=None, norm=True, time_units="s"
265314
):
@@ -306,10 +355,7 @@ def compute_eventcorrelogram(
306355
else:
307356
tsd1 = event.restrict(ep).index
308357

309-
if type(group) is nap.TsGroup:
310-
newgroup = group.restrict(ep)
311-
else:
312-
raise RuntimeError("Unknown format for group")
358+
newgroup = group.restrict(ep)
313359

314360
crosscorrs = {}
315361

0 commit comments

Comments
 (0)