Skip to content

Add groupby and groupby_apply functionality to metadata #383

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
Mar 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
cb578c2
add groupby and get_group, saving groups as a dictionary
sjvenditto Dec 2, 2024
c794f6d
groupby and groupby_apply functions with preliminary tests
sjvenditto Jan 3, 2025
7404ad4
Merge branch 'dev' into metadata
sjvenditto Jan 3, 2025
87a1505
update docstrings
sjvenditto Jan 3, 2025
5bac58d
grouping examples for intervalset, also let object index be of type p…
sjvenditto Jan 3, 2025
070b771
some tsdframe grouping examples
sjvenditto Jan 3, 2025
8ba5abf
tsgroup examples
sjvenditto Jan 3, 2025
81b079d
merge redundant tests, use dict comprehension
sjvenditto Jan 6, 2025
bcd7953
Merge branch 'dev' into metadata
sjvenditto Feb 26, 2025
8c51ce6
rename grouped_arg
sjvenditto Feb 26, 2025
507b690
test for tuning curves
sjvenditto Feb 26, 2025
ddcd8c8
more tests
sjvenditto Feb 26, 2025
cd8de28
isort
sjvenditto Feb 26, 2025
0aa343c
switch lambda to def
sjvenditto Feb 26, 2025
2ead26c
accidental typing
sjvenditto Feb 26, 2025
49b48c0
test for metadata index error
sjvenditto Feb 26, 2025
7009c92
groupby errors
sjvenditto Feb 26, 2025
60c713e
test for adding function kwargs in groupby apply
sjvenditto Feb 26, 2025
4ad24d7
Apply suggestions from code review
sjvenditto Feb 28, 2025
e960f53
fix groupby for TsdFrame to index columns, fix doc strings
sjvenditto Feb 28, 2025
f8c4f5c
update user guide, fix grouping for TsdFrames with column string names
sjvenditto Feb 28, 2025
acc1631
fix docstring formatting
sjvenditto Feb 28, 2025
517f71d
minor fixes in user guide
sjvenditto Feb 28, 2025
df5fb82
another user guide fix
sjvenditto Feb 28, 2025
b061f5c
some test updates
sjvenditto Feb 28, 2025
ed2b49a
code review changes
sjvenditto Mar 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 80 additions & 18 deletions doc/user_guide/03_metadata.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,25 +36,26 @@ import pynapple as nap

# input parameters for TsGroup
group = {
1: nap.Ts(t=np.sort(np.random.uniform(0, 100, 10))),
2: nap.Ts(t=np.sort(np.random.uniform(0, 100, 20))),
3: nap.Ts(t=np.sort(np.random.uniform(0, 100, 30))),
1: nap.Ts(t=np.sort(np.random.uniform(0, 100, 100))),
2: nap.Ts(t=np.sort(np.random.uniform(0, 100, 200))),
3: nap.Ts(t=np.sort(np.random.uniform(0, 100, 300))),
4: nap.Ts(t=np.sort(np.random.uniform(0, 100, 400))),
}

# input parameters for IntervalSet
starts = [0,10,20]
ends = [5,15,25]
starts = [0,35,70]
ends = [30,65,100]

# input parameters for TsdFrame
t = np.arange(5)
d = np.ones((5,3))
d = np.tile([1,2,3], (5, 1))
columns = ["a", "b", "c"]
```

### `TsGroup`
Metadata added to `TsGroup` must match the number of `Ts`/`Tsd` objects, or the length of its `index` property.
```{code-cell} ipython3
metadata = {"region": ["pfc", "ofc", "hpc"]}
metadata = {"region": ["pfc", "pfc", "hpc", "hpc"]}

tsgroup = nap.TsGroup(group, metadata=metadata)
print(tsgroup)
Expand All @@ -64,7 +65,7 @@ When initializing with a DataFrame, the index must align with the input dictiona
```{code-cell} ipython3
metadata = pd.DataFrame(
index=group.keys(),
data=["pfc", "ofc", "hpc"],
data=["pfc", "pfc", "hpc", "hpc"],
columns=["region"]
)

Expand All @@ -88,7 +89,7 @@ print(intervalset)
Metadata can be initialized as a DataFrame using the metadata argument, or it can be inferred when initializing an `IntervalSet` with a DataFrame.
```{code-cell} ipython3
df = pd.DataFrame(
data=[[0, 5, 1, "left"], [10, 15, 0, "right"], [20, 25, 1, "left"]],
data=[[0, 30, 1, "left"], [35, 65, 0, "right"], [70, 100, 1, "left"]],
columns=["start", "end", "reward", "choice"]
)

