Skip to content

Commit

Permalink
Merge pull request #386 from pynapple-org/fix_nwb_load_metadata
Browse files Browse the repository at this point in the history
Fix loading and IntervalSet from NWB file with metadata
  • Loading branch information
gviejo authored Jan 8, 2025
2 parents 3f59ca9 + 68b638a commit 0688345
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 10 deletions.
9 changes: 9 additions & 0 deletions pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,15 @@ def __init__(
), """
DataFrame must contain columns name "start" and "end" for start and end times.
"""
# try sorting the DataFrame by start times, preserving its end pair, as an effort to preserve metadata
# since metadata would be dropped if starts and ends are sorted separately
# note that if end times are still not sorted, metadata will be dropped
if np.any(start["start"].diff() < 0):
warnings.warn(
"DataFrame is not sorted by start times. Sorting it.", stacklevel=2
)
start = start.sort_values("start").reset_index(drop=True)

metadata = start.drop(columns=["start", "end"])
end = start["end"].values.astype(np.float64)
start = start["start"].values.astype(np.float64)
Expand Down
7 changes: 3 additions & 4 deletions pynapple/io/interface_nwb.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,9 @@ def _make_interval_set(obj, **kwargs):
df = obj.to_dataframe()

if hasattr(df, "start_time") and hasattr(df, "stop_time"):
data = nap.IntervalSet(start=df["start_time"], end=df["stop_time"])
if df.shape[1] > 2:
metadata = df.drop(columns=["start_time", "stop_time"])
data.set_info(metadata)
df = df.rename(columns={"start_time": "start", "stop_time": "end"})
# create from full dataframe to ensure that metadata is associated correctly
data = nap.IntervalSet(df)
return data

else:
Expand Down
22 changes: 16 additions & 6 deletions tests/test_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

"""Tests of decoding for `pynapple` package."""

import pynapple as nap
import numpy as np
import pandas as pd
import pytest

import pynapple as nap


def get_testing_set_1d():
feature = nap.Tsd(t=np.arange(0, 100, 1), d=np.repeat(np.arange(0, 2), 50))
Expand All @@ -36,9 +37,10 @@ def test_decode_1d():
tmp[0:50, 1] = 0.0
np.testing.assert_array_almost_equal(proba.values, tmp)


def test_decode_1d_with_TsdFrame():
feature, group, tc, ep = get_testing_set_1d()
count = group.count(bin_size=1, ep = ep)
count = group.count(bin_size=1, ep=ep)
decoded, proba = nap.decode_1d(tc, count, ep, bin_size=1)
assert isinstance(decoded, nap.Tsd)
assert isinstance(proba, nap.TsdFrame)
Expand All @@ -50,6 +52,7 @@ def test_decode_1d_with_TsdFrame():
tmp[0:50, 1] = 0.0
np.testing.assert_array_almost_equal(proba.values, tmp)


def test_decode_1d_with_feature():
feature, group, tc, ep = get_testing_set_1d()
decoded, proba = nap.decode_1d(tc, group, ep, bin_size=1, feature=feature)
Expand All @@ -63,7 +66,8 @@ def test_decode_1d_with_feature():
tmp[50:, 0] = 0.0
tmp[0:50, 1] = 0.0
np.testing.assert_array_almost_equal(proba.values, tmp)



def test_decode_1d_with_dict():
feature, group, tc, ep = get_testing_set_1d()
group = dict(group)
Expand All @@ -79,18 +83,21 @@ def test_decode_1d_with_dict():
tmp[0:50, 1] = 0.0
np.testing.assert_array_almost_equal(proba.values, tmp)


def test_decode_1d_with_wrong_feature():
feature, group, tc, ep = get_testing_set_1d()
with pytest.raises(RuntimeError) as e_info:
nap.decode_1d(tc, group, ep, bin_size=1, feature=[1,2,3])
nap.decode_1d(tc, group, ep, bin_size=1, feature=[1, 2, 3])
assert str(e_info.value) == "Unknown format for feature in decode_1d"


def test_decode_1d_with_time_units():
feature, group, tc, ep = get_testing_set_1d()
for t, tu in zip([1, 1e3, 1e6], ["s", "ms", "us"]):
decoded, proba = nap.decode_1d(tc, group, ep, 1.0 * t, time_units=tu)
np.testing.assert_array_almost_equal(feature.values, decoded.values)


def test_decoded_1d_raise_errors():
feature, group, tc, ep = get_testing_set_1d()
with pytest.raises(Exception) as e_info:
Expand Down Expand Up @@ -150,9 +157,10 @@ def test_decode_2d():
tmp[51:100:2, 1] = 1
np.testing.assert_array_almost_equal(proba[:, :, 1], tmp)


def test_decode_2d_with_TsdFrame():
features, group, tc, ep, xy = get_testing_set_2d()
count = group.count(bin_size=1, ep = ep)
count = group.count(bin_size=1, ep=ep)
decoded, proba = nap.decode_2d(tc, count, ep, 1, xy)

assert isinstance(decoded, nap.TsdFrame)
Expand All @@ -169,7 +177,8 @@ def test_decode_2d_with_TsdFrame():
tmp[1:50:2, 0] = 1
tmp[51:100:2, 1] = 1
np.testing.assert_array_almost_equal(proba[:, :, 1], tmp)



def test_decode_2d_with_dict():
features, group, tc, ep, xy = get_testing_set_2d()
group = dict(group)
Expand All @@ -190,6 +199,7 @@ def test_decode_2d_with_dict():
tmp[51:100:2, 1] = 1
np.testing.assert_array_almost_equal(proba[:, :, 1], tmp)


def test_decode_2d_with_feature():
features, group, tc, ep, xy = get_testing_set_2d()
decoded, proba = nap.decode_2d(tc, group, ep, 1, xy)
Expand Down
38 changes: 38 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,44 @@ def test_create_iset_from_df_with_metadata():
np.testing.assert_array_almost_equal(df.end.values, ep.end)


@pytest.mark.parametrize(
"df, expected",
[
# dataframe is sorted and metadata is kept
(
pd.DataFrame(
{
"start": [25.0, 0.0, 10.0, 16.0],
"end": [40.0, 5.0, 15.0, 20.0],
"label": np.arange(4),
}
),
["DataFrame is not sorted by start times"],
),
(
# dataframe is sorted and and metadata is dropped
pd.DataFrame(
{
"start": [25, 0, 10, 16],
"end": [40, 20, 15, 20],
"label": np.arange(4),
}
),
["DataFrame is not sorted by start times", "dropping metadata"],
),
],
)
def test_create_iset_from_df_with_metadata_sort(df, expected):
with warnings.catch_warnings(record=True) as w:
ep = nap.IntervalSet(df)
for e in expected:
assert np.any([e in str(w.message) for w in w])
if "dropping metadata" not in expected:
pd.testing.assert_frame_equal(
ep.as_dataframe(), df.sort_values("start").reset_index(drop=True)
)


@pytest.mark.parametrize(
"index",
[
Expand Down

0 comments on commit 0688345

Please sign in to comment.