Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add groupby and groupby_apply functionality to metadata #383

Draft
wants to merge 8 commits into
base: dev
Choose a base branch
from
93 changes: 92 additions & 1 deletion pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def __getitem__(self, key):
# self[Number]
output = self.values.__getitem__(key)
return IntervalSet(start=output[0], end=output[1], metadata=metadata)
elif isinstance(key, (slice, list, np.ndarray, pd.Series)):
elif isinstance(key, (slice, list, np.ndarray, pd.Series, pd.Index)):
# self[array_like]
output = self.values.__getitem__(key)
metadata = _MetadataMixin.__getitem__(self, key).reset_index(drop=True)
Expand Down Expand Up @@ -1186,3 +1186,94 @@ def get_info(self, key):
2 3 y
"""
return _MetadataMixin.get_info(self, key)

@add_meta_docstring("groupby")
def groupby(self, by, get_group=None):
"""
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> times = np.array([[0, 5], [10, 12], [20, 33]])
>>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]}
>>> ep = nap.IntervalSet(tmp,metadata=metadata)

Grouping by a single column:

>>> ep.groupby("l2")
{'x': [0, 1], 'y': [2]}

Grouping by multiple columns:

>>> ep.groupby(["l1","l2"])
{(1, 'x'): [0], (2, 'x'): [1], (2, 'y'): [2]}

Filtering to a specific group using the output dictionary:

>>> groups = ep.groupby("l2")
>>> ep[groups["x"]]
index start end l1 l2
0 0 5 1 x
1 10 12 2 x
shape: (2, 2), time unit: sec.

Filtering to a specific group using the get_group argument:

>>> ep.groupby("l2", get_group="x")
index start end l1 l2
0 0 5 1 x
1 10 12 2 x
shape: (2, 2), time unit: sec.
"""
return _MetadataMixin.groupby(self, by, get_group)

@add_meta_docstring("groupby_apply")
def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs):
"""
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> times = np.array([[0, 5], [10, 12], [20, 33]])
>>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]}
>>> ep = nap.IntervalSet(tmp,metadata=metadata)

Apply a numpy function::

>>> ep.groupby_apply("l2", np.mean)
{'x': 6.75, 'y': 26.5}

Apply a custom function:

>>> ep.groupby_apply("l2", lambda x: x.shape[0])
{'x': 2, 'y': 1}

Apply a function with additional arguments:

>>> ep.groupby_apply("l2", np.mean, axis=1)
{'x': array([ 2.5, 11. ]), 'y': array([26.5])}

Applying a function with additional arguments, where the grouped object is not the first argument:

>>> tsg = nap.TsGroup(
... {
... 1: nap.Ts(t=np.arange(0, 40)),
... 2: nap.Ts(t=np.arange(0, 40, 0.5), time_units="s"),
... 3: nap.Ts(t=np.arange(0, 40, 0.2), time_units="s"),
... },
... )
>>> feature = nap.Tsd(t=np.arange(40), d=np.concatenate([np.zeros(20), np.ones(20)]))
>>> func_kwargs = {
>>> "group": tsg,
>>> "feature": feature,
>>> "nb_bins": 2,
>>> }
>>> ep.groupby_apply("l2", nap.compute_1d_tuning_curves, grouped_arg="ep", **func_kwargs)
{'x': 1 2 3
0.25 1.025641 1.823362 4.216524
0.75 NaN NaN NaN,
'y': 1 2 3
0.25 NaN NaN NaN
0.75 1.025641 1.978022 4.835165}
"""
return _MetadataMixin.groupby_apply(self, by, func, grouped_arg, **func_kwargs)
79 changes: 77 additions & 2 deletions pynapple/core/metadata_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def get_info(self, key):
# metadata[str] or metadata[[*str]]
return self._metadata[key]

elif isinstance(key, (Number, list, np.ndarray, pd.Series)) or (
elif isinstance(key, (Number, list, np.ndarray, pd.Series, pd.Index)) or (
isinstance(key, tuple)
and (
isinstance(key[1], str)
Expand All @@ -313,7 +313,7 @@ def get_info(self, key):
)
)
):
# assume key is index, or tupe of index and column name
# assume key is index, or tuple of index and column name
# metadata[Number], metadata[array_like], metadata[Any, str], or metadata[Any, [*str]]
return self._metadata.loc[key]

Expand All @@ -325,3 +325,78 @@ def get_info(self, key):
else:
# we don't allow indexing columns with numbers, e.g. metadata[0,0]
raise IndexError(f"Unknown metadata index {key}")

def groupby(self, by, get_group=None):
"""
Group pynapple object by metadata column(s).

