From cb578c28ac410a1ed7cf5eaf407357602cfac8c7 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Mon, 2 Dec 2024 14:59:19 -0500 Subject: [PATCH 1/7] add groupby and get_group, saving groups as a dictionary --- pynapple/core/metadata_class.py | 78 +++++++++++++++++++++++++++++++++ tests/test_metadata.py | 53 ++++++++++++++++++++++ 2 files changed, 131 insertions(+) diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index 8e494b2e..7eda994e 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -38,6 +38,7 @@ def __init__(self): self.metadata_index = self.index self._metadata = pd.DataFrame(index=self.metadata_index) + self._metadata_groups = None def __dir__(self): """ @@ -325,3 +326,80 @@ def get_info(self, key): else: # we don't allow indexing columns with numbers, e.g. metadata[0,0] raise IndexError(f"Unknown metadata index {key}") + + def groupby(self, by=None, get_group=None): + """ + Group pynapple object by metadata column(s). + + Parameters + ---------- + by : str or list of str + Metadata column name(s) to group by. + get_group : str, optional + Name of the group to return. + + Returns + ------- + pynapple object + Original pynapple object with groups set, or pynapple object corresponding to 'get_group' if it has been supplied. + + Raises + ------ + ValueError + If metadata column does not exist. + """ + if isinstance(by, str) and by not in self.metadata_columns: + raise ValueError( + f"Metadata column '{by}' not found. Metadata columns are {self.metadata_columns}" + ) + elif isinstance(by, list): + for b in by: + if b not in self.metadata_columns: + raise ValueError( + f"Metadata column '{b}' not found. Metadata columns are {self.metadata_columns}" + ) + self.__dict__["_metadata_groups"] = self._metadata.groupby(by).groups + if get_group is not None: + return self.get_group(get_group) + else: + return self + + def get_group(self, name): + """ + Get group from metadata groups. + + Parameters + ---------- + name : str or tuple of str + Name of the group to return. + + Returns + ------- + pynapple object + Pynapple object corresponding to the group 'name'. + + Raises + ------ + RuntimeError + If no grouping has been performed. + ValueError + If group name does not exist. + """ + if self._metadata_groups is None: + raise RuntimeError( + "No grouping has been performed. Please run groupby() first." + ) + elif name not in self._metadata_groups.keys(): + raise ValueError( + f"Group '{name}' not found in metadata groups. Groups are {list(self._metadata_groups.keys())}" + ) + else: + idx = self._metadata_groups[name] + return self[np.array(idx)] + + @property + def metadata_groups(self): + """ + Dictionary of metadata groups. Keys are group names and values are the indices of the group. Is None if no grouping has been performed. + """ + return self._metadata_groups diff --git a/tests/test_metadata.py b/tests/test_metadata.py index e4403001..02784dd1 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -643,10 +643,13 @@ def test_tsgroup_metadata_future_warnings(): @pytest.fixture def clear_metadata(obj): if isinstance(obj, nap.TsGroup): + # clear metadata columns columns = [col for col in obj.metadata_columns if col != "rate"] else: columns = obj.metadata_columns obj._metadata.drop(columns=columns, inplace=True) + # clear metadata groups + obj.__dict__["_metadata_groups"] = None return obj @@ -1052,6 +1055,56 @@ def test_save_and_load_npz(self, obj, obj_len): # cleaning Path("obj.npz").unlink() + # class Test_Metadata_Group: + + @pytest.mark.parametrize( + "metadata, group", + [ + ({"label": [1, 1, 2, 2]}, "label"), + ({"l1": [1, 1, 2, 2], "l2": ["a", "b", "b", "b"]}, ["l1", "l2"]), + ], + ) + def test_metadata_groupby(self, obj, metadata, group, obj_len): + if obj_len <= 1: + pytest.skip("groupby not relevant for length 1 objects") + + obj.set_info(metadata) + # assert no groups + assert obj._metadata_groups == None + assert obj._metadata_groups == obj.metadata_groups + + # pandas groups + groups = obj._metadata.groupby(group) + + # group by metadata, assert saved groups + obj.groupby(group) + assert obj._metadata_groups.keys() == groups.groups.keys() + for grp, idx in obj._metadata_groups.items(): + # index same as pandas + assert all(idx == groups.groups[grp]) + + obj_grp = obj.get_group(grp) + + # get_group should be the same as indexed object + pd.testing.assert_frame_equal( + obj_grp._metadata, obj[np.array(idx)]._metadata + ) + # index should be the same for both objects + assert all(obj_grp.index == obj[np.array(idx)].index) + if isinstance(obj, nap.TsdFrame): + # columns should be the same + assert all(obj_grp.columns == obj[np.array(idx)].columns) + + # get_group should be the same as groupby with specified group + pd.testing.assert_frame_equal( + obj_grp._metadata, obj.groupby(group, grp)._metadata + ) + # index should be the same for both objects + assert all(obj_grp.index == obj.groupby(group, grp).index) + if isinstance(obj, nap.TsdFrame): + # columns should be the same for tsdframe + assert all(obj_grp.columns == obj.groupby(group, grp).columns) + # test double inheritance def get_defined_members(cls): From c794f6d7dd5404dc380bc4513c2a5672e4788990 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Fri, 3 Jan 2025 14:54:06 -0500 Subject: [PATCH 2/7] groupby and groupby_apply functions with preliminary tests --- pynapple/core/metadata_class.py | 61 +++++++++++++++------------------ tests/test_metadata.py | 49 ++++++++++++++++---------- 2 files changed, 58 insertions(+), 52 deletions(-) diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index 7eda994e..21ff6a12 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -38,7 +38,6 @@ def __init__(self): self.metadata_index = self.index self._metadata = pd.DataFrame(index=self.metadata_index) - self._metadata_groups = None def __dir__(self): """ @@ -358,48 +357,42 @@ def groupby(self, by=None, get_group=None): raise ValueError( f"Metadata column '{b}' not found. Metadata columns are {self.metadata_columns}" ) - self.__dict__["_metadata_groups"] = self._metadata.groupby(by).groups + groups = self._metadata.groupby(by).groups if get_group is not None: - return self.get_group(get_group) + if get_group not in groups.keys(): + raise ValueError( + f"Group '{get_group}' not found in metadata groups. Groups are {list(groups.keys())}" + ) + idx = groups[get_group] + return self[np.array(idx)] else: - return self + return groups - def get_group(self, name): + def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs): """ - Get group from metadata groups. + Apply a function to each group in a grouped pynapple object. Parameters ---------- - name : str or tuple of str - Name of the group to return. + by : str or list of str + Metadata column name(s) to group by. + func : function + Function to apply to each group. + kwargs : dict + Additional keyword arguments to pass to the function. - Returns - ------- - pynapple object - Pynapple object corresponding to the group 'name'. - Raises - ------ - RuntimeError - If no grouping has been performed. - ValueError - If group name does not exist. """ - if self._metadata_groups is None: - raise RuntimeError( - "No grouping has been performed. Please run groupby() first." - ) - elif name not in self._metadata_groups.keys(): - raise ValueError( - f"Group '{name}' not found in metadata groups. Groups are {list(self._metadata_groups.keys())}" - ) + + groups = self.groupby(by) + + out = {} + if grouped_arg is None: + for group, idx in groups.items(): + out[group] = func(self[np.array(idx)], **func_kwargs) else: - idx = self._metadata_groups[name] - return self[np.array(idx)] + for group, idx in groups.items(): + func_kwargs[grouped_arg] = self[np.array(idx)] + out[group] = func(**func_kwargs) - @property - def metadata_groups(self): - """ - Dictionary of metadata groups. Keys are group names and values are the indices of the group. Is None if no grouping has been performed. - """ - return self._metadata_groups + return out diff --git a/tests/test_metadata.py b/tests/test_metadata.py index 02784dd1..edc9ce34 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1069,21 +1069,20 @@ def test_metadata_groupby(self, obj, metadata, group, obj_len): pytest.skip("groupby not relevant for length 1 objects") obj.set_info(metadata) - # assert no groups - assert obj._metadata_groups == None - assert obj._metadata_groups == obj.metadata_groups # pandas groups - groups = obj._metadata.groupby(group) + pd_groups = obj._metadata.groupby(group) - # group by metadata, assert saved groups - obj.groupby(group) - assert obj._metadata_groups.keys() == groups.groups.keys() - for grp, idx in obj._metadata_groups.items(): + # group by metadata, assert returned groups + nap_groups = obj.groupby(group) + assert nap_groups.keys() == pd_groups.groups.keys() + + for grp, idx in nap_groups.items(): # index same as pandas - assert all(idx == groups.groups[grp]) + assert all(idx == pd_groups.groups[grp]) - obj_grp = obj.get_group(grp) + # return object with get_group argument + obj_grp = obj.groupby(group, get_group=grp) # get_group should be the same as indexed object pd.testing.assert_frame_equal( @@ -1095,15 +1094,29 @@ def test_metadata_groupby(self, obj, metadata, group, obj_len): # columns should be the same assert all(obj_grp.columns == obj[np.array(idx)].columns) - # get_group should be the same as groupby with specified group - pd.testing.assert_frame_equal( - obj_grp._metadata, obj.groupby(group, grp)._metadata + @pytest.mark.parametrize( + "metadata, group", + [ + ({"label": [1, 1, 2, 2]}, "label"), + ({"l1": [1, 1, 2, 2], "l2": ["a", "b", "b", "b"]}, ["l1", "l2"]), + ], + ) + @pytest.mark.parametrize("func", [np.mean, np.sum, np.max, np.min]) + def test_metadata_groupby_apply_numpy(self, obj, metadata, group, func, obj_len): + if obj_len <= 1: + pytest.skip("groupby not relevant for length 1 objects") + + obj.set_info(metadata) + groups = obj.groupby(group) + + # apply numpy function through groupby_apply + grouped_out = obj.groupby_apply(group, func) + + for grp, idx in groups.items(): + # check that the output is the same as applying the function to the indexed object + np.testing.assert_array_almost_equal( + func(obj[np.array(idx)]), grouped_out[grp] ) - # index should be the same for both objects - assert all(obj_grp.index == obj.groupby(group, grp).index) - if isinstance(obj, nap.TsdFrame): - # columns should be the same for tsdframe - assert all(obj_grp.columns == obj.groupby(group, grp).columns) # test double inheritance From 87a150517892630d67ecd0cace9b4633a87c194f Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Fri, 3 Jan 2025 15:32:13 -0500 Subject: [PATCH 3/7] update docstrings --- pynapple/core/metadata_class.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index 21ff6a12..ecf6d6b4 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -326,7 +326,7 @@ def get_info(self, key): # we don't allow indexing columns with numbers, e.g. metadata[0,0] raise IndexError(f"Unknown metadata index {key}") - def groupby(self, by=None, get_group=None): + def groupby(self, by, get_group=None): """ Group pynapple object by metadata column(s). @@ -334,13 +334,13 @@ def groupby(self, by=None, get_group=None): ---------- by : str or list of str Metadata column name(s) to group by. - get_group : str, optional + get_group : dictionary key, optional Name of the group to return. Returns ------- - pynapple object - Original pynapple object with groups set, or pynapple object corresponding to 'get_group' if it has been supplied. + dict or pynapple object + Dictionary of object indices (dictionary values) corresponding to each group (dictionary keys), or pynapple object corresponding to 'get_group' if it has been supplied. Raises ------ @@ -378,10 +378,15 @@ def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs): Metadata column name(s) to group by. func : function Function to apply to each group. - kwargs : dict + grouped_arg : str, optional + Name of the function argument that the grouped object should be passed as. If none, the grouped object is passed as the first positional argument. + func_kwargs : dict Additional keyword arguments to pass to the function. - + Returns + ------- + dict + Dictionary of results from applying the function to each group, where the keys are the group names and the values are the results. """ groups = self.groupby(by) From 5bac58d389bb953d402950ba80b30251cec9314b Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Fri, 3 Jan 2025 16:10:20 -0500 Subject: [PATCH 4/7] grouping examples for intervalset, also let object index be of type pd.Index --- pynapple/core/interval_set.py | 93 ++++++++++++++++++++++++++++++++- pynapple/core/metadata_class.py | 4 +- 2 files changed, 94 insertions(+), 3 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 44103d80..8778364b 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -427,7 +427,7 @@ def __getitem__(self, key): # self[Number] output = self.values.__getitem__(key) return IntervalSet(start=output[0], end=output[1], metadata=metadata) - elif isinstance(key, (slice, list, np.ndarray, pd.Series)): + elif isinstance(key, (slice, list, np.ndarray, pd.Series, pd.Index)): # self[array_like] output = self.values.__getitem__(key) metadata = _MetadataMixin.__getitem__(self, key).reset_index(drop=True) @@ -1186,3 +1186,94 @@ def get_info(self, key): 2 3 y """ return _MetadataMixin.get_info(self, key) + + @add_meta_docstring("groupby") + def groupby(self, by, get_group=None): + """ + Examples + -------- + >>> import pynapple as nap + >>> import numpy as np + >>> times = np.array([[0, 5], [10, 12], [20, 33]]) + >>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]} + >>> ep = nap.IntervalSet(tmp,metadata=metadata) + + Grouping by a single column: + + >>> ep.groupby("l2") + {'x': [0, 1], 'y': [2]} + + Grouping by multiple columns: + + >>> ep.groupby(["l1","l2"]) + {(1, 'x'): [0], (2, 'x'): [1], (2, 'y'): [2]} + + Filtering to a specific group using the output dictionary: + + >>> groups = ep.groupby("l2") + >>> ep[groups["x"]] + index start end l1 l2 + 0 0 5 1 x + 1 10 12 2 x + shape: (2, 2), time unit: sec. + + Filtering to a specific group using the get_group argument: + + >>> ep.groupby("l2", get_group="x") + index start end l1 l2 + 0 0 5 1 x + 1 10 12 2 x + shape: (2, 2), time unit: sec. + """ + return _MetadataMixin.groupby(self, by, get_group) + + @add_meta_docstring("groupby_apply") + def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs): + """ + Examples + -------- + >>> import pynapple as nap + >>> import numpy as np + >>> times = np.array([[0, 5], [10, 12], [20, 33]]) + >>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]} + >>> ep = nap.IntervalSet(tmp,metadata=metadata) + + Apply a numpy function:: + + >>> ep.groupby_apply("l2", np.mean) + {'x': 6.75, 'y': 26.5} + + Apply a custom function: + + >>> ep.groupby_apply("l2", lambda x: x.shape[0]) + {'x': 2, 'y': 1} + + Apply a function with additional arguments: + + >>> ep.groupby_apply("l2", np.mean, axis=1) + {'x': array([ 2.5, 11. ]), 'y': array([26.5])} + + Applying a function with additional arguments, where the grouped object is not the first argument: + + >>> tsg = nap.TsGroup( + >>> { + >>> 1: nap.Ts(t=np.arange(0, 40)), + >>> 2: nap.Ts(t=np.arange(0, 40, 0.5), time_units="s"), + >>> 3: nap.Ts(t=np.arange(0, 40, 0.2), time_units="s"), + >>> }, + >>> ) + >>> feature = nap.Tsd(t=np.arange(40), d=np.concatenate([np.zeros(20), np.ones(20)])) + >>> func_kwargs = { + >>> "group": tsg, + >>> "feature": feature, + >>> "nb_bins": 2, + >>> } + >>> ep.groupby_apply("l2", nap.compute_1d_tuning_curves, grouped_arg="ep", **func_kwargs) + {'x': 1 2 3 + 0.25 1.025641 1.823362 4.216524 + 0.75 NaN NaN NaN, + 'y': 1 2 3 + 0.25 NaN NaN NaN + 0.75 1.025641 1.978022 4.835165} + """ + return _MetadataMixin.groupby_apply(self, by, func, grouped_arg, **func_kwargs) diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index ecf6d6b4..49f5407a 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -303,7 +303,7 @@ def get_info(self, key): # metadata[str] or metadata[[*str]] return self._metadata[key] - elif isinstance(key, (Number, list, np.ndarray, pd.Series)) or ( + elif isinstance(key, (Number, list, np.ndarray, pd.Series, pd.Index)) or ( isinstance(key, tuple) and ( isinstance(key[1], str) @@ -313,7 +313,7 @@ def get_info(self, key): ) ) ): - # assume key is index, or tupe of index and column name + # assume key is index, or tuple of index and column name # metadata[Number], metadata[array_like], metadata[Any, str], or metadata[Any, [*str]] return self._metadata.loc[key] From 070b771a4d84e9eca2e81e9965bfba1f304ed483 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Fri, 3 Jan 2025 16:22:47 -0500 Subject: [PATCH 5/7] some tsdframe grouping examples --- pynapple/core/time_series.py | 91 ++++++++++++++++++++++++++++++------ 1 file changed, 78 insertions(+), 13 deletions(-) diff --git a/pynapple/core/time_series.py b/pynapple/core/time_series.py index 72d33d5a..4ef7fc17 100644 --- a/pynapple/core/time_series.py +++ b/pynapple/core/time_series.py @@ -1429,19 +1429,6 @@ def get_info(self, key): >>> import numpy as np >>> metadata = {"l1": [1, 2, 3], "l2": ["x", "x", "y"]} >>> tsdframe = nap.TsdFrame(t=np.arange(5), d=np.ones((5, 3)), metadata=metadata) - >>> print(tsdframe) - Time (s) 0 1 2 - ---------- -------- -------- -------- - 0.0 1.0 1.0 1.0 - 1.0 1.0 1.0 1.0 - 2.0 1.0 1.0 1.0 - 3.0 1.0 1.0 1.0 - 4.0 1.0 1.0 1.0 - Metadata - -------- -------- -------- -------- - l1 1 2 3 - l2 x x y - dtype: float64, shape: (5, 3) To access a single metadata column: @@ -1505,6 +1492,84 @@ def get_info(self, key): """ return _MetadataMixin.get_info(self, key) + @add_meta_docstring("groupby") + def groupby(self, by, get_group=None): + """ + Examples + -------- + >>> import pynapple as nap + >>> import numpy as np + >>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]} + >>> tsdframe = nap.TsdFrame(t=np.arange(5), d=np.ones((5, 3)), metadata=metadata) + + Grouping by a single column: + + >>> tsdframe.groupby("l2") + {'x': [0, 1], 'y': [2]} + + Grouping by multiple columns: + + >>> tsdframe.groupby(["l1","l2"]) + {(1, 'x'): [0], (2, 'x'): [1], (2, 'y'): [2]} + + Filtering to a specific group using the output dictionary: + + >>> groups = tsdframe.groupby("l2") + >>> tsdframe[groups["x"]] + Time (s) 0 1 2 + ---------- -------- -------- -------- + 0.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 + Metadata + -------- -------- -------- -------- + l1 1 2 2 + l2 x x y + + dtype: float64, shape: (2, 3) + + Filtering to a specific group using the get_group argument: + + >>> tsdframe.groupby("l2", get_group="x") + Time (s) 0 1 2 + ---------- -------- -------- -------- + 0.0 1.0 1.0 1.0 + 1.0 1.0 1.0 1.0 + Metadata + -------- -------- -------- -------- + l1 1 2 2 + l2 x x y + + dtype: float64, shape: (2, 3) + """ + return _MetadataMixin.groupby(self, by, get_group) + + @add_meta_docstring("groupby_apply") + def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs): + """ + Examples + -------- + >>> import pynapple as nap + >>> import numpy as np + >>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]} + >>> tsdframe = nap.TsdFrame(t=np.arange(5), d=np.ones((5, 3)), metadata=metadata) + + Apply a numpy function: + + >>> tsdframe.groupby_apply("l1", np.sum) + {1: 3.0, 2: 6.0} + + Apply a custom function: + + >>> tsdframe.groupby_apply("l1", lambda x: x.shape[0]) + {1: 1, 2: 2} + + Apply a function with additional arguments: + + >>> tsdframe.groupby_apply("l1", np.sum, axis=0) + {1: array([1., 1., 1.]), 2: array([2., 2., 2.])} + """ + return _MetadataMixin.groupby_apply(self, by, func, grouped_arg, **func_kwargs) + class Tsd(_BaseTsd): """ From 8ba5abffd586f83ee08369653139780807c5552a Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Fri, 3 Jan 2025 16:41:02 -0500 Subject: [PATCH 6/7] tsgroup examples --- pynapple/core/interval_set.py | 12 ++--- pynapple/core/ts_group.py | 83 +++++++++++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 6 deletions(-) diff --git a/pynapple/core/interval_set.py b/pynapple/core/interval_set.py index 8778364b..3e5c00b0 100644 --- a/pynapple/core/interval_set.py +++ b/pynapple/core/interval_set.py @@ -1256,12 +1256,12 @@ def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs): Applying a function with additional arguments, where the grouped object is not the first argument: >>> tsg = nap.TsGroup( - >>> { - >>> 1: nap.Ts(t=np.arange(0, 40)), - >>> 2: nap.Ts(t=np.arange(0, 40, 0.5), time_units="s"), - >>> 3: nap.Ts(t=np.arange(0, 40, 0.2), time_units="s"), - >>> }, - >>> ) + ... { + ... 1: nap.Ts(t=np.arange(0, 40)), + ... 2: nap.Ts(t=np.arange(0, 40, 0.5), time_units="s"), + ... 3: nap.Ts(t=np.arange(0, 40, 0.2), time_units="s"), + ... }, + ... ) >>> feature = nap.Tsd(t=np.arange(40), d=np.concatenate([np.zeros(20), np.ones(20)])) >>> func_kwargs = { >>> "group": tsg, diff --git a/pynapple/core/ts_group.py b/pynapple/core/ts_group.py index 487dcc56..5d004ccc 100644 --- a/pynapple/core/ts_group.py +++ b/pynapple/core/ts_group.py @@ -1583,3 +1583,86 @@ def get_info(self, key): 2 3 y """ return _MetadataMixin.get_info(self, key) + + @add_meta_docstring("groupby") + def groupby(self, by, get_group=None): + """ + Examples + -------- + >>> import pynapple as nap + >>> import numpy as np + >>> tmp = {0:nap.Ts(t=np.arange(0,40), time_units='s'), + ... 1:nap.Ts(t=np.arange(0,40,0.5), time_units='s'), + ... 2:nap.Ts(t=np.arange(0,40,0.25), time_units='s'), + ... } + >>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]} + >>> tsgroup = nap.TsGroup(tmp,metadata=metadata) + + Grouping by a single column: + + >>> tsgroup.groupby("l2") + {'x': [0, 1], 'y': [2]} + + Grouping by multiple columns: + + >>> tsgroup.groupby(["l1","l2"]) + {(1, 'x'): [0], (2, 'x'): [1], (2, 'y'): [2]} + + Filtering to a specific group using the output dictionary: + + >>> groups = tsgroup.groupby("l2") + >>> tsgroup[groups["x"]] + Index rate l1 l2 + ------- ------- ---- ---- + 1 1.00503 1 x + 2 2.01005 2 x + + Filtering to a specific group using the get_group argument: + + >>> ep.groupby("l2", get_group="x") + Index rate l1 l2 + ------- ------- ---- ---- + 1 1.00503 1 x + 2 2.01005 2 x + """ + return _MetadataMixin.groupby(self, by, get_group) + + def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs): + """ + Examples + -------- + >>> import pynapple as nap + >>> import numpy as np + >>> tmp = {0:nap.Ts(t=np.arange(0,40), time_units='s'), + ... 1:nap.Ts(t=np.arange(0,40,0.5), time_units='s'), + ... 2:nap.Ts(t=np.arange(0,40,0.25), time_units='s'), + ... } + >>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]} + >>> tsgroup = nap.TsGroup(tmp,metadata=metadata) + + Apply a numpy function: + + >>> tsgroup.groupby_apply("l2", np.mean) + {'x': 1.5, 'y': 3.0} + + Apply a custom function: + + >>> tsgroup.groupby_apply("l2", lambda x: x.to_tsd().shape[0]) + {'x': 120, 'y': 200} + + Apply a function with additional arguments: + + >>> feature = nap.Tsd( + ... t=np.arange(40), + ... d=np.concatenate([np.zeros(20), np.ones(20)]), + ... time_support=nap.IntervalSet(np.array([[0, 5], [10, 12], [20, 33]])), + ... ) + >>> tsgroup.groupby_apply("l2", nap.compute_1d_tuning_curves, feature=feature, nb_bins=2) + {'x': 1 2 + 0.25 1.15 2.044444 + 0.75 1.15 2.217857, + 'y': 3 + 0.25 4.727778 + 0.75 5.421429} + """ + return _MetadataMixin.groupby_apply(self, by, func, grouped_arg, **func_kwargs) From 81b079d0acb077c38f9eefb49ec14526e7a7c3e0 Mon Sep 17 00:00:00 2001 From: sjvenditto Date: Mon, 6 Jan 2025 10:47:58 -0500 Subject: [PATCH 7/7] merge redundant tests, use dict comprehension --- pynapple/core/metadata_class.py | 11 ++++---- tests/test_metadata.py | 46 +++++---------------------------- 2 files changed, 12 insertions(+), 45 deletions(-) diff --git a/pynapple/core/metadata_class.py b/pynapple/core/metadata_class.py index 49f5407a..6f6ed2a9 100644 --- a/pynapple/core/metadata_class.py +++ b/pynapple/core/metadata_class.py @@ -391,13 +391,12 @@ def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs): groups = self.groupby(by) - out = {} if grouped_arg is None: - for group, idx in groups.items(): - out[group] = func(self[np.array(idx)], **func_kwargs) + out = {k: func(self[v], **func_kwargs) for k, v in groups.items()} else: - for group, idx in groups.items(): - func_kwargs[grouped_arg] = self[np.array(idx)] - out[group] = func(**func_kwargs) + out = { + k: func(**{grouped_arg: self[v], **func_kwargs}) + for k, v in groups.items() + } return out diff --git a/tests/test_metadata.py b/tests/test_metadata.py index edc9ce34..168bf0e1 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -1141,52 +1141,20 @@ def get_defined_members(cls): } -def test_no_conflict_between_intervalset_and_metadatamixin(): - from pynapple.core import IntervalSet +@pytest.mark.parametrize( + "nap_class", [nap.core.IntervalSet, nap.core.TsdFrame, nap.core.TsGroup] +) +def test_no_conflict_between_class_and_metadatamixin(nap_class): from pynapple.core.metadata_class import _MetadataMixin # Adjust import as needed - iset_members = get_defined_members(IntervalSet) + iset_members = get_defined_members(nap_class) metadatamixin_members = get_defined_members(_MetadataMixin) # Check for any overlapping names between IntervalSet and _MetadataMixin conflicting_members = iset_members.intersection(metadatamixin_members) - # set_info and get_info will conflict - assert len(conflicting_members) == 2, ( + # set_info, get_info, groupby, and groupby_apply are overwritten for class-specific examples in docstrings + assert len(conflicting_members) == 4, ( f"Conflict detected! The following methods/attributes are " f"overwritten in IntervalSet: {conflicting_members}" ) - - -def test_no_conflict_between_tsdframe_and_metadatamixin(): - from pynapple.core import TsdFrame - from pynapple.core.metadata_class import _MetadataMixin # Adjust import as needed - - tsdframe_members = get_defined_members(TsdFrame) - metadatamixin_members = get_defined_members(_MetadataMixin) - - # Check for any overlapping names between TsdFrame and _MetadataMixin - conflicting_members = tsdframe_members.intersection(metadatamixin_members) - - # set_info and get_info will conflict - assert len(conflicting_members) == 2, ( - f"Conflict detected! The following methods/attributes are " - f"overwritten in TsdFrame: {conflicting_members}" - ) - - -def test_no_conflict_between_tsgroup_and_metadatamixin(): - from pynapple.core import TsGroup - from pynapple.core.metadata_class import _MetadataMixin # Adjust import as needed - - tsgroup_members = get_defined_members(TsGroup) - metadatamixin_members = get_defined_members(_MetadataMixin) - - # Check for any overlapping names between TsdFrame and _MetadataMixin - conflicting_members = tsgroup_members.intersection(metadatamixin_members) - - # set_info and get_info will conflict - assert len(conflicting_members) == 2, ( - f"Conflict detected! The following methods/attributes are " - f"overwritten in TsGroup: {conflicting_members}" - )