Skip to content

Commit

Permalink
add groupby and get_group, saving groups as a dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
sjvenditto committed Dec 2, 2024
1 parent 2eea6bc commit cb578c2
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 0 deletions.
78 changes: 78 additions & 0 deletions pynapple/core/metadata_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self):
self.metadata_index = self.index

self._metadata = pd.DataFrame(index=self.metadata_index)
self._metadata_groups = None

def __dir__(self):
"""
Expand Down Expand Up @@ -325,3 +326,80 @@ def get_info(self, key):
else:
# we don't allow indexing columns with numbers, e.g. metadata[0,0]
raise IndexError(f"Unknown metadata index {key}")

def groupby(self, by=None, get_group=None):
"""
Group pynapple object by metadata column(s).
Parameters
----------
by : str or list of str
Metadata column name(s) to group by.
get_group : str, optional
Name of the group to return.
Returns
-------
pynapple object
Original pynapple object with groups set, or pynapple object corresponding to 'get_group' if it has been supplied.
Raises
------
ValueError
If metadata column does not exist.
"""
if isinstance(by, str) and by not in self.metadata_columns:
raise ValueError(
f"Metadata column '{by}' not found. Metadata columns are {self.metadata_columns}"
)
elif isinstance(by, list):
for b in by:
if b not in self.metadata_columns:
raise ValueError(
f"Metadata column '{b}' not found. Metadata columns are {self.metadata_columns}"
)
self.__dict__["_metadata_groups"] = self._metadata.groupby(by).groups
if get_group is not None:
return self.get_group(get_group)
else:
return self

def get_group(self, name):
"""
Get group from metadata groups.
Parameters
----------
name : str or tuple of str
Name of the group to return.
Returns
-------
pynapple object
Pynapple object corresponding to the group 'name'.
Raises
------
RuntimeError
If no grouping has been performed.
ValueError
If group name does not exist.
"""
if self._metadata_groups is None:
raise RuntimeError(
"No grouping has been performed. Please run groupby() first."
)
elif name not in self._metadata_groups.keys():
raise ValueError(
f"Group '{name}' not found in metadata groups. Groups are {list(self._metadata_groups.keys())}"
)
else:
idx = self._metadata_groups[name]
return self[np.array(idx)]

@property
def metadata_groups(self):
"""
Dictionary of metadata groups. Keys are group names and values are the indices of the group. Is None if no grouping has been performed.
"""
return self._metadata_groups
53 changes: 53 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,10 +643,13 @@ def test_tsgroup_metadata_future_warnings():
@pytest.fixture
def clear_metadata(obj):
if isinstance(obj, nap.TsGroup):
# clear metadata columns
columns = [col for col in obj.metadata_columns if col != "rate"]
else:
columns = obj.metadata_columns
obj._metadata.drop(columns=columns, inplace=True)
# clear metadata groups
obj.__dict__["_metadata_groups"] = None
return obj


Expand Down Expand Up @@ -1052,6 +1055,56 @@ def test_save_and_load_npz(self, obj, obj_len):
# cleaning
Path("obj.npz").unlink()

# class Test_Metadata_Group:

@pytest.mark.parametrize(
"metadata, group",
[
({"label": [1, 1, 2, 2]}, "label"),
({"l1": [1, 1, 2, 2], "l2": ["a", "b", "b", "b"]}, ["l1", "l2"]),
],
)
def test_metadata_groupby(self, obj, metadata, group, obj_len):
if obj_len <= 1:
pytest.skip("groupby not relevant for length 1 objects")

obj.set_info(metadata)
# assert no groups
assert obj._metadata_groups == None
assert obj._metadata_groups == obj.metadata_groups

# pandas groups
groups = obj._metadata.groupby(group)

# group by metadata, assert saved groups
obj.groupby(group)
assert obj._metadata_groups.keys() == groups.groups.keys()
for grp, idx in obj._metadata_groups.items():
# index same as pandas
assert all(idx == groups.groups[grp])

obj_grp = obj.get_group(grp)

# get_group should be the same as indexed object
pd.testing.assert_frame_equal(
obj_grp._metadata, obj[np.array(idx)]._metadata
)
# index should be the same for both objects
assert all(obj_grp.index == obj[np.array(idx)].index)
if isinstance(obj, nap.TsdFrame):
# columns should be the same
assert all(obj_grp.columns == obj[np.array(idx)].columns)

# get_group should be the same as groupby with specified group
pd.testing.assert_frame_equal(
obj_grp._metadata, obj.groupby(group, grp)._metadata
)
# index should be the same for both objects
assert all(obj_grp.index == obj.groupby(group, grp).index)
if isinstance(obj, nap.TsdFrame):
# columns should be the same for tsdframe
assert all(obj_grp.columns == obj.groupby(group, grp).columns)


# test double inheritance
def get_defined_members(cls):
Expand Down

0 comments on commit cb578c2

Please sign in to comment.