Parameters
----------
by : str or list of str
Metadata column name(s) to group by.
get_group : dictionary key, optional
Name of the group to return.

Returns
-------
dict or pynapple object
Dictionary of object indices (dictionary values) corresponding to each group (dictionary keys), or pynapple object corresponding to 'get_group' if it has been supplied.

Raises
------
ValueError
If metadata column does not exist.
"""
if isinstance(by, str) and by not in self.metadata_columns:
raise ValueError(
f"Metadata column '{by}' not found. Metadata columns are {self.metadata_columns}"
)
elif isinstance(by, list):
for b in by:
if b not in self.metadata_columns:
raise ValueError(
f"Metadata column '{b}' not found. Metadata columns are {self.metadata_columns}"
)
groups = self._metadata.groupby(by).groups
if get_group is not None:
if get_group not in groups.keys():
raise ValueError(
f"Group '{get_group}' not found in metadata groups. Groups are {list(groups.keys())}"
)
idx = groups[get_group]
return self[np.array(idx)]
else:
return groups

def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs):
"""
Apply a function to each group in a grouped pynapple object.

Parameters
----------
by : str or list of str
Metadata column name(s) to group by.
func : function
Function to apply to each group.
grouped_arg : str, optional
Name of the function argument that the grouped object should be passed as. If none, the grouped object is passed as the first positional argument.
func_kwargs : dict
Additional keyword arguments to pass to the function.

Returns
-------
dict
Dictionary of results from applying the function to each group, where the keys are the group names and the values are the results.
"""

groups = self.groupby(by)

if grouped_arg is None:
out = {k: func(self[v], **func_kwargs) for k, v in groups.items()}
else:
out = {
k: func(**{grouped_arg: self[v], **func_kwargs})
for k, v in groups.items()
}

return out
91 changes: 78 additions & 13 deletions pynapple/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1429,19 +1429,6 @@ def get_info(self, key):
>>> import numpy as np
>>> metadata = {"l1": [1, 2, 3], "l2": ["x", "x", "y"]}
>>> tsdframe = nap.TsdFrame(t=np.arange(5), d=np.ones((5, 3)), metadata=metadata)
>>> print(tsdframe)
Time (s) 0 1 2
---------- -------- -------- --------
0.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0
2.0 1.0 1.0 1.0
3.0 1.0 1.0 1.0
4.0 1.0 1.0 1.0
Metadata
-------- -------- -------- --------
l1 1 2 3
l2 x x y
dtype: float64, shape: (5, 3)

To access a single metadata column:

Expand Down Expand Up @@ -1505,6 +1492,84 @@ def get_info(self, key):
"""
return _MetadataMixin.get_info(self, key)

@add_meta_docstring("groupby")
def groupby(self, by, get_group=None):
"""
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]}
>>> tsdframe = nap.TsdFrame(t=np.arange(5), d=np.ones((5, 3)), metadata=metadata)

Grouping by a single column:

>>> tsdframe.groupby("l2")
{'x': [0, 1], 'y': [2]}

Grouping by multiple columns:

>>> tsdframe.groupby(["l1","l2"])
{(1, 'x'): [0], (2, 'x'): [1], (2, 'y'): [2]}

Filtering to a specific group using the output dictionary:

>>> groups = tsdframe.groupby("l2")
>>> tsdframe[groups["x"]]
Time (s) 0 1 2
---------- -------- -------- --------
0.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0
Metadata
-------- -------- -------- --------
l1 1 2 2
l2 x x y
<BLANKLINE>
dtype: float64, shape: (2, 3)

Filtering to a specific group using the get_group argument:

>>> tsdframe.groupby("l2", get_group="x")
Time (s) 0 1 2
---------- -------- -------- --------
0.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0
Metadata
-------- -------- -------- --------
l1 1 2 2
l2 x x y
<BLANKLINE>
dtype: float64, shape: (2, 3)
"""
return _MetadataMixin.groupby(self, by, get_group)

