From 5d6e6ce437124818dbcd2f2f6d8329c61125c800 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 21 Nov 2023 16:11:42 -0500 Subject: [PATCH 1/7] Testing convolution and smooth method --- pynapple/core/jitted_functions.py | 33 +++++++++++++++++++++++- pynapple/core/time_series.py | 43 ++++++++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 2 deletions(-) diff --git a/pynapple/core/jitted_functions.py b/pynapple/core/jitted_functions.py index 310f4e1e..a9af4c75 100644 --- a/pynapple/core/jitted_functions.py +++ b/pynapple/core/jitted_functions.py @@ -2,7 +2,7 @@ # @Author: guillaume # @Date: 2022-10-31 16:44:31 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-11-19 18:27:43 +# @Last Modified time: 2023-11-20 19:23:47 import numpy as np from numba import jit @@ -773,6 +773,37 @@ def jitremove_nan(time_array, index_nan): ends = time_array[ix_end] return (starts, ends) +@jit(nopython=True) +def jitconvolve(time_array, data_array, starts, ends, array): + time_array, data_array, countin = jitrestrict_with_count( + time_array, data_array, starts, ends + ) + + m = starts.shape[0] + f = data_array.shape[1:] + n = time_array.shape[0] + new_data_array = np.zeros((n, *f), dtype=np.float64) + wsize = array.shape[0] + + k = 0 # epochs count + t = 0 # time points count + i = 0 # window position + + while k < m: + maxt = t + countin[k] + wn = 1 + i = 0 + while t < maxt: + new_data_array[t] = np.sum(array[0:i+wn]*data_array[t-wn+1:t+1]) + + if wn < wsize: + wn += 1 + + t += 1 + + k += 1 + + return new_data_array @jit(nopython=True) def jit_poisson_IRLS(X, y, niter=100, tolerance=1e-5): diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index e1626245..eeddfc86 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-01-27 18:33:31 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-11-19 18:59:08 +# @Last Modified time: 2023-11-20 19:35:20 """ @@ -33,6 +33,7 @@ import pandas as pd from numpy.lib.mixins import NDArrayOperatorsMixin from tabulate import tabulate +from scipy import signal from .interval_set import IntervalSet from .jitted_functions import ( @@ -46,6 +47,7 @@ jittsrestrict_with_count, jitvaluefrom, jitvaluefromtensor, + jitconvolve, ) from .time_index import TsIndex @@ -830,6 +832,45 @@ def dropna(self, update_time_support=True): else: return self + def convolve(self, array, ep = None): + """Things to assume : constant sampling rate + + Parameters + ---------- + array : np.ndarray + One dimensional input array + """ + assert isinstance(array, np.ndarray) + if ep is None: + ep = self.time_support + time_array = self.index.values + data_array = self.values + starts = ep.start.values + ends = ep.end.values + + new_data_array = jitconvolve(time_array, data_array, starts, ends, array) + + return self.__class__(time_array, new_data_array, time_support=self.time_support) + + def smooth(self, std, size): + """Smooth with a gaussan kernel + + Parameters + ---------- + std : TYPE + Description + size : TYPE + Description + + Returns + ------- + TYPE + Description + """ + window = signal.windows.gaussian(size, std=std) + window = window/window.sum() + return self.convolve(window) + class TsdTensor(NDArrayOperatorsMixin, _AbstractTsd): """ From 6743124ec4d544040563e67c5f53bd26aea1fe7e Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 21 Nov 2023 18:53:59 -0500 Subject: [PATCH 2/7] Adding fastpltlib demo --- .../examples/tutorial_pynapple_fastplotlib.py | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) create mode 100644 docs/examples/tutorial_pynapple_fastplotlib.py diff --git a/docs/examples/tutorial_pynapple_fastplotlib.py b/docs/examples/tutorial_pynapple_fastplotlib.py new file mode 100644 index 00000000..250326fb --- /dev/null +++ b/docs/examples/tutorial_pynapple_fastplotlib.py @@ -0,0 +1,83 @@ +# -*- coding: utf-8 -*- +""" +Fastplotlib +=========== + +Working with calcium data. + +For the example dataset, we will be working with a recording of a freely-moving mouse imaged with a Miniscope (1-photon imaging). The area recorded for this experiment is the postsubiculum - a region that is known to contain head-direction cells, or cells that fire when the animal's head is pointing in a specific direction. + +The NWB file for the example is hosted on [OSF](https://osf.io/sbnaw). We show below how to stream it. + +See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. + +This tutorial was made by Sofia Skromne Carrasco and Guillaume Viejo. + +""" +# %% +# !!! warning +# This tutorial uses seaborn and matplotlib for displaying the figure +# +# You can install all with `pip install matplotlib seaborn tqdm` +# +# mkdocs_gallery_thumbnail_number = 1 +# +# Now, import the necessary libraries: + +import pynapple as nap +import numpy as np +import fastplotlib as fpl +#from PyQt6 import QtWidgets +import imageio.v3 as iio +import sys +# mkdocs_gallery_thumbnail_path = '../_static/fastplotlib_demo.png' + +nwb = nap.load_file("/Users/gviejo/pynapple/Mouse32-220101.nwb") + +units = nwb['units'].getby_category("location")['adn'] + +tmp = units.to_tsd() + +tmp = np.vstack((tmp.index.values, tmp.values)).T + +fplot = fpl.Plot() + +fplot.add_scatter(tmp) + +fplot.graphics[0].cmap = "jet" + +fplot.graphics[0].cmap.values = tmp[:, 1] + +fplot.show() + + +sys.exit() +# %% +# *** + +#frames = iio.imread("/Users/gviejo/pynapple/A0670-221213_filtered.avi") +#frames = frames[:,:,:,0] +frames = np.random.randn(10, 100, 100) + +# %% +# *** +app = QtWidgets.QApplication([]) + +# %% +# *** +iw = fpl.ImageWidget(frames, cmap="gnuplot2") + +iw.show() + +imageVar2 = iw.widget.grab(iw.widget.rect()) #returns QPixMap +imageVar2.save("../_static/fastplotlib_demo.png") #again any file name/path and image type possible here + +iw.close() + +app.exec() + + + + + + From 3192d3b30ca806dec0698fd3bc6403a2edb74bb5 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 5 Dec 2023 20:01:24 -0500 Subject: [PATCH 3/7] Adding first version of convolution --- pynapple/core/jitted_functions.py | 93 ++++++++++++++----------------- pynapple/core/time_series.py | 36 ++++++++++-- 2 files changed, 73 insertions(+), 56 deletions(-) diff --git a/pynapple/core/jitted_functions.py b/pynapple/core/jitted_functions.py index a9af4c75..5d1fdf3f 100644 --- a/pynapple/core/jitted_functions.py +++ b/pynapple/core/jitted_functions.py @@ -2,9 +2,9 @@ # @Author: guillaume # @Date: 2022-10-31 16:44:31 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-11-20 19:23:47 +# @Last Modified time: 2023-12-05 19:45:00 import numpy as np -from numba import jit +from numba import jit, njit, prange @jit(nopython=True) @@ -773,58 +773,49 @@ def jitremove_nan(time_array, index_nan): ends = time_array[ix_end] return (starts, ends) -@jit(nopython=True) -def jitconvolve(time_array, data_array, starts, ends, array): - time_array, data_array, countin = jitrestrict_with_count( - time_array, data_array, starts, ends - ) - - m = starts.shape[0] - f = data_array.shape[1:] - n = time_array.shape[0] - new_data_array = np.zeros((n, *f), dtype=np.float64) - wsize = array.shape[0] - - k = 0 # epochs count - t = 0 # time points count - i = 0 # window position - - while k < m: - maxt = t + countin[k] - wn = 1 - i = 0 - while t < maxt: - new_data_array[t] = np.sum(array[0:i+wn]*data_array[t-wn+1:t+1]) - - if wn < wsize: - wn += 1 - - t += 1 - - k += 1 +@jit(nopython=True) +def jitconvolve(d, a): + return np.convolve(d, a) + +@njit(parallel=True) +def pjitconvolve(data_array, array, trim='both'): + t,c = data_array.shape + k = array.shape[0] + new_data_array = np.zeros((t,c)) + + if trim=='both': + cut = ((1-k%2)+(k-1)//2, t+k-1-((k-1)//2)) + elif trim=='left': + cut = (k-1,t+k-1) + elif trim=='right': + cut = (0,t) + + for i in prange(c): + new_data_array[:,i] = jitconvolve(data_array[:,i], array)[cut[0]:cut[1]] + return new_data_array -@jit(nopython=True) -def jit_poisson_IRLS(X, y, niter=100, tolerance=1e-5): - y = y.astype(np.float64) - X = X.astype(np.float64) - n, d = X.shape - W = np.ones(n) - iXtWX = np.linalg.inv(np.dot(X.T * W, X)) - XtWY = np.dot(X.T * W, y) - B = np.dot(iXtWX, XtWY) - - for _ in range(niter): - B_ = B - L = np.exp(X.dot(B)) # Link function - Z = L.reshape((-1, 1)) * X # partial derivatives - delta = np.dot(np.linalg.inv(np.dot(Z.T * W, Z)), np.dot(Z.T * W, y)) - B = B + delta - tol = np.sum(np.abs((B - B_) / B_)) - if tol < tolerance: - return B - return B +# @jit(nopython=True) +# def jit_poisson_IRLS(X, y, niter=100, tolerance=1e-5): +# y = y.astype(np.float64) +# X = X.astype(np.float64) +# n, d = X.shape +# W = np.ones(n) +# iXtWX = np.linalg.inv(np.dot(X.T * W, X)) +# XtWY = np.dot(X.T * W, y) +# B = np.dot(iXtWX, XtWY) + +# for _ in range(niter): +# B_ = B +# L = np.exp(X.dot(B)) # Link function +# Z = L.reshape((-1, 1)) * X # partial derivatives +# delta = np.dot(np.linalg.inv(np.dot(Z.T * W, Z)), np.dot(Z.T * W, y)) +# B = B + delta +# tol = np.sum(np.abs((B - B_) / B_)) +# if tol < tolerance: +# return B +# return B # @jit(nopython=True) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index eeddfc86..e78b72ab 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-01-27 18:33:31 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-11-20 19:35:20 +# @Last Modified time: 2023-12-05 19:49:28 """ @@ -48,6 +48,7 @@ jitvaluefrom, jitvaluefromtensor, jitconvolve, + pjitconvolve ) from .time_index import TsIndex @@ -832,7 +833,7 @@ def dropna(self, update_time_support=True): else: return self - def convolve(self, array, ep = None): + def convolve(self, array, ep = None, trim='both'): """Things to assume : constant sampling rate Parameters @@ -843,17 +844,42 @@ def convolve(self, array, ep = None): assert isinstance(array, np.ndarray) if ep is None: ep = self.time_support + time_array = self.index.values data_array = self.values starts = ep.start.values ends = ep.end.values - new_data_array = jitconvolve(time_array, data_array, starts, ends, array) + if data_array.ndim == 1: + new_data_array = np.zeros(data_array.shape) + k = array.shape[0] + for s, e in zip(starts, ends): + idx_s = np.searchsorted(time_array, s) + idx_e = np.searchsorted(time_array, e, side="right") + + t = idx_e - idx_s + if trim=='both': + cut = ((1-k%2)+(k-1)//2, t+k-1-((k-1)//2)) + elif trim=='left': + cut = (k-1,t+k-1) + elif trim=='right': + cut = (0,t) + # scipy is actually faster for Tsd + new_data_array[idx_s:idx_e] = signal.convolve(data_array[idx_s:idx_e], array)[cut[0]:cut[1]] + + return self.__class__(t=time_array, d=new_data_array, time_support=ep) + else: + new_data_array = np.zeros(data_array.shape) + for s, e in zip(starts, ends): + idx_s = np.searchsorted(time_array, s) + idx_e = np.searchsorted(time_array, e, side="right") + new_data_array[idx_s : idx_e] = pjitconvolve(data_array[idx_s:idx_e], array, trim=trim) + + return self.__class__(t=time_array, d=new_data_array, time_support=ep) - return self.__class__(time_array, new_data_array, time_support=self.time_support) def smooth(self, std, size): - """Smooth with a gaussan kernel + """Smooth with a gaussian kernel Parameters ---------- From 4f96d6a341d6d5fc90129c2823b55ede858fffb0 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Wed, 6 Dec 2023 16:39:48 -0500 Subject: [PATCH 4/7] Adding working convolution and smoothing --- ...b.py => _tutorial_pynapple_fastplotlib.py} | 0 pynapple/core/jitted_functions.py | 37 +++++---- pynapple/core/time_series.py | 81 ++++++++++++------- 3 files changed, 74 insertions(+), 44 deletions(-) rename docs/examples/{tutorial_pynapple_fastplotlib.py => _tutorial_pynapple_fastplotlib.py} (100%) diff --git a/docs/examples/tutorial_pynapple_fastplotlib.py b/docs/examples/_tutorial_pynapple_fastplotlib.py similarity index 100% rename from docs/examples/tutorial_pynapple_fastplotlib.py rename to docs/examples/_tutorial_pynapple_fastplotlib.py diff --git a/pynapple/core/jitted_functions.py b/pynapple/core/jitted_functions.py index 5d1fdf3f..3534d5c0 100644 --- a/pynapple/core/jitted_functions.py +++ b/pynapple/core/jitted_functions.py @@ -2,7 +2,7 @@ # @Author: guillaume # @Date: 2022-10-31 16:44:31 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-12-05 19:45:00 +# @Last Modified time: 2023-12-06 16:03:51 import numpy as np from numba import jit, njit, prange @@ -778,24 +778,31 @@ def jitremove_nan(time_array, index_nan): def jitconvolve(d, a): return np.convolve(d, a) + @njit(parallel=True) -def pjitconvolve(data_array, array, trim='both'): - t,c = data_array.shape +def pjitconvolve(data_array, array, trim="both"): + shape = data_array.shape + t = shape[0] k = array.shape[0] - new_data_array = np.zeros((t,c)) - - if trim=='both': - cut = ((1-k%2)+(k-1)//2, t+k-1-((k-1)//2)) - elif trim=='left': - cut = (k-1,t+k-1) - elif trim=='right': - cut = (0,t) - - for i in prange(c): - new_data_array[:,i] = jitconvolve(data_array[:,i], array)[cut[0]:cut[1]] - + + data_array = data_array.reshape(t, -1) + new_data_array = np.zeros(shape) + + if trim == "both": + cut = ((1 - k % 2) + (k - 1) // 2, t + k - 1 - ((k - 1) // 2)) + elif trim == "left": + cut = (k - 1, t + k - 1) + elif trim == "right": + cut = (0, t) + + for i in prange(data_array.shape[1]): + new_data_array[:, i] = jitconvolve(data_array[:, i], array)[cut[0] : cut[1]] + + new_data_array = new_data_array.reshape(shape) + return new_data_array + # @jit(nopython=True) # def jit_poisson_IRLS(X, y, niter=100, tolerance=1e-5): # y = y.astype(np.float64) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index e78b72ab..be1a8d09 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-01-27 18:33:31 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-12-05 19:49:28 +# @Last Modified time: 2023-12-06 16:32:49 """ @@ -32,13 +32,13 @@ import numpy as np import pandas as pd from numpy.lib.mixins import NDArrayOperatorsMixin -from tabulate import tabulate from scipy import signal +from tabulate import tabulate from .interval_set import IntervalSet from .jitted_functions import ( jitbin, - jitbin_array, + jitbin_array, jitcount, jitremove_nan, jitrestrict, @@ -47,8 +47,7 @@ jittsrestrict_with_count, jitvaluefrom, jitvaluefromtensor, - jitconvolve, - pjitconvolve + pjitconvolve, ) from .time_index import TsIndex @@ -833,19 +832,38 @@ def dropna(self, update_time_support=True): else: return self - def convolve(self, array, ep = None, trim='both'): - """Things to assume : constant sampling rate - + def convolve(self, array, ep=None, trim="both"): + """Return the discrete linear convolution of the time series with a one dimensional sequence. + + A parameter ep can control the epochs for which the convolution will apply. Otherwise the convolution is made over the time support. + + This function assume a constant sampling rate of the time series. + + The only mode supported is full. The returned object is trimmed to match the size of the original object. The parameter trim controls which side the trimming operates. Default is 'both'. + + See the numpy documentation here : https://numpy.org/doc/stable/reference/generated/numpy.convolve.html + Parameters ---------- array : np.ndarray One dimensional input array + ep : None, optional + The epochs to apply the convolution + trim : str, optional + The side on which to trim the output of the convolution ('left', 'right', 'both' [default]) + + Returns + ------- + Tsd, TsdFrame or TsdTensor + The convolved time series """ assert isinstance(array, np.ndarray) + assert array.ndim == 1, "Input should be a one dimensional array." + if ep is None: ep = self.time_support - time_array = self.index.values + time_array = self.index.values data_array = self.values starts = ep.start.values ends = ep.end.values @@ -857,44 +875,49 @@ def convolve(self, array, ep = None, trim='both'): idx_s = np.searchsorted(time_array, s) idx_e = np.searchsorted(time_array, e, side="right") - t = idx_e - idx_s - if trim=='both': - cut = ((1-k%2)+(k-1)//2, t+k-1-((k-1)//2)) - elif trim=='left': - cut = (k-1,t+k-1) - elif trim=='right': - cut = (0,t) + t = idx_e - idx_s + if trim == "both": + cut = ((1 - k % 2) + (k - 1) // 2, t + k - 1 - ((k - 1) // 2)) + elif trim == "left": + cut = (k - 1, t + k - 1) + elif trim == "right": + cut = (0, t) # scipy is actually faster for Tsd - new_data_array[idx_s:idx_e] = signal.convolve(data_array[idx_s:idx_e], array)[cut[0]:cut[1]] + new_data_array[idx_s:idx_e] = signal.convolve( + data_array[idx_s:idx_e], array + )[cut[0] : cut[1]] - return self.__class__(t=time_array, d=new_data_array, time_support=ep) + return self.__class__(t=time_array, d=new_data_array, time_support=ep) else: new_data_array = np.zeros(data_array.shape) for s, e in zip(starts, ends): idx_s = np.searchsorted(time_array, s) idx_e = np.searchsorted(time_array, e, side="right") - new_data_array[idx_s : idx_e] = pjitconvolve(data_array[idx_s:idx_e], array, trim=trim) + new_data_array[idx_s:idx_e] = pjitconvolve( + data_array[idx_s:idx_e], array, trim=trim + ) return self.__class__(t=time_array, d=new_data_array, time_support=ep) - def smooth(self, std, size): - """Smooth with a gaussian kernel - + """Smooth a time series with a gaussian kernel. std is the standard deviation and size is the number of point of the window. + + See the scipy documentation : https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.gaussian.html + Parameters ---------- - std : TYPE - Description - size : TYPE + std : int + Standard deviation + size : int Description - + Returns ------- - TYPE - Description + Tsd, TsdFrame, TsdTensor + Time series convolved with a gaussian kernel """ window = signal.windows.gaussian(size, std=std) - window = window/window.sum() + window = window / window.sum() return self.convolve(window) From cdcbc154661a48af1807d03b29575356b5661aa1 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Thu, 7 Dec 2023 14:03:48 -0500 Subject: [PATCH 5/7] Adding tests for convolution and smoothing --- pynapple/core/jitted_functions.py | 4 +- pynapple/core/time_series.py | 13 ++++-- tests/test_time_series.py | 76 ++++++++++++++++++++++++++++++- 3 files changed, 85 insertions(+), 8 deletions(-) diff --git a/pynapple/core/jitted_functions.py b/pynapple/core/jitted_functions.py index 3534d5c0..fa719822 100644 --- a/pynapple/core/jitted_functions.py +++ b/pynapple/core/jitted_functions.py @@ -2,7 +2,7 @@ # @Author: guillaume # @Date: 2022-10-31 16:44:31 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-12-06 16:03:51 +# @Last Modified time: 2023-12-06 17:51:35 import numpy as np from numba import jit, njit, prange @@ -786,7 +786,7 @@ def pjitconvolve(data_array, array, trim="both"): k = array.shape[0] data_array = data_array.reshape(t, -1) - new_data_array = np.zeros(shape) + new_data_array = np.zeros(data_array.shape) if trim == "both": cut = ((1 - k % 2) + (k - 1) // 2, t + k - 1 - ((k - 1) // 2)) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index be1a8d09..72a4543e 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-01-27 18:33:31 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-12-06 16:32:49 +# @Last Modified time: 2023-12-07 13:58:06 """ @@ -857,8 +857,9 @@ def convolve(self, array, ep=None, trim="both"): Tsd, TsdFrame or TsdTensor The convolved time series """ - assert isinstance(array, np.ndarray) + assert isinstance(array, np.ndarray), "Input should be a 1-d numpy array." assert array.ndim == 1, "Input should be a one dimensional array." + assert trim in ['both', 'left', 'right'], "Unknow argument. trim should be 'both', 'left' or 'right'." if ep is None: ep = self.time_support @@ -876,12 +877,12 @@ def convolve(self, array, ep=None, trim="both"): idx_e = np.searchsorted(time_array, e, side="right") t = idx_e - idx_s - if trim == "both": - cut = ((1 - k % 2) + (k - 1) // 2, t + k - 1 - ((k - 1) // 2)) - elif trim == "left": + if trim == "left": cut = (k - 1, t + k - 1) elif trim == "right": cut = (0, t) + else: + cut = ((1 - k % 2) + (k - 1) // 2, t + k - 1 - ((k - 1) // 2)) # scipy is actually faster for Tsd new_data_array[idx_s:idx_e] = signal.convolve( data_array[idx_s:idx_e], array @@ -916,6 +917,8 @@ def smooth(self, std, size): Tsd, TsdFrame, TsdTensor Time series convolved with a gaussian kernel """ + assert isinstance(std, int), "std should be type int" + assert isinstance(size, int), "size should be type int" window = signal.windows.gaussian(size, std=std) window = window / window.sum() return self.convolve(window) diff --git a/tests/test_time_series.py b/tests/test_time_series.py index 449cbfa7..5d267c06 100755 --- a/tests/test_time_series.py +++ b/tests/test_time_series.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-04-01 09:57:55 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-11-19 18:48:57 +# @Last Modified time: 2023-12-07 14:01:09 #!/usr/bin/env python """Tests of time series for `pynapple` package.""" @@ -421,6 +421,80 @@ def test_dropna(self, tsd): assert len(new_tsd) == 0 assert len(new_tsd.time_support) == 0 + def test_convolve(self, tsd): + array = np.random.randn(10) + if not isinstance(tsd, nap.Ts): + tsd2 = tsd.convolve(array) + tmp = tsd.values.reshape(tsd.shape[0], -1) + tmp2 = np.zeros_like(tmp) + for i in range(tmp.shape[-1]): + tmp2[:,i] = np.convolve(tmp[:,i], array, mode='full')[5:-4] + np.testing.assert_array_almost_equal( + tmp2, + tsd2.values.reshape(tsd2.shape[0], -1) + ) + + with pytest.raises(AssertionError) as e_info: + tsd.convolve([1,2,3]) + assert str(e_info.value) == "Input should be a 1-d numpy array." + + with pytest.raises(AssertionError) as e_info: + tsd.convolve(np.random.rand(2,3)) + assert str(e_info.value) == "Input should be a one dimensional array." + + ep = nap.IntervalSet(start=[0, 60], end=[40,100]) + tsd3 = tsd.convolve(array, ep) + + for i in range(len(ep)): + tmp2 = tsd.restrict(ep.loc[[i]]).values + tmp2 = tmp2.reshape(tmp2.shape[0], -1) + for j in range(tmp2.shape[-1]): + tmp2[:,j] = np.convolve(tmp2[:,j], array, mode='full')[5:-4] + np.testing.assert_array_almost_equal( + tmp2, + tsd3.restrict(ep.loc[[i]]).values.reshape(tmp2.shape[0], -1) + ) + + # Trim + for trim, sl in zip(['left', 'both', 'right'], [slice(9,None),slice(5,-4),slice(None,-9)]): + tsd2 = tsd.convolve(array, trim=trim) + tmp = tsd.values.reshape(tsd.shape[0], -1) + tmp2 = np.zeros_like(tmp) + for i in range(tmp.shape[-1]): + tmp2[:,i] = np.convolve(tmp[:,i], array, mode='full')[sl] + np.testing.assert_array_almost_equal( + tmp2, + tsd2.values.reshape(tsd2.shape[0], -1) + ) + + with pytest.raises(AssertionError) as e_info: + tsd.convolve(array, trim='a') + assert str(e_info.value) == "Unknow argument. trim should be 'both', 'left' or 'right'." + + def test_smooth(self, tsd): + if not isinstance(tsd, nap.Ts): + from scipy import signal + tsd2 = tsd.smooth(2, 20) + tmp = tsd.values.reshape(tsd.shape[0], -1) + tmp2 = np.zeros_like(tmp) + window = signal.windows.gaussian(20, std=2) + window = window / window.sum() + for i in range(tmp.shape[-1]): + tmp2[:,i] = np.convolve(tmp[:,i], window, mode='full')[10:-9] + np.testing.assert_array_almost_equal( + tmp2, + tsd2.values.reshape(tsd2.shape[0], -1) + ) + + with pytest.raises(AssertionError) as e_info: + tsd.smooth('a', 20) + assert str(e_info.value) == "std should be type int" + + with pytest.raises(AssertionError) as e_info: + tsd.smooth(2, 'b') + assert str(e_info.value) == "size should be type int" + + #################################################### # Test for tsd From 5f3b322f44335613cb1b2abeb793b3beac220406 Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Mon, 11 Dec 2023 18:20:28 -0500 Subject: [PATCH 6/7] Adding working version of perievent continuous --- pynapple/core/jitted_functions.py | 112 +++++++++++++++++++++++++++++- pynapple/core/time_series.py | 8 ++- pynapple/process/perievent.py | 79 ++++++++++++++++++++- tests/test_perievent.py | 11 ++- 4 files changed, 203 insertions(+), 7 deletions(-) diff --git a/pynapple/core/jitted_functions.py b/pynapple/core/jitted_functions.py index fa719822..5604d9bc 100644 --- a/pynapple/core/jitted_functions.py +++ b/pynapple/core/jitted_functions.py @@ -2,7 +2,7 @@ # @Author: guillaume # @Date: 2022-10-31 16:44:31 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-12-06 17:51:35 +# @Last Modified time: 2023-12-11 18:05:36 import numpy as np from numba import jit, njit, prange @@ -803,6 +803,116 @@ def pjitconvolve(data_array, array, trim="both"): return new_data_array +@njit(parallel=True) +def jitcontinuous_perievent( + time_array, data_array, time_target_array, starts, ends, windowsize +): + N_samples = len(time_array) + N_target = len(time_target_array) + N_epochs = len(starts) + count = np.zeros((N_epochs, 2), dtype=np.int64) + start_t = np.zeros((N_epochs, 2), dtype=np.int64) + + k = 0 # Epochs + t = 0 # Samples + i = 0 # Target + + while ends[k] < time_array[t] and ends[k] < time_target_array[i]: + k += 1 + + while k < N_epochs: + # Outside + while t < N_samples: + if time_array[t] >= starts[k]: + break + t += 1 + + while i < N_target: + if time_target_array[i] >= starts[k]: + break + i += 1 + + if time_array[t] <= ends[k]: + start_t[k, 0] = t + + if time_target_array[i] <= ends[k]: + start_t[k, 1] = i + + # Inside + while t < N_samples: + if time_array[t] > ends[k]: + break + else: + count[k, 0] += 1 + t += 1 + + while i < N_target: + if time_target_array[i] > ends[k]: + break + else: + count[k, 1] += 1 + i += 1 + + k += 1 + + if k == N_epochs: + break + if t == N_samples: + break + if i == N_target: + break + + new_data_array = np.full( + (np.sum(windowsize) + 1, np.sum(count[:, 1]), *data_array.shape[1:]), np.nan + ) + + if np.all((count[:, 0] * count[:, 1]) > 0): + for k in prange(N_epochs): + if count[k, 0] > 0 and count[k, 1] > 0: + t = start_t[k, 0] + i = start_t[k, 1] + maxt = t + count[k, 0] + maxi = i + count[k, 1] + cnt_i = np.sum(count[0:k, 1]) + + while i < maxi: + interval = abs(time_array[t] - time_target_array[i]) + t_pos = t + t += 1 + while t < maxt: + new_interval = abs(time_array[t] - time_target_array[i]) + if new_interval > interval: + break + else: + interval = new_interval + t_pos = t + t += 1 + + left = np.minimum(windowsize[0], t_pos - start_t[k, 0]) + right = np.minimum(windowsize[1], maxt - t_pos - 1) + center = new_data_array.shape[0] // 2 + 1 + new_data_array[ + center - left - 1 : center + right, cnt_i + ] = data_array[t_pos - left : t_pos + right + 1] + + t -= 1 + i += 1 + cnt_i += 1 + + return new_data_array + + +# time_array = tsd.t +# time_target_array = tref.t +# data_array = tsd.d + +# for i,t in enumerate(tref.restrict(ep).t): +# plot(time_idx + t, new_data_array[:,i]+i*2.0, 'o') +# plot(tsd + i*2.0, color='grey') +# [axvspan(ep.loc[i,'start'], ep.loc[i,'end'], alpha=0.3) for i in range(len(ep))] +# [axvline(t) for t in tref.restrict(ep).t] + + # @jit(nopython=True) # def jit_poisson_IRLS(X, y, niter=100, tolerance=1e-5): # y = y.astype(np.float64) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 72a4543e..51e11fc3 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -38,7 +38,7 @@ from .interval_set import IntervalSet from .jitted_functions import ( jitbin, - jitbin_array, + jitbin_array, jitcount, jitremove_nan, jitrestrict, @@ -859,7 +859,11 @@ def convolve(self, array, ep=None, trim="both"): """ assert isinstance(array, np.ndarray), "Input should be a 1-d numpy array." assert array.ndim == 1, "Input should be a one dimensional array." - assert trim in ['both', 'left', 'right'], "Unknow argument. trim should be 'both', 'left' or 'right'." + assert trim in [ + "both", + "left", + "right", + ], "Unknow argument. trim should be 'both', 'left' or 'right'." if ep is None: ep = self.time_support diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index 8492cdb7..f10dec8d 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-01-30 22:59:00 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-11-20 12:08:15 +# @Last Modified time: 2023-12-11 17:41:30 import numpy as np from scipy.linalg import hankel @@ -82,7 +82,7 @@ def compute_perievent(data, tref, minmax, time_unit="s"): if tref is not a Ts/Tsd object or if data is not a Ts/Tsd or TsGroup """ if not isinstance(tref, (nap.Ts, nap.Tsd)): - raise RuntimeError("tref should be a Tsd object.") + raise RuntimeError("tref should be a Ts or Tsd object.") if isinstance(minmax, float) or isinstance(minmax, int): minmax = np.array([minmax, minmax], dtype=np.float64) @@ -106,6 +106,81 @@ def compute_perievent(data, tref, minmax, time_unit="s"): raise RuntimeError("Unknown format for data") +def compute_perievent_continuous(data, tref, minmax, ep=None, time_unit="s"): + """ + Center contiunous time series around the timestamps given by the tref argument. + minmax indicates the start and end of the window. + + If the input is a n dimensional time series, it returns a n+1 dimensional time series. + + This function assumes a constant sampling rate of the time series. + + Parameters + ---------- + data : Tsd, TsdFrame or TsdTensor + The data to align to tref. + tref : Ts/Tsd + The timestamps of the event to align to + minmax : tuple or int or float + The window size. Can be unequal on each side i.e. (-500, 1000). + time_unit : str, optional + Time units of the minmax ('s' [default], 'ms', 'us'). + + Returns + ------- + Tsd, TsdFrame, TsdTensor + + + Raises + ------ + RuntimeError + if tref is not a Ts/Tsd object or if data is not a Ts/Tsd/Tensor object. + """ + + assert isinstance(tref, (nap.Ts, nap.Tsd)), "tref should be a Ts or Tsd object." + assert isinstance( + data, (nap.Tsd, nap.TsdFrame, nap.TsdTensor) + ), "data should be a Tsd, TsdFrame or TsdTensor." + assert isinstance( + minmax, (float, int, tuple) + ), "minmax should be a tuple or int or float." + assert isinstance(time_unit, str), "time_unit should be a str." + assert time_unit in ["s", "ms", "us"], "time_unit should be 's', 'ms' or 'us')" + + if ep is None: + ep = data.time_support + else: + assert isinstance(ep, (nap.IntervalSet)), "ep should be an IntervalSet object." + + if isinstance(minmax, float) or isinstance(minmax, int): + minmax = np.array([minmax, minmax], dtype=np.float64) + + window = np.abs(nap.TsIndex.format_timestamps(np.array(minmax), time_unit)) + + time_array = data.index.values + data_array = data.values + time_target_array = tref.index.values + starts = ep.start.values + ends = ep.end.values + + binsize = time_array[1] - time_array[0] + idx1 = -np.arange(0, window[0] + binsize, binsize)[::-1][:-1] + idx2 = np.arange(0, window[1] + binsize, binsize)[1:] + time_idx = np.hstack((idx1, np.zeros(1), idx2)) + windowsize = np.array([idx1.shape[0], idx2.shape[0]]) + + new_data_array = nap.jitted_functions.jitcontinuous_perievent( + time_array, data_array, time_target_array, starts, ends, windowsize + ) + + time_support = nap.IntervalSet(start=-window[0], end=window[1]) + + if new_data_array.ndim == 2: + return nap.TsdFrame(t=time_idx, d=new_data_array, time_support=time_support) + else: + return nap.TsdTensor(t=time_idx, d=new_data_array, time_support=time_support) + + def compute_event_trigger_average( group, feature, binsize, windowsize, ep, time_units="s" ): diff --git a/tests/test_perievent.py b/tests/test_perievent.py index efe2fd3b..1a75daad 100644 --- a/tests/test_perievent.py +++ b/tests/test_perievent.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-03-30 11:16:53 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-09-18 16:04:02 +# @Last Modified time: 2023-12-11 18:16:58 #!/usr/bin/env python """Tests of perievent for `pynapple` package.""" @@ -23,6 +23,13 @@ def test_align_tsd(): for i, j in zip(peth.keys(), np.arange(0, 100, 10)): np.testing.assert_array_almost_equal(peth[i].index, np.arange(-10, 10)) +def test_compute_perievent_continuous(): + tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) + tref = nap.Ts(t=np.arange(10, 100, 10)) + pe = nap.compute_perievent_continuous(tsd, tref, minmax=(-10, 10)) + + assert isinstance(pe, nap.TsdFrame) + assert pe.shape[1] == len(tref) def test_compute_perievent_with_tsd(): tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) @@ -57,7 +64,7 @@ def test_compute_perievent_raise_error(): tref = np.arange(10, 100, 10) with pytest.raises(Exception) as e_info: nap.compute_perievent(tsd, tref, minmax=(-10, 10)) - assert str(e_info.value) == "tref should be a Tsd object." + assert str(e_info.value) == "tref should be a Ts or Tsd object." tsd = t = np.arange(100) tref = nap.Ts(t=np.arange(10, 100, 10)) with pytest.raises(Exception) as e_info: From 0fafa017aefa61c30445b8d259b141ec979e408e Mon Sep 17 00:00:00 2001 From: Guillaume Viejo Date: Tue, 12 Dec 2023 18:42:29 -0500 Subject: [PATCH 7/7] Bumping v0.5 --- README.md | 4 +- docs/HISTORY.md | 13 +- .../_tutorial_pynapple_fastplotlib.py | 83 --------- docs/examples/tutorial_HD_dataset.py | 0 docs/index.md | 4 +- docs/old_pages/core.interval_set.md | 13 -- docs/old_pages/core.time_series.md | 15 -- docs/old_pages/core.ts_group.md | 13 -- docs/old_pages/io.cnmfe.md | 13 -- docs/old_pages/io.folder.md | 12 -- docs/old_pages/io.loader.md | 13 -- docs/old_pages/io.md | 11 -- docs/old_pages/io.neurosuite.md | 13 -- docs/old_pages/io.npz.md | 11 -- docs/old_pages/io.nwb.md | 11 -- docs/old_pages/io.phy.md | 12 -- docs/old_pages/io.suite2p.md | 11 -- docs/old_pages/process.correlograms.md | 11 -- docs/old_pages/process.decoding.md | 11 -- docs/old_pages/process.perievent.md | 11 -- docs/old_pages/process.randomize.md | 11 -- docs/old_pages/process.tuning_curves.md | 11 -- pynapple/__init__.py | 2 +- pynapple/core/jitted_functions.py | 4 +- pynapple/process/perievent.py | 94 ++++++----- pyproject.toml | 2 +- setup.py | 2 +- test_dandi.py | 52 ------ test_ndarray.py | 53 ------ test_new_io.py | 37 ---- tests/test_perievent.py | 158 ++++++++++++++++-- tests/test_spike_trigger_average.py | 32 +++- 32 files changed, 241 insertions(+), 502 deletions(-) delete mode 100644 docs/examples/_tutorial_pynapple_fastplotlib.py mode change 100755 => 100644 docs/examples/tutorial_HD_dataset.py delete mode 100644 docs/old_pages/core.interval_set.md delete mode 100644 docs/old_pages/core.time_series.md delete mode 100644 docs/old_pages/core.ts_group.md delete mode 100644 docs/old_pages/io.cnmfe.md delete mode 100644 docs/old_pages/io.folder.md delete mode 100644 docs/old_pages/io.loader.md delete mode 100644 docs/old_pages/io.md delete mode 100644 docs/old_pages/io.neurosuite.md delete mode 100644 docs/old_pages/io.npz.md delete mode 100644 docs/old_pages/io.nwb.md delete mode 100644 docs/old_pages/io.phy.md delete mode 100644 docs/old_pages/io.suite2p.md delete mode 100644 docs/old_pages/process.correlograms.md delete mode 100644 docs/old_pages/process.decoding.md delete mode 100644 docs/old_pages/process.perievent.md delete mode 100644 docs/old_pages/process.randomize.md delete mode 100644 docs/old_pages/process.tuning_curves.md delete mode 100644 test_dandi.py delete mode 100644 test_ndarray.py delete mode 100644 test_new_io.py diff --git a/README.md b/README.md index 0bc976d2..ef22693c 100644 --- a/README.md +++ b/README.md @@ -85,7 +85,7 @@ This procedure will install all the dependencies including - tabulate - h5py -For spyder users, it is recommended to install spyder after installing pynapple with : + Basic Usage diff --git a/docs/HISTORY.md b/docs/HISTORY.md index f02efbf3..9e5961d8 100644 --- a/docs/HISTORY.md +++ b/docs/HISTORY.md @@ -9,10 +9,13 @@ In 2018, Francesco started neuroseries, a Python package built on Pandas. It was In 2021, Guillaume and other trainees in Adrien's lab decided to fork from neuroseries and started *pynapple*. The core of pynapple is largely built upon neuroseries. Some of the original changes to TSToolbox made by Luke were included in this package, especially the *time_support* property of all ts/tsd objects. -0.4.2 (2023-11-16) +0.5.0 (2023-12-12) ------------------ -- Removing GUI stack from pynapple. +- Removing GUI stack from pynapple. To create a NWB file, users need to install nwbmatic (https://github.com/pynapple-org/nwbmatic) +- Implementing `compute_perievent_continuous` +- Implementing `convolve` for Tsd, TsdFrame and TsdTensor +- Implementing `smooth` for fast gaussian smoothing of time series 0.4.1 (2023-10-30) @@ -48,10 +51,10 @@ In 2021, Guillaume and other trainees in Adrien's lab decided to fork from neuro 0.3.4 (2023-06-29) ------------------ -- TsGroup.to_tsd and Tsd.to_tsgroup transformations -- Count can take IntervalSet +- `TsGroup.to_tsd` and `Tsd.to_tsgroup` transformations +- `count` can take IntervalSet - Saving to npz functions for all objects. -- tsd.value_from can take TsdFrame +- `tsd.value_from` can take TsdFrame - Warning message for deprecating current IO. diff --git a/docs/examples/_tutorial_pynapple_fastplotlib.py b/docs/examples/_tutorial_pynapple_fastplotlib.py deleted file mode 100644 index 250326fb..00000000 --- a/docs/examples/_tutorial_pynapple_fastplotlib.py +++ /dev/null @@ -1,83 +0,0 @@ -# -*- coding: utf-8 -*- -""" -Fastplotlib -=========== - -Working with calcium data. - -For the example dataset, we will be working with a recording of a freely-moving mouse imaged with a Miniscope (1-photon imaging). The area recorded for this experiment is the postsubiculum - a region that is known to contain head-direction cells, or cells that fire when the animal's head is pointing in a specific direction. - -The NWB file for the example is hosted on [OSF](https://osf.io/sbnaw). We show below how to stream it. - -See the [documentation](https://pynapple-org.github.io/pynapple/) of Pynapple for instructions on installing the package. - -This tutorial was made by Sofia Skromne Carrasco and Guillaume Viejo. - -""" -# %% -# !!! warning -# This tutorial uses seaborn and matplotlib for displaying the figure -# -# You can install all with `pip install matplotlib seaborn tqdm` -# -# mkdocs_gallery_thumbnail_number = 1 -# -# Now, import the necessary libraries: - -import pynapple as nap -import numpy as np -import fastplotlib as fpl -#from PyQt6 import QtWidgets -import imageio.v3 as iio -import sys -# mkdocs_gallery_thumbnail_path = '../_static/fastplotlib_demo.png' - -nwb = nap.load_file("/Users/gviejo/pynapple/Mouse32-220101.nwb") - -units = nwb['units'].getby_category("location")['adn'] - -tmp = units.to_tsd() - -tmp = np.vstack((tmp.index.values, tmp.values)).T - -fplot = fpl.Plot() - -fplot.add_scatter(tmp) - -fplot.graphics[0].cmap = "jet" - -fplot.graphics[0].cmap.values = tmp[:, 1] - -fplot.show() - - -sys.exit() -# %% -# *** - -#frames = iio.imread("/Users/gviejo/pynapple/A0670-221213_filtered.avi") -#frames = frames[:,:,:,0] -frames = np.random.randn(10, 100, 100) - -# %% -# *** -app = QtWidgets.QApplication([]) - -# %% -# *** -iw = fpl.ImageWidget(frames, cmap="gnuplot2") - -iw.show() - -imageVar2 = iw.widget.grab(iw.widget.rect()) #returns QPixMap -imageVar2.save("../_static/fastplotlib_demo.png") #again any file name/path and image type possible here - -iw.close() - -app.exec() - - - - - - diff --git a/docs/examples/tutorial_HD_dataset.py b/docs/examples/tutorial_HD_dataset.py old mode 100755 new mode 100644 diff --git a/docs/index.md b/docs/index.md index a07b494b..52f2a840 100644 --- a/docs/index.md +++ b/docs/index.md @@ -81,7 +81,7 @@ This procedure will install all the dependencies including - tabulate - h5py -For spyder users, it is recommended to install spyder after installing pynapple with : + Basic Usage ----------- diff --git a/docs/old_pages/core.interval_set.md b/docs/old_pages/core.interval_set.md deleted file mode 100644 index 932e023e..00000000 --- a/docs/old_pages/core.interval_set.md +++ /dev/null @@ -1,13 +0,0 @@ -::: pynapple.core.interval_set - handler: python - selection: - docstring_style: numpy - members: - - IntervalSet - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source - - diff --git a/docs/old_pages/core.time_series.md b/docs/old_pages/core.time_series.md deleted file mode 100644 index ffd1e579..00000000 --- a/docs/old_pages/core.time_series.md +++ /dev/null @@ -1,15 +0,0 @@ -::: pynapple.core.time_series - handler: python - selection: - docstring_style: numpy - members: - - Tsd - - Ts - - TsdFrame - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source - - diff --git a/docs/old_pages/core.ts_group.md b/docs/old_pages/core.ts_group.md deleted file mode 100644 index e410a8ab..00000000 --- a/docs/old_pages/core.ts_group.md +++ /dev/null @@ -1,13 +0,0 @@ -::: pynapple.core.ts_group - handler: python - selection: - docstring_style: numpy - members: - - TsGroup - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source - - diff --git a/docs/old_pages/io.cnmfe.md b/docs/old_pages/io.cnmfe.md deleted file mode 100644 index c8305673..00000000 --- a/docs/old_pages/io.cnmfe.md +++ /dev/null @@ -1,13 +0,0 @@ -::: pynapple.io.cnmfe - handler: python - selection: - docstring_style: numpy - members: - - CNMF_E - - Minian - - InscopixCNMFE - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source diff --git a/docs/old_pages/io.folder.md b/docs/old_pages/io.folder.md deleted file mode 100644 index 438d227b..00000000 --- a/docs/old_pages/io.folder.md +++ /dev/null @@ -1,12 +0,0 @@ -::: pynapple.io.folder - handler: python - selection: - docstring_style: numpy - members: - - Folder - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source - \ No newline at end of file diff --git a/docs/old_pages/io.loader.md b/docs/old_pages/io.loader.md deleted file mode 100644 index 1e45a7de..00000000 --- a/docs/old_pages/io.loader.md +++ /dev/null @@ -1,13 +0,0 @@ -::: pynapple.io.loader - handler: python - selection: - docstring_style: numpy - members: - - BaseLoader - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source - - diff --git a/docs/old_pages/io.md b/docs/old_pages/io.md deleted file mode 100644 index 71dc3085..00000000 --- a/docs/old_pages/io.md +++ /dev/null @@ -1,11 +0,0 @@ -::: pynapple.io.misc - handler: python - selection: - docstring_style: numpy - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source - - diff --git a/docs/old_pages/io.neurosuite.md b/docs/old_pages/io.neurosuite.md deleted file mode 100644 index fbdedf24..00000000 --- a/docs/old_pages/io.neurosuite.md +++ /dev/null @@ -1,13 +0,0 @@ -::: pynapple.io.neurosuite - handler: python - selection: - docstring_style: numpy - members: - - NeuroSuite - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source - - diff --git a/docs/old_pages/io.npz.md b/docs/old_pages/io.npz.md deleted file mode 100644 index 736f52c8..00000000 --- a/docs/old_pages/io.npz.md +++ /dev/null @@ -1,11 +0,0 @@ -::: pynapple.io.interface_npz - handler: python - selection: - docstring_style: numpy - members: - - NPZFile - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source \ No newline at end of file diff --git a/docs/old_pages/io.nwb.md b/docs/old_pages/io.nwb.md deleted file mode 100644 index fe719f13..00000000 --- a/docs/old_pages/io.nwb.md +++ /dev/null @@ -1,11 +0,0 @@ -::: pynapple.io.interface_nwb - handler: python - selection: - docstring_style: numpy - members: - - NWBFile - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source \ No newline at end of file diff --git a/docs/old_pages/io.phy.md b/docs/old_pages/io.phy.md deleted file mode 100644 index 0352ef12..00000000 --- a/docs/old_pages/io.phy.md +++ /dev/null @@ -1,12 +0,0 @@ -::: pynapple.io.phy - handler: python - selection: - docstring_style: numpy - members: - - Phy - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source - \ No newline at end of file diff --git a/docs/old_pages/io.suite2p.md b/docs/old_pages/io.suite2p.md deleted file mode 100644 index 1849dbbf..00000000 --- a/docs/old_pages/io.suite2p.md +++ /dev/null @@ -1,11 +0,0 @@ -::: pynapple.io.suite2p - handler: python - selection: - docstring_style: numpy - members: - - Suite2P - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source diff --git a/docs/old_pages/process.correlograms.md b/docs/old_pages/process.correlograms.md deleted file mode 100644 index 21b32034..00000000 --- a/docs/old_pages/process.correlograms.md +++ /dev/null @@ -1,11 +0,0 @@ -::: pynapple.process.correlograms - handler: python - selection: - docstring_style: numpy - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source - - diff --git a/docs/old_pages/process.decoding.md b/docs/old_pages/process.decoding.md deleted file mode 100644 index 247d8399..00000000 --- a/docs/old_pages/process.decoding.md +++ /dev/null @@ -1,11 +0,0 @@ -::: pynapple.process.decoding - handler: python - selection: - docstring_style: numpy - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source - - diff --git a/docs/old_pages/process.perievent.md b/docs/old_pages/process.perievent.md deleted file mode 100644 index eb7da3a3..00000000 --- a/docs/old_pages/process.perievent.md +++ /dev/null @@ -1,11 +0,0 @@ -::: pynapple.process.perievent - handler: python - selection: - docstring_style: numpy - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source - - diff --git a/docs/old_pages/process.randomize.md b/docs/old_pages/process.randomize.md deleted file mode 100644 index 6c8f3467..00000000 --- a/docs/old_pages/process.randomize.md +++ /dev/null @@ -1,11 +0,0 @@ -::: pynapple.process.randomize - handler: python - selection: - docstring_style: numpy - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source - - diff --git a/docs/old_pages/process.tuning_curves.md b/docs/old_pages/process.tuning_curves.md deleted file mode 100644 index 4d943a3b..00000000 --- a/docs/old_pages/process.tuning_curves.md +++ /dev/null @@ -1,11 +0,0 @@ -::: pynapple.process.tuning_curves - handler: python - selection: - docstring_style: numpy - rendering: - show_root_heading: false - show_source: true - show_category_heading: false - members_order: source - - diff --git a/pynapple/__init__.py b/pynapple/__init__.py index ed30dc54..e26a04cd 100644 --- a/pynapple/__init__.py +++ b/pynapple/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.1" +__version__ = "0.5.0" from .core import * from .io import * from .process import * diff --git a/pynapple/core/jitted_functions.py b/pynapple/core/jitted_functions.py index 5604d9bc..340451c5 100644 --- a/pynapple/core/jitted_functions.py +++ b/pynapple/core/jitted_functions.py @@ -2,7 +2,7 @@ # @Author: guillaume # @Date: 2022-10-31 16:44:31 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-12-11 18:05:36 +# @Last Modified time: 2023-12-12 16:50:36 import numpy as np from numba import jit, njit, prange @@ -890,7 +890,7 @@ def jitcontinuous_perievent( left = np.minimum(windowsize[0], t_pos - start_t[k, 0]) right = np.minimum(windowsize[1], maxt - t_pos - 1) - center = new_data_array.shape[0] // 2 + 1 + center = windowsize[0] + 1 new_data_array[ center - left - 1 : center + right, cnt_i ] = data_array[t_pos - left : t_pos + right + 1] diff --git a/pynapple/process/perievent.py b/pynapple/process/perievent.py index f10dec8d..d4a3cf95 100644 --- a/pynapple/process/perievent.py +++ b/pynapple/process/perievent.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-01-30 22:59:00 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-12-11 17:41:30 +# @Last Modified time: 2023-12-12 18:17:42 import numpy as np from scipy.linalg import hankel @@ -54,18 +54,21 @@ def _align_tsd(tsd, tref, window, time_support): def compute_perievent(data, tref, minmax, time_unit="s"): """ - Center ts/tsd/tsgroup object around the timestamps given by the tref argument. - minmax indicates the start and end of the window. + Center the timestamps of a time series object or a time series group around the timestamps given by the `tref` argument. + `minmax` indicates the start and end of the window. If `minmax=(-5, 10)`, the window will be from -5 second to 10 second. + If `minmax=10`, the window will be from -10 second to 10 second. + + To center continuous time series around a set of timestamps, you can use `compute_perievent_continuous`. Parameters ---------- - data : Ts/Tsd/TsGroup + data : Ts, Tsd or TsGroup The data to align to tref. If Ts/Tsd, returns a TsGroup. If TsGroup, returns a dictionnary of TsGroup - tref : Ts/Tsd + tref : Ts or Tsd The timestamps of the event to align to - minmax : tuple or int or float + minmax : tuple, int or float The window size. Can be unequal on each side i.e. (-500, 1000). time_unit : str, optional Time units of the minmax ('s' [default], 'ms', 'us'). @@ -81,8 +84,15 @@ def compute_perievent(data, tref, minmax, time_unit="s"): RuntimeError if tref is not a Ts/Tsd object or if data is not a Ts/Tsd or TsGroup """ - if not isinstance(tref, (nap.Ts, nap.Tsd)): - raise RuntimeError("tref should be a Ts or Tsd object.") + assert isinstance(tref, (nap.Ts, nap.Tsd)), "tref should be a Ts or Tsd object." + assert isinstance( + data, (nap.Ts, nap.Tsd, nap.TsGroup) + ), "data should be a Ts, Tsd or TsGroup." + assert isinstance( + minmax, (float, int, tuple) + ), "minmax should be a tuple or int or float." + assert isinstance(time_unit, str), "time_unit should be a str." + assert time_unit in ["s", "ms", "us"], "time_unit should be 's', 'ms' or 'us'" if isinstance(minmax, float) or isinstance(minmax, int): minmax = np.array([minmax, minmax], dtype=np.float64) @@ -99,19 +109,17 @@ def compute_perievent(data, tref, minmax, time_unit="s"): return toreturn - elif isinstance(data, (nap.Ts, nap.Tsd)): - return _align_tsd(data, tref, window, time_support) - else: - raise RuntimeError("Unknown format for data") + return _align_tsd(data, tref, window, time_support) def compute_perievent_continuous(data, tref, minmax, ep=None, time_unit="s"): """ - Center contiunous time series around the timestamps given by the tref argument. - minmax indicates the start and end of the window. + Center continuous time series around the timestamps given by the 'tref' argument. + `minmax` indicates the start and end of the window. If `minmax=(-5, 10)`, the window will be from -5 second to 10 second. + If `minmax=10`, the window will be from -10 second to 10 second. - If the input is a n dimensional time series, it returns a n+1 dimensional time series. + To realign timestamps around a set of timestamps, you can use `compute_perievent_continuous`. This function assumes a constant sampling rate of the time series. @@ -119,22 +127,25 @@ def compute_perievent_continuous(data, tref, minmax, ep=None, time_unit="s"): ---------- data : Tsd, TsdFrame or TsdTensor The data to align to tref. - tref : Ts/Tsd + tref : Ts or Tsd The timestamps of the event to align to minmax : tuple or int or float The window size. Can be unequal on each side i.e. (-500, 1000). + ep : IntervalSet, optional + The epochs to perform the operation. If None, the default is the time support of the data. time_unit : str, optional Time units of the minmax ('s' [default], 'ms', 'us'). Returns ------- - Tsd, TsdFrame, TsdTensor - + TsdFrame, TsdTensor + If `data` is a one-dimensional Tsd, the output is a TsdFrame. Each column is one timestamps from `tref`. + If `data` is a TsdFrame or TsdTensor, the output is a TsdTensor with one more dimension. The first dimension is always time and the second dimension is the 'tref' timestamps. Raises ------ RuntimeError - if tref is not a Ts/Tsd object or if data is not a Ts/Tsd/Tensor object. + if tref is not a Ts/Tsd object or if data is not a Tsd/TsdFrame/TsdTensor object. """ assert isinstance(tref, (nap.Ts, nap.Tsd)), "tref should be a Ts or Tsd object." @@ -145,7 +156,7 @@ def compute_perievent_continuous(data, tref, minmax, ep=None, time_unit="s"): minmax, (float, int, tuple) ), "minmax should be a tuple or int or float." assert isinstance(time_unit, str), "time_unit should be a str." - assert time_unit in ["s", "ms", "us"], "time_unit should be 's', 'ms' or 'us')" + assert time_unit in ["s", "ms", "us"], "time_unit should be 's', 'ms' or 'us'" if ep is None: ep = data.time_support @@ -182,14 +193,14 @@ def compute_perievent_continuous(data, tref, minmax, ep=None, time_unit="s"): def compute_event_trigger_average( - group, feature, binsize, windowsize, ep, time_units="s" + group, feature, binsize, windowsize, ep, time_unit="s" ): """ - Bin the spike train in binsize and compute the Spike Trigger Average (STA) within windowsize. - If C is the spike count matrix and feature is a Tsd array, the function computes + Bin the spike train in binsize and compute the Event Trigger Average (ETA) within windowsize. + If C is the spike count matrix and `feature` is a Tsd array, the function computes the Hankel matrix H from windowsize=(-t1,+t2) by offseting the Tsd array. - The STA is then defined as the dot product between H and C divided by the number of spikes. + The ETA is then defined as the dot product between H and C divided by the number of events. Parameters ---------- @@ -197,49 +208,56 @@ def compute_event_trigger_average( The group of Ts/Tsd objects that hold the trigger time. feature : Tsd The 1-dimensional feature to average. Can be a TsdFrame with one column only. - binsize : float + binsize : float or int The bin size. Default is second. - If different, specify with the parameter time_units ('s' [default], 'ms', 'us'). + If different, specify with the parameter time_unit ('s' [default], 'ms', 'us'). windowsize : tuple or list of float The window size. Default is second. For example (-1, 1). - If different, specify with the parameter time_units ('s' [default], 'ms', 'us'). + If different, specify with the parameter time_unit ('s' [default], 'ms', 'us'). ep : IntervalSet - The epoch on which STA are computed - time_units : str, optional - The time units of the parameters. They have to be consistent for binsize and windowsize. + The epoch on which ETA are computed + time_unit : str, optional + The time unit of the parameters. They have to be consistent for binsize and windowsize. ('s' [default], 'ms', 'us'). Returns ------- TsdFrame - A TsdFrame of Spike-Trigger Average. Each column is an element from the group. + A TsdFrame of Event-Trigger Average. Each column is an element from the group. Raises ------ RuntimeError if group is not a Ts/Tsd or TsGroup """ - if type(group) is not nap.TsGroup: - raise RuntimeError("Unknown format for group") + assert isinstance(group, nap.TsGroup), "group should be a TsGroup." + assert isinstance( + windowsize, (float, int, tuple) + ), "windowsize should be a tuple or int or float." + assert isinstance(binsize, (float, int)), "binsize should be int or float." + assert isinstance(time_unit, str), "time_unit should be a str." + assert time_unit in ["s", "ms", "us"], "time_unit should be 's', 'ms' or 'us'" + assert isinstance(ep, (nap.IntervalSet)), "ep should be an IntervalSet object." if isinstance(feature, nap.TsdFrame): if feature.shape[1] == 1: feature = feature[:, 0] - if type(feature) is not nap.Tsd: - raise RuntimeError("Feature should be a Tsd or a TsdFrame with one column") + assert isinstance( + feature, nap.Tsd + ), "Feature should be a Tsd or a TsdFrame with one column" binsize = nap.TsIndex.format_timestamps( - np.array([binsize], dtype=np.float64), time_units + np.array([binsize], dtype=np.float64), time_unit )[0] start = np.abs( nap.TsIndex.format_timestamps( - np.array([windowsize[0]], dtype=np.float64), time_units + np.array([windowsize[0]], dtype=np.float64), time_unit )[0] ) end = np.abs( nap.TsIndex.format_timestamps( - np.array([windowsize[1]], dtype=np.float64), time_units + np.array([windowsize[1]], dtype=np.float64), time_unit )[0] ) idx1 = -np.arange(0, start + binsize, binsize)[::-1][:-1] diff --git a/pyproject.toml b/pyproject.toml index ed1298d9..d9863274 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "pynapple" -version = "0.4.1" +version = "0.5.0" description = "PYthon Neural Analysis Package Pour Laboratoires d’Excellence" readme = "README.md" authors = [{ name = "Guillaume Viejo", email = "guillaume.viejo@gmail.com" }] diff --git a/setup.py b/setup.py index 47f2c70a..5489cd00 100644 --- a/setup.py +++ b/setup.py @@ -59,7 +59,7 @@ test_suite='tests', tests_require=test_requirements, url='https://github.com/pynapple-org/pynapple', - version='v0.4.1', + version='v0.5.0', zip_safe=False, long_description_content_type='text/markdown', download_url='https://github.com/pynapple-org/pynapple/archive/refs/tags/v0.4.1.tar.gz' diff --git a/test_dandi.py b/test_dandi.py deleted file mode 100644 index 703ebfae..00000000 --- a/test_dandi.py +++ /dev/null @@ -1,52 +0,0 @@ -# -*- coding: utf-8 -*- -# @Author: gviejo -# @Date: 2023-07-25 18:30:04 -# @Last Modified by: gviejo -# @Last Modified time: 2023-07-27 16:50:38 -import pynwb -from pynwb import NWBHDF5IO, TimeSeries - -# from nwbwidgets import nwb2widget - -from dandi.dandiapi import DandiAPIClient -import pynapple as nap -import numpy as np -import fsspec -from fsspec.implementations.cached import CachingFileSystem - -import pynwb -import h5py - - -from matplotlib.pyplot import * - -# ecephys, Buzsaki Lab (15.2 GB) -dandiset_id, filepath = "000003", "sub-YutaMouse41/sub-YutaMouse41_ses-YutaMouse41-150831_behavior+ecephys.nwb" - - -with DandiAPIClient() as client: - asset = client.get_dandiset(dandiset_id, "draft").get_asset_by_path(filepath) - s3_url = asset.get_content_url(follow_redirects=1, strip_query=True) - - - - -# first, create a virtual filesystem based on the http protocol -fs=fsspec.filesystem("http") - -# create a cache to save downloaded data to disk (optional) -fs = CachingFileSystem( - fs=fs, - cache_storage="nwb-cache", # Local folder for the cache -) - -# next, open the file -file = h5py.File(fs.open(s3_url, "rb")) -io = pynwb.NWBHDF5IO(file=file, load_namespaces=True) - - -##################################### -# Pynapple -##################################### - -nwb = nap.NWBFile(io.read()) diff --git a/test_ndarray.py b/test_ndarray.py deleted file mode 100644 index 664fd374..00000000 --- a/test_ndarray.py +++ /dev/null @@ -1,53 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt - -class Index: - def __init__(self, index): - self.values = index - - def to_numpy(self): - return np.array(self.values) - -class CustomArray: - def __init__(self, data, index=None, name=None): - self.data = np.array(data) - self.index = Index(index) - self.name = name - - def plot(self, kind='line'): - if kind == 'line': - self._plot_line() - elif kind == 'bar': - self._plot_bar() - # Add more plot types as needed - - def _plot_line(self): - x = self.index if self.index is not None else np.arange(len(self.data)) - plt.plot(self.index, self.data) - plt.xlabel('X-axis') - plt.ylabel('Y-axis') - plt.title(self.name if self.name is not None else 'Custom Array Plot') - plt.show() - - def _plot_bar(self): - x = self.index if self.index is not None else np.arange(len(self.data)) - plt.bar(x, self.data) - plt.xlabel('X-axis') - plt.ylabel('Y-axis') - plt.title(self.name if self.name is not None else 'Custom Array Bar Plot') - - def __array__(self): - return self.data - - def to_numpy(self): - return self.data - -# Example usage: -custom_data = np.array([0, 10, 20, 30, 40, 50]) -custom_index = np.array([0, 100, 200, 300, 400, 500]) -custom_name = 'MyCustomArray' - -custom_series = CustomArray(custom_data, index=custom_index, name=custom_name) -plt.plot(custom_series) # Use plot(custom_series) to create the plot - -plt.show() \ No newline at end of file diff --git a/test_new_io.py b/test_new_io.py deleted file mode 100644 index a26a28b2..00000000 --- a/test_new_io.py +++ /dev/null @@ -1,37 +0,0 @@ -# -*- coding: utf-8 -*- -# @Author: Guillaume Viejo -# @Date: 2023-05-15 15:37:03 -# @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-07-08 17:51:46 - -import numpy as np -import pynapple as nap - -path = "/Users/gviejo/Dropbox/MyProject" - -project = nap.load_folder(path) - -data = project['sub-A001']['ses-01']['pynapple'] - -# project = nap.load_project(path) - -# print(project) - -# print(project['sub-A001']) - -# print(project['sub-A001']['ses-01']) - -# print(project['sub-A001']['ses-01']['pynapple']['spikes']) - - -# session = project['sub-A001']['ses-01'] - -# epoch = nap.IntervalSet(start = np.array([0, 3]), end = np.array([1, 6])) - -# session.save(epoch, "stimulus-fish", "Fish pictures to V1") - -# print(session) - -# print(session['stimulus-fish']) - -# print(session.doc('stimulus-fish')) \ No newline at end of file diff --git a/tests/test_perievent.py b/tests/test_perievent.py index 1a75daad..2c8f14b7 100644 --- a/tests/test_perievent.py +++ b/tests/test_perievent.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-03-30 11:16:53 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-12-11 18:16:58 +# @Last Modified time: 2023-12-12 17:45:38 #!/usr/bin/env python """Tests of perievent for `pynapple` package.""" @@ -23,14 +23,6 @@ def test_align_tsd(): for i, j in zip(peth.keys(), np.arange(0, 100, 10)): np.testing.assert_array_almost_equal(peth[i].index, np.arange(-10, 10)) -def test_compute_perievent_continuous(): - tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) - tref = nap.Ts(t=np.arange(10, 100, 10)) - pe = nap.compute_perievent_continuous(tsd, tref, minmax=(-10, 10)) - - assert isinstance(pe, nap.TsdFrame) - assert pe.shape[1] == len(tref) - def test_compute_perievent_with_tsd(): tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) tref = nap.Ts(t=np.arange(10, 100, 10)) @@ -60,17 +52,27 @@ def test_compute_perievent_minmax(): np.testing.assert_array_almost_equal(peth[i].values, np.arange(j, j + 20)) def test_compute_perievent_raise_error(): - tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) - tref = np.arange(10, 100, 10) - with pytest.raises(Exception) as e_info: - nap.compute_perievent(tsd, tref, minmax=(-10, 10)) - assert str(e_info.value) == "tref should be a Ts or Tsd object." - tsd = t = np.arange(100) + tsd = nap.Ts(t=np.arange(100)) tref = nap.Ts(t=np.arange(10, 100, 10)) - with pytest.raises(Exception) as e_info: - nap.compute_perievent(tsd, tref, minmax=(-10, 10)) - assert str(e_info.value) == "Unknown format for data" + with pytest.raises(AssertionError) as e_info: + nap.compute_perievent(tsd, [0,1,2], minmax=(-10, 10)) + assert str(e_info.value) == "tref should be a Ts or Tsd object." + + with pytest.raises(AssertionError) as e_info: + nap.compute_perievent([0,1,2], tref, minmax=(-10, 10)) + assert str(e_info.value) == "data should be a Ts, Tsd or TsGroup." + + with pytest.raises(AssertionError) as e_info: + nap.compute_perievent(tsd, tref, minmax={0:1}) + assert str(e_info.value) == "minmax should be a tuple or int or float." + with pytest.raises(AssertionError) as e_info: + nap.compute_perievent(tsd, tref, minmax=10, time_unit=1) + assert str(e_info.value) == "time_unit should be a str." + + with pytest.raises(AssertionError) as e_info: + nap.compute_perievent(tsd, tref, minmax=10, time_unit='a') + assert str(e_info.value) == "time_unit should be 's', 'ms' or 'us'" def test_compute_perievent_with_tsgroup(): tsgroup = nap.TsGroup( @@ -101,3 +103,123 @@ def test_compute_perievent_time_units(): for i, j in zip(peth.keys(), np.arange(0, 100, 10)): np.testing.assert_array_almost_equal(peth[i].index, np.arange(-10, 10)) np.testing.assert_array_almost_equal(peth[i].values, np.arange(j, j + 20)) + + +def test_compute_perievent_continuous(): + tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) + tref = nap.Ts(t=np.array([20, 60])) + minmax=(-5, 10) + pe = nap.compute_perievent_continuous(tsd, tref, minmax=minmax) + assert isinstance(pe, nap.TsdFrame) + assert pe.shape[1] == len(tref) + np.testing.assert_array_almost_equal(pe.index.values, np.arange(minmax[0], minmax[-1]+1)) + tmp = np.array([np.arange(t+minmax[0], t+minmax[1]+1) for t in tref.t]).T + np.testing.assert_array_almost_equal(pe.values, tmp) + + minmax=(5, 10) + pe = nap.compute_perievent_continuous(tsd, tref, minmax=minmax) + np.testing.assert_array_almost_equal(pe.values, tmp) + + tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) + tref = nap.Ts(t=np.array([20, 60])) + minmax=5 + pe = nap.compute_perievent_continuous(tsd, tref, minmax=minmax) + assert isinstance(pe, nap.TsdFrame) + assert pe.shape[1] == len(tref) + np.testing.assert_array_almost_equal(pe.index.values, np.arange(-minmax, minmax+1)) + tmp = np.array([np.arange(t-minmax, t+minmax+1) for t in tref.t]).T + np.testing.assert_array_almost_equal(pe.values, tmp) + + tsd = nap.TsdFrame(t=np.arange(100), d=np.random.randn(100, 3)) + tref = nap.Ts(t=np.array([20, 60])) + minmax=(-5, 10) + pe = nap.compute_perievent_continuous(tsd, tref, minmax=minmax) + assert isinstance(pe, nap.TsdTensor) + assert pe.d.ndim == 3 + assert pe.shape[1:] == (len(tref), tsd.shape[1]) + np.testing.assert_array_almost_equal(pe.index.values, np.arange(minmax[0], minmax[-1]+1)) + tmp = np.zeros(pe.shape) + for i,t in enumerate(tref.t): + idx = np.where(tsd.t == t)[0][0] + tmp[:,i,:] = tsd.values[idx+minmax[0]:idx+minmax[1]+1] + np.testing.assert_array_almost_equal(pe.values, tmp) + + tsd = nap.TsdTensor(t=np.arange(100), d=np.random.randn(100, 3, 4)) + tref = nap.Ts(t=np.array([20, 60])) + minmax=(-5, 10) + pe = nap.compute_perievent_continuous(tsd, tref, minmax=minmax) + assert isinstance(pe, nap.TsdTensor) + assert pe.d.ndim == 4 + assert pe.shape[1:] == (len(tref), *tsd.shape[1:]) + np.testing.assert_array_almost_equal(pe.index.values, np.arange(minmax[0], minmax[-1]+1)) + tmp = np.zeros(pe.shape) + for i,t in enumerate(tref.t): + idx = np.where(tsd.t == t)[0][0] + tmp[:,i,:] = tsd.values[idx+minmax[0]:idx+minmax[1]+1] + np.testing.assert_array_almost_equal(pe.values, tmp) + + +def test_compute_perievent_continuous_time_units(): + tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) + tref = nap.Ts(t=np.array([20, 60])) + minmax = (-5, 10) + for tu, fa in zip(["s", "ms", "us"], [1, 1e3, 1e6]): + pe = nap.compute_perievent_continuous(tsd, tref, minmax=(minmax[0] * fa, minmax[1] * fa), time_unit=tu) + np.testing.assert_array_almost_equal(pe.index.values, np.arange(minmax[0], minmax[1]+1)) + tmp = np.array([np.arange(t+minmax[0], t+minmax[1]+1) for t in tref.t]).T + np.testing.assert_array_almost_equal(pe.values, tmp) + + +def test_compute_perievent_continuous_with_ep(): + tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) + tref = nap.Ts(t=np.array([10, 50, 80])) + minmax=(-5, 10) + ep = nap.IntervalSet(start = [0, 60], end = [40, 99]) + pe = nap.compute_perievent_continuous(tsd, tref, minmax=minmax, ep=ep) + + assert pe.shape[1] == len(tref)-1 + tmp = np.array([np.arange(t+minmax[0], t+minmax[1]+1) for t in tref.restrict(ep).t]).T + np.testing.assert_array_almost_equal(pe.values, tmp) + + tref = ep.starts + pe = nap.compute_perievent_continuous(tsd, tref, minmax=minmax, ep=ep) + tmp = np.array([np.arange(t, t+minmax[1]+1) for t in tref.restrict(ep).t]).T + np.testing.assert_array_almost_equal(pe.values[abs(minmax[0]):], tmp) + + tref = ep.ends + pe = nap.compute_perievent_continuous(tsd, tref, minmax=minmax, ep=ep) + tmp = np.array([np.arange(t+minmax[0], t+1) for t in tref.restrict(ep).t]).T + np.testing.assert_array_almost_equal(pe.values[:-abs(minmax[1])], tmp) + + ep = nap.IntervalSet(start = [100], end = [200]) + tref = nap.Ts(t=np.array([120, 150, 180])) + pe = nap.compute_perievent_continuous(tsd, tref, minmax=minmax, ep=ep) + assert np.all(np.isnan(pe.values)) + + +def test_compute_perievent_continuous_raise_error(): + tsd = nap.Tsd(t=np.arange(100), d=np.arange(100)) + tref = nap.Ts(t=np.arange(10, 100, 10)) + with pytest.raises(AssertionError) as e_info: + nap.compute_perievent_continuous(tsd, [0,1,2], minmax=(-10, 10)) + assert str(e_info.value) == "tref should be a Ts or Tsd object." + + with pytest.raises(AssertionError) as e_info: + nap.compute_perievent_continuous([0,1,2], tref, minmax=(-10, 10)) + assert str(e_info.value) == "data should be a Tsd, TsdFrame or TsdTensor." + + with pytest.raises(AssertionError) as e_info: + nap.compute_perievent_continuous(tsd, tref, minmax={0:1}) + assert str(e_info.value) == "minmax should be a tuple or int or float." + + with pytest.raises(AssertionError) as e_info: + nap.compute_perievent_continuous(tsd, tref, minmax=10, time_unit=1) + assert str(e_info.value) == "time_unit should be a str." + + with pytest.raises(AssertionError) as e_info: + nap.compute_perievent_continuous(tsd, tref, minmax=10, time_unit='a') + assert str(e_info.value) == "time_unit should be 's', 'ms' or 'us'" + + with pytest.raises(AssertionError) as e_info: + nap.compute_perievent_continuous(tsd, tref, minmax=10, ep='a') + assert str(e_info.value) == "ep should be an IntervalSet object." diff --git a/tests/test_spike_trigger_average.py b/tests/test_spike_trigger_average.py index ee741f84..38733d17 100644 --- a/tests/test_spike_trigger_average.py +++ b/tests/test_spike_trigger_average.py @@ -2,7 +2,7 @@ # @Author: gviejo # @Date: 2022-08-29 17:27:02 # @Last Modified by: Guillaume Viejo -# @Last Modified time: 2023-11-20 12:07:53 +# @Last Modified time: 2023-12-12 18:10:30 #!/usr/bin/env python """Tests of spike trigger average for `pynapple` package.""" @@ -43,6 +43,30 @@ def test_compute_spike_trigger_average(): sta = nap.compute_event_trigger_average(spikes, feature, 0.2, (0.6, 0.6), ep) np.testing.assert_array_almost_equal(sta, output) +def test_compute_spike_trigger_average_add_nan(): + ep = nap.IntervalSet(0, 110) + feature = nap.Tsd( + t=np.arange(0, 110, 0.01), d=np.zeros(int(110 / 0.01)), time_support=ep + ) + t1 = np.arange(1, 100) + x = np.arange(100, 10000, 100) + feature[x] = 1.0 + spikes = nap.TsGroup( + {0: nap.Ts(t1), 1: nap.Ts(t1 - 0.1), 2: nap.Ts(t1 + 0.2)}, time_support=ep + ) + + feature[-1001:] = np.nan + + sta = nap.compute_event_trigger_average(spikes, feature, 0.2, (0.6, 0.6), ep) + + output = np.zeros((7, 3)) + output[3, 0] = 0.05 + output[4, 1] = 0.05 + output[2, 2] = 0.05 + + assert isinstance(sta, nap.TsdFrame) + assert sta.shape == output.shape + np.testing.assert_array_almost_equal(sta, output) def test_compute_spike_trigger_average_raise_error(): ep = nap.IntervalSet(0, 101) @@ -55,7 +79,7 @@ def test_compute_spike_trigger_average_raise_error(): with pytest.raises(Exception) as e_info: nap.compute_event_trigger_average(feature, feature, 0.1, (0.5, 0.5), ep) - assert str(e_info.value) == "Unknown format for group" + assert str(e_info.value) == "group should be a TsGroup." feature = nap.TsdFrame( t=np.arange(0, 101, 0.01), d=np.random.rand(int(101 / 0.01), 3), time_support=ep @@ -70,7 +94,7 @@ def test_compute_spike_trigger_average_raise_error(): -def test_compute_spike_trigger_average_time_units(): +def test_compute_spike_trigger_average_time_unit(): ep = nap.IntervalSet(0, 100) feature = pd.Series(index=np.arange(0, 101, 0.01), data=np.zeros(int(101 / 0.01))) t1 = np.arange(1, 100) @@ -93,7 +117,7 @@ def test_compute_spike_trigger_average_time_units(): for tu, fa in zip(["s", "ms", "us"], [1, 1e3, 1e6]): sta = nap.compute_event_trigger_average( - spikes, feature, binsize * fa, tuple(windowsize * fa), ep, time_units=tu + spikes, feature, binsize * fa, tuple(windowsize * fa), ep, time_unit=tu ) assert isinstance(sta, nap.TsdFrame) assert sta.shape == output.shape