Skip to content

Commit

Permalink
Adding dropna
Browse files Browse the repository at this point in the history
  • Loading branch information
gviejo committed Nov 20, 2023
1 parent e081914 commit 288662b
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 42 deletions.
29 changes: 27 additions & 2 deletions pynapple/core/jitted_functions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
# @Author: guillaume
# @Date: 2022-10-31 16:44:31
# @Last Modified by: gviejo
# @Last Modified time: 2023-10-15 16:05:27
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-11-19 18:27:43
import numpy as np
from numba import jit

Expand Down Expand Up @@ -749,6 +749,31 @@ def jitin_interval(time_array, starts, ends):
return data


@jit(nopython=True)
def jitremove_nan(time_array, index_nan):
n = len(time_array)
ix_start = np.zeros(n, dtype=np.bool_)
ix_end = np.zeros(n, dtype=np.bool_)

if not index_nan[0]: # First start
ix_start[0] = True

t = 1
while t < n:
if index_nan[t - 1] and not index_nan[t]: # start
ix_start[t] = True
if not index_nan[t - 1] and index_nan[t]: # end
ix_end[t - 1] = True
t += 1

if not index_nan[-1]: # Last stop
ix_end[-1] = True

starts = time_array[ix_start]
ends = time_array[ix_end]
return (starts, ends)


@jit(nopython=True)
def jit_poisson_IRLS(X, y, niter=100, tolerance=1e-5):
y = y.astype(np.float64)
Expand Down
69 changes: 37 additions & 32 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: gviejo
# @Date: 2022-01-27 18:33:31
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-11-17 16:09:35
# @Last Modified time: 2023-11-19 18:59:08

