Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev #193

Merged
merged 2 commits into from
Oct 30, 2023
Merged

Dev #193

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: gviejo
# @Date: 2022-01-27 18:33:31
# @Last Modified by: gviejo
# @Last Modified time: 2023-10-18 11:16:43
# @Last Modified time: 2023-10-27 14:26:55

"""

Expand Down Expand Up @@ -729,7 +729,7 @@ def find_support(self, min_gap, time_units="s"):

Parameters
----------
min_gap : float
min_gap : float or int
minimal interval between timestamps
time_units : str, optional
Time units of min gap
Expand All @@ -754,6 +754,27 @@ def find_support(self, min_gap, time_units="s"):

return IntervalSet(start=starts, end=ends)

def get(self, start, end, time_units="s"):
"""Slice the time series from start to end such that all the timestamps satisfy start<=t<=end.

By default, the time support doesn't change. If you want to change the

Parameters
----------
start : float or int
The start
end : float or int
The end
"""
assert isinstance(start, Number), "start should be a float or int"
assert isinstance(end, Number), "end should be a float or int"
assert start < end, "Start should not precede end"
start, end = TsIndex.format_timestamps(np.array([start, end]), time_units)
time_array = self.index.values
idx_start = np.searchsorted(time_array, start)
idx_end = np.searchsorted(time_array, end, side="right")
return self[idx_start:idx_end]


class TsdTensor(NDArrayOperatorsMixin, _AbstractTsd):
"""
Expand Down Expand Up @@ -840,9 +861,22 @@ def __repr__(self):

def create_str(array):
if array.ndim == 1:
return (
"[" + array[0].__repr__() + " ... " + array[0].__repr__() + "]"
)
if len(array) > 2:
return (
"["
+ array[0].__repr__()
+ " ... "
+ array[-1].__repr__()
+ "]"
)
elif len(array) == 2:
return (
"[" + array[0].__repr__() + "," + array[1].__repr__() + "]"
)
elif len(array) == 1:
return "[" + array[0].__repr__() + "]"
else:
return "[]"
else:
return "[" + create_str(array[0]) + " ...]"

Expand Down
44 changes: 43 additions & 1 deletion tests/test_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# @Author: gviejo
# @Date: 2022-04-01 09:57:55
# @Last Modified by: gviejo
# @Last Modified time: 2023-10-18 11:18:07
# @Last Modified time: 2023-10-29 16:19:01
#!/usr/bin/env python

"""Tests of time series for `pynapple` package."""
Expand Down Expand Up @@ -246,6 +246,25 @@ def test_properties():
assert tsd.ndim == 1
assert tsd.size == 100

with pytest.raises(RuntimeError):
tsd.rate = 0

def test_abstract_class():
class DummyTsd(nap.core.time_series._AbstractTsd):
def __init__(self):
super().__init__()

tsd = DummyTsd()
assert np.isnan(tsd.rate)
assert isinstance(tsd.index, nap.TsIndex)
assert isinstance(tsd.values, np.ndarray)

# assert tsd.__repr__() == "<class '__main__.DummyTsd'>"

with pytest.raises(IndexError):
tsd['a']


####################################################
# General test for time series
####################################################
Expand Down Expand Up @@ -348,6 +367,22 @@ def test_restrict_inherit_time_support(self, tsd):
np.testing.assert_approx_equal(tsd2.time_support.start[0], ep.start[0])
np.testing.assert_approx_equal(tsd2.time_support.end[0], ep.end[0])

def test_get(self, tsd):
tsd2 = tsd.get(10, 20)
assert len(tsd2) == 11
np.testing.assert_array_equal(tsd2.index.values, tsd.index.values[10:21])
if not isinstance(tsd, nap.Ts):
np.testing.assert_array_equal(tsd2.values, tsd.values[10:21])

with pytest.raises(Exception):
tsd.get(20, 10)

with pytest.raises(Exception):
tsd.get(10, [20])

with pytest.raises(Exception):
tsd.get([10], 20)


####################################################
# Test for tsd
Expand Down Expand Up @@ -634,6 +669,13 @@ def test_str_indexing(self, tsdframe):
np.testing.assert_array_almost_equal(tsdframe.values[:,0], tsdframe['a'])
np.testing.assert_array_almost_equal(tsdframe.values[:,[0,2]], tsdframe[['a', 'c']])

with pytest.raises(Exception):
tsdframe['d']

with pytest.raises(Exception):
tsdframe[['d', 'e']]


def test_operators(self, tsdframe):
v = tsdframe.values

Expand Down
Loading