Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
0623b9b
changed _scale_func_group to also work with 3D objects, updated doc c…
agerardy Sep 25, 2025
9ef4818
added 3d object fixture to conftest
agerardy Oct 7, 2025
4d5873c
attempted 3D version of scale_norm
agerardy Oct 7, 2025
0247e5a
test for scale_norm
agerardy Oct 7, 2025
f579bc6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 7, 2025
7fd87aa
fixed scale_norm and added 3D functionality and tests for minmax_norm…
agerardy Oct 9, 2025
a6117ad
Merge branch '944-longitudinal-normalization' of github.com:theislab/…
agerardy Oct 9, 2025
07d187c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 9, 2025
2260102
added 3D functionality and tests for all normalization functions. rem…
agerardy Oct 14, 2025
fefda75
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 14, 2025
799e84f
updated normalization to correctly work with selected variables. adde…
agerardy Oct 20, 2025
654d355
Merge branch 'main' into 944-longitudinal-normalization
agerardy Oct 20, 2025
e7887d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2025
155a72f
fixed small 3d fixture not returning anything
agerardy Oct 20, 2025
0efca8a
minor comment edits
agerardy Oct 20, 2025
cbb07d6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2025
cd26ac2
Changed logic to work with layers and R as just a layer
agerardy Oct 20, 2025
1f3c554
removed unnecessary comments and a nonfunctional test
agerardy Oct 20, 2025
f7e6bfa
3D normalization tests now work with edata_blobs_timeseries_small and…
agerardy Oct 20, 2025
da9cc22
removed unmecessary copy and fixed docstrings
agerardy Oct 20, 2025
82bd238
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 20, 2025
fa7becc
added type parameters to new functions and removed old comments
agerardy Oct 21, 2025
e8808a7
Merge branch 'main' into 944-longitudinal-normalization
Zethson Oct 22, 2025
81e4873
refined docstrings in test_normalization to be more informative
agerardy Oct 29, 2025
e38bc83
update examples for test functions
agerardy Oct 30, 2025
0566a17
Merge branch 'main' into 944-longitudinal-normalization
agerardy Nov 12, 2025
32d3869
Merge branch '944-longitudinal-normalization' of github.com:theislab/…
agerardy Nov 12, 2025
312ec7a
Merge remote-tracking branch 'origin/main' into 944-longitudinal-norm…
agerardy Nov 26, 2025
bed0dd4
updated normalization functions and tests to work with layers instead…
agerardy Nov 26, 2025
5d3381f
combined most tests into a basic test_norm_3D and test_norm_3D_precis…
agerardy Nov 28, 2025
f1b2e94
updated examples to use layers. numbers are not real examples anymore…
agerardy Nov 28, 2025
89a14ca
fixed unnecessary copy(), uses array_not_implemented error, import DE…
agerardy Nov 28, 2025
4416c3b
redid examples with real data and different layer name
agerardy Dec 3, 2025
a0a494a
split up precise tests, removed unnecessary shape checks, added comme…
agerardy Dec 3, 2025
882bfc2
group wise normalization now works with 3D and with dask arrays. fixe…
agerardy Dec 3, 2025
faa8ef5
simplified test_norm_3D to only test basic functions. removed NotImpl…
agerardy Dec 3, 2025
b7c06ce
removed group wise normalization for dask arrays, raise notImplemente…
agerardy Dec 9, 2025
028456a
forgot to tupdate some tests with NotImplementedError
agerardy Dec 9, 2025
f7af4ec
removed redundant assignments and use _offset_negative_values instead…
agerardy Dec 9, 2025
b8de4a9
added registered type to .register and cast to float32 instead of 64 …
agerardy Dec 10, 2025
1d54535
made test_norm_power_integers more lenient to allow casts to float32 …
agerardy Dec 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 123 additions & 35 deletions ehrapy/preprocessing/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,12 @@ def _scale_func_group(
copy: bool,
norm_name: str,
) -> EHRData | AnnData | None:
"""Apply scaling function to selected columns of edata, either globally or per group."""
"""Apply scaling function to selected columns of edata, either globally or per group.

Supports both 2D and 3D data:
- 2D data (n_obs × n_var): Standard normalization across observations
- 3D data (n_obs × n_var × n_timestamps): Per-variable normalization across samples and timestamps
"""
if group_key is not None and group_key not in edata.obs_keys():
raise KeyError(f"group key '{group_key}' not found in edata.obs.")

