1
- """Cross-correlograms """
1
+ """
2
+ This module holds the functions to compute discrete cross-correlogram
3
+ for timestamps data (i.e. spike times).
2
4
5
+ | Function | Description |
6
+ |------|------|
7
+ | `nap.compute_autocorrelogram` | Autocorrelograms from a TsGroup object |
8
+ | `nap.compute_crosscorrelogram` | Crosscorrelogram from a TsGroup object |
9
+ | `nap.compute_eventcorrelogram` | Crosscorrelogram between a TsGroup object and a Ts object |
10
+
11
+ """
12
+
13
+ import inspect
14
+ from functools import wraps
3
15
from itertools import combinations , product
16
+ from numbers import Number
4
17
5
18
import numpy as np
6
19
import pandas as pd
9
22
from .. import core as nap
10
23
11
24
12
- #########################################################
13
- # CORRELATION
14
- #########################################################
25
+ def _validate_correlograms_inputs (func ):
26
+ @wraps (func )
27
+ def wrapper (* args , ** kwargs ):
28
+ # Validate each positional argument
29
+ sig = inspect .signature (func )
30
+ kwargs = sig .bind_partial (* args , ** kwargs ).arguments
31
+
32
+ # Only TypeError here
33
+ if getattr (func , "__name__" ) == "compute_crosscorrelogram" and isinstance (
34
+ kwargs ["group" ], (tuple , list )
35
+ ):
36
+ if (
37
+ not all ([isinstance (g , nap .TsGroup ) for g in kwargs ["group" ]])
38
+ or len (kwargs ["group" ]) != 2
39
+ ):
40
+ raise TypeError (
41
+ "Invalid type. Parameter group must be of type TsGroup or a tuple/list of (TsGroup, TsGroup)."
42
+ )
43
+ else :
44
+ if not isinstance (kwargs ["group" ], nap .TsGroup ):
45
+ msg = "Invalid type. Parameter group must be of type TsGroup"
46
+ if getattr (func , "__name__" ) == "compute_crosscorrelogram" :
47
+ msg = msg + " or a tuple/list of (TsGroup, TsGroup)."
48
+ raise TypeError (msg )
49
+
50
+ parameters_type = {
51
+ "binsize" : Number ,
52
+ "windowsize" : Number ,
53
+ "ep" : nap .IntervalSet ,
54
+ "norm" : bool ,
55
+ "time_units" : str ,
56
+ "reverse" : bool ,
57
+ "event" : (nap .Ts , nap .Tsd ),
58
+ }
59
+ for param , param_type in parameters_type .items ():
60
+ if param in kwargs :
61
+ if not isinstance (kwargs [param ], param_type ):
62
+ raise TypeError (
63
+ f"Invalid type. Parameter { param } must be of type { param_type } ."
64
+ )
65
+
66
+ # Call the original function with validated inputs
67
+ return func (** kwargs )
68
+
69
+ return wrapper
70
+
71
+
15
72
@jit (nopython = True )
16
73
def _cross_correlogram (t1 , t2 , binsize , windowsize ):
17
74
"""
@@ -81,6 +138,7 @@ def _cross_correlogram(t1, t2, binsize, windowsize):
81
138
return C , B
82
139
83
140
141
+ @_validate_correlograms_inputs
84
142
def compute_autocorrelogram (
85
143
group , binsize , windowsize , ep = None , norm = True , time_units = "s"
86
144
):
@@ -118,13 +176,10 @@ def compute_autocorrelogram(
118
176
RuntimeError
119
177
group must be TsGroup
120
178
"""
121
- if type (group ) is nap .TsGroup :
122
- if isinstance (ep , nap .IntervalSet ):
123
- newgroup = group .restrict (ep )
124
- else :
125
- newgroup = group
179
+ if isinstance (ep , nap .IntervalSet ):
180
+ newgroup = group .restrict (ep )
126
181
else :
127
- raise RuntimeError ( "Unknown format for group" )
182
+ newgroup = group
128
183
129
184
autocorrs = {}
130
185
@@ -152,6 +207,7 @@ def compute_autocorrelogram(
152
207
return autocorrs .astype ("float" )
153
208
154
209
210
+ @_validate_correlograms_inputs
155
211
def compute_crosscorrelogram (
156
212
group , binsize , windowsize , ep = None , norm = True , time_units = "s" , reverse = False
157
213
):
@@ -207,7 +263,24 @@ def compute_crosscorrelogram(
207
263
np .array ([windowsize ], dtype = np .float64 ), time_units
208
264
)[0 ]
209
265
210
- if isinstance (group , nap .TsGroup ):
266
+ if isinstance (group , tuple ):
267
+ if isinstance (ep , nap .IntervalSet ):
268
+ newgroup = [group [i ].restrict (ep ) for i in range (2 )]
269
+ else :
270
+ newgroup = group
271
+
272
+ pairs = product (list (newgroup [0 ].keys ()), list (newgroup [1 ].keys ()))
273
+
274
+ for i , j in pairs :
275
+ spk1 = newgroup [0 ][i ].index
276
+ spk2 = newgroup [1 ][j ].index
277
+ auc , times = _cross_correlogram (spk1 , spk2 , binsize , windowsize )
278
+ if norm :
279
+ auc /= newgroup [1 ][j ].rate
280
+ crosscorrs [(i , j )] = pd .Series (index = times , data = auc , dtype = "float" )
281
+
282
+ crosscorrs = pd .DataFrame .from_dict (crosscorrs )
283
+ else :
211
284
if isinstance (ep , nap .IntervalSet ):
212
285
newgroup = group .restrict (ep )
213
286
else :
@@ -232,34 +305,10 @@ def compute_crosscorrelogram(
232
305
)
233
306
crosscorrs = crosscorrs / freq2
234
307
235
- elif (
236
- isinstance (group , (tuple , list ))
237
- and len (group ) == 2
238
- and all (map (lambda g : isinstance (g , nap .TsGroup ), group ))
239
- ):
240
- if isinstance (ep , nap .IntervalSet ):
241
- newgroup = [group [i ].restrict (ep ) for i in range (2 )]
242
- else :
243
- newgroup = group
244
-
245
- pairs = product (list (newgroup [0 ].keys ()), list (newgroup [1 ].keys ()))
246
-
247
- for i , j in pairs :
248
- spk1 = newgroup [0 ][i ].index
249
- spk2 = newgroup [1 ][j ].index
250
- auc , times = _cross_correlogram (spk1 , spk2 , binsize , windowsize )
251
- if norm :
252
- auc /= newgroup [1 ][j ].rate
253
- crosscorrs [(i , j )] = pd .Series (index = times , data = auc , dtype = "float" )
254
-
255
- crosscorrs = pd .DataFrame .from_dict (crosscorrs )
256
-
257
- else :
258
- raise RuntimeError ("Unknown format for group" )
259
-
260
308
return crosscorrs .astype ("float" )
261
309
262
310
311
+ @_validate_correlograms_inputs
263
312
def compute_eventcorrelogram (
264
313
group , event , binsize , windowsize , ep = None , norm = True , time_units = "s"
265
314
):
@@ -306,10 +355,7 @@ def compute_eventcorrelogram(
306
355
else :
307
356
tsd1 = event .restrict (ep ).index
308
357
309
- if type (group ) is nap .TsGroup :
310
- newgroup = group .restrict (ep )
311
- else :
312
- raise RuntimeError ("Unknown format for group" )
358
+ newgroup = group .restrict (ep )
313
359
314
360
crosscorrs = {}
315
361
0 commit comments