Expand All @@ -101,7 +102,8 @@ Metadata added to `TsdFrame` must match the number of data columns, or the lengt
```{code-cell} ipython3
metadata = {
"color": ["red", "blue", "green"],
"position": [10,20,30]
"position": [10,20,30],
"label": ["x", "x", "y"]
}

tsdframe = nap.TsdFrame(d=d, t=t, columns=["a", "b", "c"], metadata=metadata)
Expand All @@ -112,8 +114,8 @@ When initializing with a DataFrame, the DataFrame index must match the `TsdFrame
```{code-cell} ipython3
metadata = pd.DataFrame(
index=["a", "b", "c"],
data=[["red", 10], ["blue", 20], ["green", 30]],
columns=["color", "position"],
data=[["red", 10, "x"], ["blue", 20, "x"], ["green", 30, "y"]],
columns=["color", "position", "label"],
)

tsdframe = nap.TsdFrame(d=d, t=t, columns=["a", "b", "c"], metadata=metadata)
Expand All @@ -130,21 +132,21 @@ The remaining metadata examples will be shown on a `TsGroup` object; however, al
### `set_info`
Metadata can be passed as a dictionary or pandas DataFrame as the first positional argument, or metadata can be passed as name-value keyword arguments.
```{code-cell} ipython3
tsgroup.set_info(unit_type=["multi", "single", "single"])
tsgroup.set_info(unit_type=["multi", "single", "single", "single"])
print(tsgroup)
```

### Using dictionary-like keys (square brackets)
Most metadata names can set as a dictionary-like key (i.e. using square brackets). The only exceptions are for `IntervalSet`, where the names "start" and "end" are reserved for class properties.
```{code-cell} ipython3
tsgroup["depth"] = [0, 1, 2]
tsgroup["depth"] = [0, 1, 2, 3]
print(tsgroup)
```

### Using attribute assignment
If the metadata name is unique from other class attributes and methods, and it is formatted properly (i.e. only alpha-numeric characters and underscores), it can be set as an attribute (i.e. using a `.` followed by the metadata name).
```{code-cell} ipython3
tsgroup.label=["MUA", "good", "good"]
tsgroup.label=["MUA", "good", "good", "good"]
print(tsgroup)
```

Expand Down Expand Up @@ -177,20 +179,80 @@ print(tsgroup.region)
User-set metadata is mutable and can be overwritten.
```{code-cell} ipython3
print(tsgroup, "\n")
tsgroup.set_info(region=["A", "B", "C"])
tsgroup.set_info(label=["A", "B", "C", "D"])
print(tsgroup)
```

## Allowed data types
As long as the length of the metadata container matches the length of the object (number of columns for `TsdFrame` and number of indices for `IntervalSet` and `TsGroup`), elements of the metadata can be any data type.
```{code-cell} ipython3
tsgroup.coords = [[1,0],[0,1],[1,1]]
tsgroup.coords = [[1,0],[0,1],[1,1],[2,1]]
print(tsgroup.coords)
```

## Using metadata to slice objects
Metadata can be used to slice or filter objects based on metadata values.
```{code-cell} ipython3
print(tsgroup[tsgroup.label == "good"])
print(tsgroup[tsgroup.label == "A"])
```

## `groupby`: Using metadata to group objects
Similar to pandas, metadata can be used to group objects based on one or more metadata columns using the object method `groupby`, where the first argument is the metadata columns name(s) to group by. This function returns a dictionary with keys corresponding to unique groups and values corresponding to object indices belonging to each group.
```{code-cell} ipython3
print(tsgroup,"\n")
print(tsgroup.groupby("region"))
```

Grouping by multiple metadata columns should be passed as a list.
```{code-cell} ipython3
tsgroup.groupby(["region","unit_type"])
```

The optional argument `get_group` can be provided to return a new object corresponding to a specific group.
```{code-cell} ipython3
tsgroup.groupby("region", get_group="hpc")
```

## `groupby_apply`: Applying functions to object groups
The `groupby_apply` object method allows a specific function to be applied to object groups. The first argument, same as `groupby`, is the metadata column(s) used to group the object. The second argument is the function to apply to each group. If only these two arguments are supplied, it is assumed that the grouped object is the first and only input to the applied function. This function returns a dictionary, where keys correspond to each unique group, and values correspond to the function output on each group.
```{code-cell} ipython3
print(tsdframe,"\n")
print(tsdframe.groupby_apply("label", np.mean))
```

If the applied function requires additional inputs, these can be passed as additional keyword arguments into `groupby_apply`.
```{code-cell} ipython3
feature = nap.Tsd(t=np.arange(100), d=np.repeat([0,1], 50))
tsgroup.groupby_apply(
"region",
nap.compute_1d_tuning_curves,
feature=feature,
nb_bins=2)
```

Alternatively, an anonymous function can be passed instead that defines additional arguments.
```{code-cell} ipython3
func = lambda x: nap.compute_1d_tuning_curves(x, feature=feature, nb_bins=2)
tsgroup.groupby_apply("region", func)
```

An anonymous function can also be used to apply a function where the grouped object is not the first input.
```{code-cell} ipython3
func = lambda x: nap.compute_1d_tuning_curves(
group=tsgroup,
feature=feature,
nb_bins=2,
ep=x)
intervalset.groupby_apply("choice", func)
```

Alternatively, the optional parameter `input_key` can be passed to specify which keyword argument the grouped object corresponds to. Other required arguments of the applied function need to be passed as keyword arguments.
```{code-cell} ipython3
intervalset.groupby_apply(
"choice",
nap.compute_1d_tuning_curves,
input_key="ep",
group=tsgroup,
feature=feature,
nb_bins=2)
```
106 changes: 105 additions & 1 deletion pynapple/core/interval_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def __getitem__(self, key):
output = self.values.__getitem__(key)
metadata = self._metadata.iloc[key].reset_index(drop=True)
return IntervalSet(start=output[:, 0], end=output[:, 1], metadata=metadata)
elif isinstance(key, pd.Series):
elif isinstance(key, (pd.Series, pd.Index)):
# use loc for metadata
output = self.values.__getitem__(key)
metadata = _MetadataMixin.__getitem__(self, key).reset_index(drop=True)
Expand Down Expand Up @@ -1198,3 +1198,107 @@ def get_info(self, key):
2 3 y
"""
return _MetadataMixin.get_info(self, key)

@add_meta_docstring("groupby")
def groupby(self, by, get_group=None):
"""
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> times = np.array([[0, 5], [10, 12], [20, 33]])
>>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]}
>>> ep = nap.IntervalSet(times,metadata=metadata)
>>> print(ep)
index start end l1 l2
0 0 5 1 x
1 10 12 2 x
2 20 33 2 y
shape: (3, 2), time unit: sec.


Grouping by a single column:

>>> ep.groupby("l2")
{'x': [0, 1], 'y': [2]}

Grouping by multiple columns:

>>> ep.groupby(["l1","l2"])
{(1, 'x'): [0], (2, 'x'): [1], (2, 'y'): [2]}

Filtering to a specific group using the output dictionary:

>>> groups = ep.groupby("l2")
>>> ep[groups["x"]]
index start end l1 l2
0 0 5 1 x
1 10 12 2 x
shape: (2, 2), time unit: sec.

Filtering to a specific group using the get_group argument:

>>> ep.groupby("l2", get_group="x")
index start end l1 l2
0 0 5 1 x
1 10 12 2 x
shape: (2, 2), time unit: sec.
"""
return _MetadataMixin.groupby(self, by, get_group)

@add_meta_docstring("groupby_apply")
def groupby_apply(self, by, func, input_key=None, **func_kwargs):
"""
Examples
--------
>>> import pynapple as nap
>>> import numpy as np
>>> times = np.array([[0, 5], [10, 12], [20, 33]])
>>> metadata = {"l1": [1, 2, 2], "l2": ["x", "x", "y"]}
>>> ep = nap.IntervalSet(times,metadata=metadata)
>>> print(ep)
index start end l1 l2
0 0 5 1 x
1 10 12 2 x
2 20 33 2 y
shape: (3, 2), time unit: sec.

Apply a numpy function::

>>> ep.groupby_apply("l2", np.mean)
{'x': 6.75, 'y': 26.5}

Apply a custom function:

>>> ep.groupby_apply("l2", lambda x: x.shape[0])
{'x': 2, 'y': 1}

Apply a function with additional arguments:

>>> ep.groupby_apply("l2", np.mean, axis=1)
{'x': array([ 2.5, 11. ]), 'y': array([26.5])}

Applying a function with additional arguments, where the grouped object is not the first argument:

>>> tsg = nap.TsGroup(
... {
... 1: nap.Ts(t=np.arange(0, 40)),
... 2: nap.Ts(t=np.arange(0, 40, 0.5), time_units="s"),
... 3: nap.Ts(t=np.arange(0, 40, 0.2), time_units="s"),
... },
... )
>>> feature = nap.Tsd(t=np.arange(40), d=np.concatenate([np.zeros(20), np.ones(20)]))
>>> func_kwargs = {
>>> "group": tsg,
>>> "feature": feature,
>>> "nb_bins": 2,
>>> }
>>> ep.groupby_apply("l2", nap.compute_1d_tuning_curves, input_key="ep", **func_kwargs)
{'x': 1 2 3
0.25 1.025641 1.823362 4.216524
0.75 NaN NaN NaN,
'y': 1 2 3
0.25 NaN NaN NaN
0.75 1.025641 1.978022 4.835165}
"""
return _MetadataMixin.groupby_apply(self, by, func, input_key, **func_kwargs)
Loading
Loading