"""
Expand Down Expand Up @@ -39,6 +39,7 @@
jitbin,
jitbin_array,
jitcount,
jitremove_nan,
jitrestrict,
jitthreshold,
jittsrestrict,
Expand Down Expand Up @@ -787,43 +788,47 @@ def get(self, start, end=None, time_units="s"):
idx_end = np.searchsorted(time_array, end, side="right")
return self[idx_start:idx_end]

def dropna(self):
nant = np.any(np.isnan(self.values), 1)
if np.any(nant):
starts = []
ends = []
n = 0
if not nant[n]: # start is the time support
starts.append(self.time_support.start.values[0])
else:
while n<len(self):
if nant[n]:
n+=1
else:
starts.append(self.index.values[n])
break
def dropna(self, update_time_support=True):
"""Drop every rows containing NaNs. By default, the time support is updated to start and end around the time points that are non NaNs.
To change this behavior, you can set update_time_support=False.
while n<len(self):
if nant[n]:
ends.append(nant[n-1])
break
else:
n+=1
Parameters
----------
update_time_support : bool, optional
while n<len(self):
if not nant[n]:
n+1
else:
starts.append(self.index.values[n])
break
Returns
-------
Tsd, TsdFrame or TsdTensor
The time series without the NaNs
"""
index_nan = np.any(np.isnan(self.values), axis=tuple(range(1, self.ndim)))
if np.all(index_nan): # In case it's only NaNs
return self.__class__(
t=np.array([]), d=np.empty(tuple([0] + [d for d in self.shape[1:]]))
)

if not nant[-1]: # end is the time support
ends.append(self.time_support.end.values[0])
elif np.any(index_nan):
if update_time_support:
time_array = self.index.values
starts, ends = jitremove_nan(time_array, index_nan)

to_fix = starts == ends
if np.any(to_fix):
ends[
to_fix
] += 1e-6 # adding 1 millisecond in case of a single point

else:
return self
ep = IntervalSet(starts, ends)

return self.__class__(
t=time_array[~index_nan], d=self.values[~index_nan], time_support=ep
)

else:
return self[~index_nan]

else:
return self


class TsdTensor(NDArrayOperatorsMixin, _AbstractTsd):
Expand Down
9 changes: 7 additions & 2 deletions pynapple/process/perievent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
# @Author: gviejo
# @Date: 2022-01-30 22:59:00
# @Last Modified by: gviejo
# @Last Modified time: 2023-11-16 11:34:48
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-11-19 19:13:24

import numpy as np
from scipy.linalg import hankel
Expand Down Expand Up @@ -168,6 +168,11 @@ def compute_event_trigger_average(

tmp = feature.bin_average(binsize, ep)

# Check for any NaNs in feature
if np.any(np.isnan(tmp)):
tmp = tmp.dropna()
count = count.restrict(tmp.time_support)

# Build the Hankel matrix
n_p = len(idx1)
n_f = len(idx2)
Expand Down
9 changes: 6 additions & 3 deletions tests/test_numpy_compatibility.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
# @Author: Guillaume Viejo
# @Date: 2023-09-18 18:11:24
# @Last Modified by: gviejo
# @Last Modified time: 2023-11-08 18:14:12
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-11-19 16:55:26



Expand All @@ -17,7 +17,10 @@

tsd = nap.TsdTensor(t=np.arange(100), d=np.random.rand(100, 5, 3), time_units="s")

tsd = nap.TsdFrame(t=np.arange(100), d=np.random.randn(100, 6))
# tsd = nap.TsdFrame(t=np.arange(100), d=np.random.randn(100, 6))

tsd.d[tsd.values>0.9] = np.NaN


@pytest.mark.parametrize(
"tsd",
Expand Down
35 changes: 32 additions & 3 deletions tests/test_time_series.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# -*- coding: utf-8 -*-
# @Author: gviejo
# @Date: 2022-04-01 09:57:55
# @Last Modified by: gviejo
# @Last Modified time: 2023-11-08 18:46:52
# @Last Modified by: Guillaume Viejo
# @Last Modified time: 2023-11-19 18:48:57
#!/usr/bin/env python

"""Tests of time series for `pynapple` package."""
Expand Down Expand Up @@ -271,7 +271,7 @@ def __init__(self):
@pytest.mark.parametrize(
"tsd",
[
nap.Tsd(t=np.arange(100), d=np.arange(100), time_units="s"),
nap.Tsd(t=np.arange(100), d=np.random.rand(100), time_units="s"),
nap.TsdFrame(t=np.arange(100), d=np.random.rand(100, 5), time_units="s"),
nap.TsdTensor(t=np.arange(100), d=np.random.rand(100, 5, 2), time_units="s"),
nap.Ts(t=np.arange(100), time_units="s"),
Expand Down Expand Up @@ -393,6 +393,35 @@ def test_get_timepoint(self, tsd):
np.testing.assert_array_equal(tsd.get(1), tsd[1])
np.testing.assert_array_equal(tsd.get(1000), tsd[-1])

def test_dropna(self, tsd):
if not isinstance(tsd, nap.Ts):

new_tsd = tsd.dropna()
np.testing.assert_array_equal(tsd.index.values, new_tsd.index.values)
np.testing.assert_array_equal(tsd.values, new_tsd.values)

tsd.values[tsd.values>0.9] = np.NaN
new_tsd = tsd.dropna()
assert not np.all(np.isnan(new_tsd))
tokeep = np.array([~np.any(np.isnan(tsd[i])) for i in range(len(tsd))])
np.testing.assert_array_equal(tsd.index.values[tokeep], new_tsd.index.values)
np.testing.assert_array_equal(tsd.values[tokeep], new_tsd.values)

newtsd2 = tsd.restrict(new_tsd.time_support)
np.testing.assert_array_equal(newtsd2.index.values, new_tsd.index.values)
np.testing.assert_array_equal(newtsd2.values, new_tsd.values)

new_tsd = tsd.dropna(update_time_support=False)
np.testing.assert_array_equal(tsd.index.values[tokeep], new_tsd.index.values)
np.testing.assert_array_equal(tsd.values[tokeep], new_tsd.values)
pd.testing.assert_frame_equal(new_tsd.time_support, tsd.time_support)

tsd.values[:] = np.NaN
new_tsd = tsd.dropna()
assert len(new_tsd) == 0
assert len(new_tsd.time_support) == 0


####################################################
# Test for tsd
####################################################
Expand Down

0 comments on commit 288662b

Please sign in to comment.