diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index 7eda994e..21ff6a12 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -38,7 +38,6 @@ def __init__(self): self.metadata_index = self.index self._metadata = pd.DataFrame(index=self.metadata_index) - self._metadata_groups = None def __dir__(self): """ @@ -358,48 +357,42 @@ def groupby(self, by=None, get_group=None): raise ValueError( f"Metadata column '{b}' not found. Metadata columns are {self.metadata_columns}" ) - self.__dict__["_metadata_groups"] = self._metadata.groupby(by).groups + groups = self._metadata.groupby(by).groups if get_group is not None: - return self.get_group(get_group) + if get_group not in groups.keys(): + raise ValueError( + f"Group '{get_group}' not found in metadata groups. Groups are {list(groups.keys())}" + ) + idx = groups[get_group] + return self[np.array(idx)] else: - return self + return groups - def get_group(self, name): + def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs): """ - Get group from metadata groups. + Apply a function to each group in a grouped pynapple object. Parameters ---------- - name : str or tuple of str - Name of the group to return. + by : str or list of str + Metadata column name(s) to group by. + func : function + Function to apply to each group. + kwargs : dict + Additional keyword arguments to pass to the function. - Returns - ------- - pynapple object - Pynapple object corresponding to the group 'name'. - Raises - ------ - RuntimeError - If no grouping has been performed. - ValueError - If group name does not exist. """ - if self._metadata_groups is None: - raise RuntimeError( - "No grouping has been performed. Please run groupby() first." - ) - elif name not in self._metadata_groups.keys(): - raise ValueError( - f"Group '{name}' not found in metadata groups. Groups are {list(self._metadata_groups.keys())}" - ) + + groups = self.groupby(by) + + out = {} + if grouped_arg is None: + for group, idx in groups.items(): + out[group] = func(self[np.array(idx)], **func_kwargs) else: - idx = self._metadata_groups[name] - return self[np.array(idx)] + for group, idx in groups.items(): + func_kwargs[grouped_arg] = self[np.array(idx)] + out[group] = func(**func_kwargs) - @property - def metadata_groups(self): - """ - Dictionary of metadata groups. Keys are group names and values are the indices of the group. Is None if no grouping has been performed. - """ - return self._metadata_groups + return out diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 02784dd1..edc9ce34 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1069,21 +1069,20 @@ def test_metadata_groupby(self, obj, metadata, group, obj_len): pytest.skip("groupby not relevant for length 1 objects") obj.set_info(metadata) - # assert no groups - assert obj._metadata_groups == None - assert obj._metadata_groups == obj.metadata_groups # pandas groups - groups = obj._metadata.groupby(group) + pd_groups = obj._metadata.groupby(group) - # group by metadata, assert saved groups - obj.groupby(group) - assert obj._metadata_groups.keys() == groups.groups.keys() - for grp, idx in obj._metadata_groups.items(): + # group by metadata, assert returned groups + nap_groups = obj.groupby(group) + assert nap_groups.keys() == pd_groups.groups.keys() + + for grp, idx in nap_groups.items(): # index same as pandas - assert all(idx == groups.groups[grp]) + assert all(idx == pd_groups.groups[grp]) - obj_grp = obj.get_group(grp) + # return object with get_group argument + obj_grp = obj.groupby(group, get_group=grp) # get_group should be the same as indexed object pd.testing.assert_frame_equal( @@ -1095,15 +1094,29 @@ def test_metadata_groupby(self, obj, metadata, group, obj_len): # columns should be the same assert all(obj_grp.columns == obj[np.array(idx)].columns) - # get_group should be the same as groupby with specified group - pd.testing.assert_frame_equal( - obj_grp._metadata, obj.groupby(group, grp)._metadata + @pytest.mark.parametrize( + "metadata, group", + [ + ({"label": [1, 1, 2, 2]}, "label"), + ({"l1": [1, 1, 2, 2], "l2": ["a", "b", "b", "b"]}, ["l1", "l2"]), + ], + ) + @pytest.mark.parametrize("func", [np.mean, np.sum, np.max, np.min]) + def test_metadata_groupby_apply_numpy(self, obj, metadata, group, func, obj_len): + if obj_len <= 1: + pytest.skip("groupby not relevant for length 1 objects") + + obj.set_info(metadata) + groups = obj.groupby(group) + + # apply numpy function through groupby_apply + grouped_out = obj.groupby_apply(group, func) + + for grp, idx in groups.items(): + # check that the output is the same as applying the function to the indexed object + np.testing.assert_array_almost_equal( + func(obj[np.array(idx)]), grouped_out[grp] ) - # index should be the same for both objects - assert all(obj_grp.index == obj.groupby(group, grp).index) - if isinstance(obj, nap.TsdFrame): - # columns should be the same for tsdframe - assert all(obj_grp.columns == obj.groupby(group, grp).columns) # test double inheritance