Expand All @@ -43,25 +48,58 @@ def _scale_func_group(

edata = _prep_edata_norm(edata, copy)

if layer is None:
var_values = edata[:, vars].X.copy()
else:
var_values = edata[:, vars].layers[layer].copy()

if group_key is None:
var_values = scale_func(var_values)

if hasattr(edata, "R") and edata.R is not None and edata.R.ndim == 3:
if layer is None:
var_values = edata.R[:, :, :].copy()
else:
var_values = edata.layers[layer][:, :, :].copy()

n_obs, n_var, n_timestamps = var_values.shape
if group_key is None:
for var_idx in range(n_var):
var_data = var_values[:, var_idx, :].reshape(-1, 1)
var_data = scale_func(var_data)
var_values[:, var_idx, :] = var_data.reshape(n_obs, n_timestamps)
else:
for group in edata.obs[group_key].unique():
group_idx = edata.obs[group_key] == group
group_data = var_values[group_idx]
n_obs_group = group_data.shape[0]
for var_idx in range(n_var):
var_data = group_data[:, var_idx, :].reshape(-1, 1)
var_data = scale_func(var_data)
var_values[group_idx, var_idx, :] = var_data.reshape(n_obs_group, n_timestamps)

# Write back to edata.R or edata.layers[layer]
if layer is None:
edata.R = edata.R.astype(var_values.dtype)
edata.R[:, :, :] = var_values
else:
edata.layers[layer] = edata.layers[layer].astype(var_values.dtype)
edata.layers[layer][:, :, :] = var_values
else:
for group in edata.obs[group_key].unique():
group_idx = edata.obs[group_key] == group
var_values[group_idx] = scale_func(var_values[group_idx])
# 2D normalization (AnnData or 2D EHRData)
if layer is None:
var_values = edata[:, vars].X.copy()
else:
var_values = edata[:, vars].layers[layer].copy()

if var_values.ndim == 2:
if group_key is None:
var_values = scale_func(var_values)
else:
for group in edata.obs[group_key].unique():
group_idx = edata.obs[group_key] == group
var_values[group_idx] = scale_func(var_values[group_idx])
else:
raise ValueError(f"Unsupported data dimensionality: {var_values.ndim}D. Expected 2D or 3D data.")

if layer is None:
edata.X = edata.X.astype(var_values.dtype)
edata[:, vars].X = var_values
else:
edata.layers[layer] = edata.layers[layer].astype(var_values.dtype)
edata[:, vars].layers[layer] = var_values
if layer is None:
edata.X = edata.X.astype(var_values.dtype)
edata[:, vars].X = var_values
else:
edata.layers[layer] = edata.layers[layer].astype(var_values.dtype)
edata[:, vars].layers[layer] = var_values

_record_norm(edata, vars, norm_name)

Expand Down Expand Up @@ -89,7 +127,6 @@ def _(arr: DaskArray, **kwargs):


@use_ehrdata(deprecated_after="1.0.0")
@function_2D_only()
def scale_norm(
edata: EHRData | AnnData,
vars: str | Sequence[str] | None = None,
Expand All @@ -103,6 +140,10 @@ def scale_norm(
Functionality is provided by :class:`~sklearn.preprocessing.StandardScaler`, see https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.StandardScaler.html for details.
If `edata.X` is a Dask Array, functionality is provided by :class:`~dask_ml.preprocessing.StandardScaler`, see https://ml.dask.org/modules/generated/dask_ml.preprocessing.StandardScaler.html for details.

Supports both 2D and 3D data:
- 2D data: Standard normalization across observations
- 3D data: Per-variable normalization across samples and timestamps

Args:
edata: Central data object. Must already be encoded using :func:`~ehrapy.preprocessing.encode`.
vars: List of the names of the numeric variables to normalize.
Expand All @@ -113,15 +154,21 @@ def scale_norm(
**kwargs: Additional arguments passed to the StandardScaler.

Returns:
`None` if `copy=False` and modifies the passed edata, else returns an updated AnnData object. Also stores a record of applied normalizations as a dictionary in edata.uns["normalization"].
`None` if `copy=False` and modifies the passed edata, else returns an updated edata object. Also stores a record of applied normalizations as a dictionary in edata.uns["normalization"].

Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.mimic_2()
>>> edata_norm = ep.pp.scale_norm(edata, copy=True)
>>> # Works automatically with both 2D and 3D data
>>> edata_3d_norm = ep.pp.scale_norm(edata_3d, copy=True)
"""
scale_func = _scale_norm_function(edata.X if layer is None else edata.layers[layer], **kwargs)
if hasattr(edata, "R") and edata.R is not None and edata.R.ndim == 3:
arr = edata.R if layer is None else edata.layers[layer]
else:
arr = edata.X if layer is None else edata.layers[layer]
scale_func = _scale_norm_function(arr, **kwargs)

return _scale_func_group(
edata=edata,
Expand Down Expand Up @@ -152,7 +199,6 @@ def _(arr: DaskArray, **kwargs):


@use_ehrdata(deprecated_after="1.0.0")
@function_2D_only()
def minmax_norm(
edata: EHRData | AnnData,
vars: str | Sequence[str] | None = None,
Expand All @@ -166,6 +212,10 @@ def minmax_norm(
Functionality is provided by :class:`~sklearn.preprocessing.MinMaxScaler`, see https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MinMaxScaler.html for details.
If `edata.X` is a Dask Array, functionality is provided by :class:`~dask_ml.preprocessing.MinMaxScaler`, see https://ml.dask.org/modules/generated/dask_ml.preprocessing.MinMaxScaler.html for details.

Supports both 2D and 3D data:
- 2D data: Standard normalization across observations
- 3D data: Per-variable normalization across samples and timestamps

Args:
edata: Central data object.
Must already be encoded using :func:`~ehrapy.preprocessing.encode`.
Expand All @@ -177,13 +227,15 @@ def minmax_norm(
**kwargs: Additional arguments passed to the MinMaxScaler.

Returns:
`None` if `copy=False` and modifies the passed edata, else returns an updated AnnData object. Also stores a record of applied normalizations as a dictionary in edata.uns["normalization"].
`None` if `copy=False` and modifies the passed edata, else returns an updated data object. Also stores a record of applied normalizations as a dictionary in edata.uns["normalization"].

Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.mimic_2()
>>> edata_norm = ep.pp.minmax_norm(edata, copy=True)
>>> # Works automatically with both 2D and 3D data
>>> edata_3d_norm = ep.pp.minmax_norm(edata_3d, copy=True)
"""
scale_func = _minmax_norm_function(edata.X if layer is None else edata.layers[layer], **kwargs)

Expand All @@ -208,7 +260,6 @@ def _(arr: np.ndarray):
return sklearn_pp.MaxAbsScaler().fit_transform


@function_2D_only()
@use_ehrdata(deprecated_after="1.0.0")
def maxabs_norm(
edata: EHRData | AnnData,
Expand All @@ -221,6 +272,10 @@ def maxabs_norm(

Functionality is provided by :class:`~sklearn.preprocessing.MaxAbsScaler`, see https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.MaxAbsScaler.html for details.

Supports both 2D and 3D data:
- 2D data: Standard normalization across observations
- 3D data: Per-variable normalization across samples and timestamps

Args:
edata: Central data object.
Must already be encoded using :func:`~ehrapy.preprocessing.encode`.
Expand All @@ -231,13 +286,15 @@ def maxabs_norm(
copy: Whether to return a copy or act in place.

Returns:
`None` if `copy=False` and modifies the passed edata, else returns an updated AnnData object. Also stores a record of applied normalizations as a dictionary in edata.uns["normalization"].
`None` if `copy=False` and modifies the passed edata, else returns an updated edata object. Also stores a record of applied normalizations as a dictionary in edata.uns["normalization"].

Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.mimic_2()
>>> edata_norm = ep.pp.maxabs_norm(edata, copy=True)
>>> # Works automatically with both 2D and 3D data
>>> edata_3d_norm = ep.pp.maxabs_norm(edata_3d, copy=True)
"""
scale_func = _maxabs_norm_function(edata.X if layer is None else edata.layers[layer])

Expand Down Expand Up @@ -270,7 +327,6 @@ def _(arr: DaskArray, **kwargs):


@use_ehrdata(deprecated_after="1.0.0")
@function_2D_only()
def robust_scale_norm(
edata: EHRData | AnnData,
vars: str | Sequence[str] | None = None,
Expand All @@ -285,6 +341,10 @@ def robust_scale_norm(
see https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.RobustScaler.html for details.
If `edata.X` is a Dask Array, functionality is provided by :class:`~dask_ml.preprocessing.RobustScaler`, see https://ml.dask.org/modules/generated/dask_ml.preprocessing.RobustScaler.html for details.

Supports both 2D and 3D data:
- 2D data: Standard normalization across observations
- 3D data: Per-variable normalization across samples and timestamps

Args:
edata: Central data object.
Must already be encoded using :func:`~ehrapy.preprocessing.encode`.
Expand All @@ -296,13 +356,15 @@ def robust_scale_norm(
**kwargs: Additional arguments passed to the RobustScaler.

Returns:
`None` if `copy=False` and modifies the passed edata, else returns an updated AnnData object. Also stores a record of applied normalizations as a dictionary in edata.uns["normalization"].
`None` if `copy=False` and modifies the passed edata, else returns an updated edata object. Also stores a record of applied normalizations as a dictionary in edata.uns["normalization"].

Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.mimic_2()
>>> edata_norm = ep.pp.robust_scale_norm(edata, copy=True)
>>> # Works automatically with both 2D and 3D data
>>> edata_3d_norm = ep.pp.robust_scale_norm(edata_3d, copy=True)
"""
scale_func = _robust_scale_norm_function(edata.X if layer is None else edata.layers[layer], **kwargs)

Expand Down Expand Up @@ -335,7 +397,6 @@ def _(arr: DaskArray, **kwargs):


@use_ehrdata(deprecated_after="1.0.0")
@function_2D_only()
def quantile_norm(
edata: EHRData | AnnData,
vars: str | Sequence[str] | None = None,
Expand All @@ -350,6 +411,10 @@ def quantile_norm(
see https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.QuantileTransformer.html for details.
If `edata.X` is a Dask Array, functionality is provided by :class:`~dask_ml.preprocessing.QuantileTransformer`, see https://ml.dask.org/modules/generated/dask_ml.preprocessing.QuantileTransformer.html for details.

Supports both 2D and 3D data:
- 2D data: Standard normalization across observations
- 3D data: Per-variable normalization across samples and timestamps

Args:
edata: Central data object. Must already be encoded using :func:`~ehrapy.preprocessing.encode`.
vars: List of the names of the numeric variables to normalize.
Expand All @@ -360,13 +425,15 @@ def quantile_norm(
**kwargs: Additional arguments passed to the QuantileTransformer.

Returns:
`None` if `copy=False` and modifies the passed edata, else returns an updated data object. Also stores a record of applied normalizations as a dictionary in edata.uns["normalization"].
`None` if `copy=False` and modifies the passed edata, else returns an updated edata object. Also stores a record of applied normalizations as a dictionary in edata.uns["normalization"].

Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.mimic_2()
>>> edata_norm = ep.pp.quantile_norm(edata, copy=True)
>>> # Works automatically with both 2D and 3D data
>>> edata_3d_norm = ep.pp.quantile_norm(edata_3d, copy=True)
"""
scale_func = _quantile_norm_function(edata.X if layer is None else edata.layers[layer], **kwargs)

Expand All @@ -392,7 +459,6 @@ def _(arr: np.ndarray, **kwargs):


@use_ehrdata(deprecated_after="1.0.0")
@function_2D_only()
def power_norm(
edata: EHRData | AnnData,
vars: str | Sequence[str] | None = None,
Expand All @@ -406,6 +472,10 @@ def power_norm(
Functionality is provided by :class:`~sklearn.preprocessing.PowerTransformer`,
see https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.PowerTransformer.html for details.

Supports both 2D and 3D data:
- 2D data: Standard normalization across observations
- 3D data: Per-variable normalization across samples and timestamps

Args:
edata: Central data object.
Must already be encoded using :func:`~ehrapy.preprocessing.encode`.
Expand All @@ -417,13 +487,15 @@ def power_norm(
**kwargs: Additional arguments passed to the PowerTransformer.

Returns:
`None` if `copy=False` and modifies the passed edata, else returns an updated data object. Also stores a record of applied normalizations as a dictionary in edata.uns["normalization"].
`None` if `copy=False` and modifies the passed edata, else returns an updated edata object. Also stores a record of applied normalizations as a dictionary in edata.uns["normalization"].

Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.mimic_2()
>>> edata_norm = ep.pp.power_norm(edata, copy=True)
>>> # Works automatically with both 2D and 3D data
>>> edata_3d_norm = ep.pp.power_norm(edata_3d, copy=True)
"""
scale_func = _power_norm_function(edata.X if layer is None else edata.layers[layer], **kwargs)

Expand All @@ -439,7 +511,6 @@ def power_norm(


@use_ehrdata(deprecated_after="1.0.0")
@function_2D_only()
def log_norm(
edata: EHRData | AnnData,
vars: str | Sequence[str] | None = None,
Expand All @@ -453,6 +524,10 @@ def log_norm(
Computes :math:`x = \\log(x + offset)`, where :math:`log` denotes the natural logarithm
unless a different base is given and the default :math:`offset` is :math:`1`.

Supports both 2D and 3D data:
- 2D data: Standard normalization across observations
- 3D data: Applied to all elements across samples and timestamps

Args:
edata: Central data object.
vars: List of the names of the numeric variables to normalize.
Expand All @@ -463,13 +538,15 @@ def log_norm(
copy: Whether to return a copy or act in place.

Returns:
`None` if `copy=False` and modifies the passed edata, else returns an updated data object. Also stores a record of applied normalizations as a dictionary in edata.uns["normalization"].
`None` if `copy=False` and modifies the passed edata, else returns an updated edata object. Also stores a record of applied normalizations as a dictionary in edata.uns["normalization"].

Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.mimic_2()
>>> edata_norm = ep.pp.log_norm(edata, copy=True)
>>> # Works automatically with both 2D and 3D data
>>> edata_3d_norm = ep.pp.log_norm(edata_3d, copy=True)
"""
if isinstance(vars, str):
vars = [vars]
Expand Down Expand Up @@ -537,20 +614,31 @@ def _record_norm(edata: EHRData | AnnData, vars: Sequence[str], method: str) ->


@use_ehrdata(deprecated_after="1.0.0")
@function_2D_only()
def offset_negative_values(edata: EHRData | AnnData, layer: str = None, copy: bool = False) -> EHRData | AnnData | None:
"""Offsets negative values into positive ones with the lowest negative value becoming 0.

This is primarily used to enable the usage of functions such as log_norm that
do not allow negative values for mathematical or technical reasons.

Supports both 2D and 3D data:
- 2D data: Standard offset across observations
- 3D data: Applied to all elements across samples and timestamps

Args:
edata: Central data object.
layer: The layer to offset.
copy: Whether to return a modified copy of the data object.

Returns:
`None` if `copy=False` and modifies the passed edata, else returns an updated data object.
`None` if `copy=False` and modifies the passed edata, else returns an updated edata object.

Examples:
>>> import ehrdata as ed
>>> import ehrapy as ep
>>> edata = ed.dt.mimic_2()
>>> edata_offset = ep.pp.offset_negative_values(edata, copy=True)
>>> # Works automatically with both 2D and 3D data
>>> edata_3d_offset = ep.pp.offset_negative_values(edata_3d, copy=True)
"""
if copy:
edata = edata.copy()
Expand Down
Loading
Loading