Skip to content

Commit

Permalink
groupby and groupby_apply functions with preliminary tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sjvenditto committed Jan 3, 2025
1 parent cb578c2 commit c794f6d
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 52 deletions.
61 changes: 27 additions & 34 deletions pynapple/core/metadata_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __init__(self):
self.metadata_index = self.index

self._metadata = pd.DataFrame(index=self.metadata_index)
self._metadata_groups = None

def __dir__(self):
"""
Expand Down Expand Up @@ -358,48 +357,42 @@ def groupby(self, by=None, get_group=None):
raise ValueError(
f"Metadata column '{b}' not found. Metadata columns are {self.metadata_columns}"
)
self.__dict__["_metadata_groups"] = self._metadata.groupby(by).groups
groups = self._metadata.groupby(by).groups
if get_group is not None:
return self.get_group(get_group)
if get_group not in groups.keys():
raise ValueError(
f"Group '{get_group}' not found in metadata groups. Groups are {list(groups.keys())}"
)
idx = groups[get_group]
return self[np.array(idx)]
else:
return self
return groups

def get_group(self, name):
def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs):
"""
Get group from metadata groups.
Apply a function to each group in a grouped pynapple object.
Parameters
----------
name : str or tuple of str
Name of the group to return.
by : str or list of str
Metadata column name(s) to group by.
func : function
Function to apply to each group.
kwargs : dict
Additional keyword arguments to pass to the function.
Returns
-------
pynapple object
Pynapple object corresponding to the group 'name'.
Raises
------
RuntimeError
If no grouping has been performed.
ValueError
If group name does not exist.
"""
if self._metadata_groups is None:
raise RuntimeError(
"No grouping has been performed. Please run groupby() first."
)
elif name not in self._metadata_groups.keys():
raise ValueError(
f"Group '{name}' not found in metadata groups. Groups are {list(self._metadata_groups.keys())}"
)

groups = self.groupby(by)

out = {}
if grouped_arg is None:
for group, idx in groups.items():
out[group] = func(self[np.array(idx)], **func_kwargs)
else:
idx = self._metadata_groups[name]
return self[np.array(idx)]
for group, idx in groups.items():
func_kwargs[grouped_arg] = self[np.array(idx)]
out[group] = func(**func_kwargs)

@property
def metadata_groups(self):
"""
Dictionary of metadata groups. Keys are group names and values are the indices of the group. Is None if no grouping has been performed.
"""
return self._metadata_groups
return out
49 changes: 31 additions & 18 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -1069,21 +1069,20 @@ def test_metadata_groupby(self, obj, metadata, group, obj_len):
pytest.skip("groupby not relevant for length 1 objects")

obj.set_info(metadata)
# assert no groups
assert obj._metadata_groups == None
assert obj._metadata_groups == obj.metadata_groups

# pandas groups
groups = obj._metadata.groupby(group)
pd_groups = obj._metadata.groupby(group)

# group by metadata, assert saved groups
obj.groupby(group)
assert obj._metadata_groups.keys() == groups.groups.keys()
for grp, idx in obj._metadata_groups.items():
# group by metadata, assert returned groups
nap_groups = obj.groupby(group)
assert nap_groups.keys() == pd_groups.groups.keys()

for grp, idx in nap_groups.items():
# index same as pandas
assert all(idx == groups.groups[grp])
assert all(idx == pd_groups.groups[grp])

obj_grp = obj.get_group(grp)
# return object with get_group argument
obj_grp = obj.groupby(group, get_group=grp)

# get_group should be the same as indexed object
pd.testing.assert_frame_equal(
Expand All @@ -1095,15 +1094,29 @@ def test_metadata_groupby(self, obj, metadata, group, obj_len):
# columns should be the same
assert all(obj_grp.columns == obj[np.array(idx)].columns)

# get_group should be the same as groupby with specified group
pd.testing.assert_frame_equal(
obj_grp._metadata, obj.groupby(group, grp)._metadata
@pytest.mark.parametrize(
"metadata, group",
[
({"label": [1, 1, 2, 2]}, "label"),
({"l1": [1, 1, 2, 2], "l2": ["a", "b", "b", "b"]}, ["l1", "l2"]),
],
)
@pytest.mark.parametrize("func", [np.mean, np.sum, np.max, np.min])
def test_metadata_groupby_apply_numpy(self, obj, metadata, group, func, obj_len):
if obj_len <= 1:
pytest.skip("groupby not relevant for length 1 objects")

obj.set_info(metadata)
groups = obj.groupby(group)

# apply numpy function through groupby_apply
grouped_out = obj.groupby_apply(group, func)

for grp, idx in groups.items():
# check that the output is the same as applying the function to the indexed object
np.testing.assert_array_almost_equal(
func(obj[np.array(idx)]), grouped_out[grp]
)
# index should be the same for both objects
assert all(obj_grp.index == obj.groupby(group, grp).index)
if isinstance(obj, nap.TsdFrame):
# columns should be the same for tsdframe
assert all(obj_grp.columns == obj.groupby(group, grp).columns)


# test double inheritance
Expand Down

0 comments on commit c794f6d

Please sign in to comment.