@add_meta_docstring("groupby_apply")
def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs):
"""
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]}
>>> tsdframe = nap.TsdFrame(t=np.arange(5), d=np.ones((5, 3)), metadata=metadata)

Apply a numpy function:

>>> tsdframe.groupby_apply("l1", np.sum)
{1: 3.0, 2: 6.0}

Apply a custom function:

>>> tsdframe.groupby_apply("l1", lambda x: x.shape[0])
{1: 1, 2: 2}

Apply a function with additional arguments:

>>> tsdframe.groupby_apply("l1", np.sum, axis=0)
{1: array([1., 1., 1.]), 2: array([2., 2., 2.])}
"""
return _MetadataMixin.groupby_apply(self, by, func, grouped_arg, **func_kwargs)


class Tsd(_BaseTsd):
"""
Expand Down
83 changes: 83 additions & 0 deletions pynapple/core/ts_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,3 +1583,86 @@ def get_info(self, key):
2 3 y
"""
return _MetadataMixin.get_info(self, key)

@add_meta_docstring("groupby")
def groupby(self, by, get_group=None):
"""
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> tmp = {0:nap.Ts(t=np.arange(0,40), time_units='s'),
... 1:nap.Ts(t=np.arange(0,40,0.5), time_units='s'),
... 2:nap.Ts(t=np.arange(0,40,0.25), time_units='s'),
... }
>>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]}
>>> tsgroup = nap.TsGroup(tmp,metadata=metadata)

Grouping by a single column:

>>> tsgroup.groupby("l2")
{'x': [0, 1], 'y': [2]}

Grouping by multiple columns:

>>> tsgroup.groupby(["l1","l2"])
{(1, 'x'): [0], (2, 'x'): [1], (2, 'y'): [2]}

Filtering to a specific group using the output dictionary:

>>> groups = tsgroup.groupby("l2")
>>> tsgroup[groups["x"]]
Index rate l1 l2
------- ------- ---- ----
1 1.00503 1 x
2 2.01005 2 x

Filtering to a specific group using the get_group argument:

>>> ep.groupby("l2", get_group="x")
Index rate l1 l2
------- ------- ---- ----
1 1.00503 1 x
2 2.01005 2 x
"""
return _MetadataMixin.groupby(self, by, get_group)

def groupby_apply(self, by, func, grouped_arg=None, **func_kwargs):
"""
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> tmp = {0:nap.Ts(t=np.arange(0,40), time_units='s'),
... 1:nap.Ts(t=np.arange(0,40,0.5), time_units='s'),
... 2:nap.Ts(t=np.arange(0,40,0.25), time_units='s'),
... }
>>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]}
>>> tsgroup = nap.TsGroup(tmp,metadata=metadata)

Apply a numpy function:

>>> tsgroup.groupby_apply("l2", np.mean)
{'x': 1.5, 'y': 3.0}

Apply a custom function:

>>> tsgroup.groupby_apply("l2", lambda x: x.to_tsd().shape[0])
{'x': 120, 'y': 200}

Apply a function with additional arguments:

>>> feature = nap.Tsd(
... t=np.arange(40),
... d=np.concatenate([np.zeros(20), np.ones(20)]),
... time_support=nap.IntervalSet(np.array([[0, 5], [10, 12], [20, 33]])),
... )
>>> tsgroup.groupby_apply("l2", nap.compute_1d_tuning_curves, feature=feature, nb_bins=2)
{'x': 1 2
0.25 1.15 2.044444
0.75 1.15 2.217857,
'y': 3
0.25 4.727778
0.75 5.421429}
"""
return _MetadataMixin.groupby_apply(self, by, func, grouped_arg, **func_kwargs)
Loading
Loading