Skip to content

Improve metrics documentation #204

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions docs/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ @misc{klein:23
year = {2024}
}



@article{lin:2023,
title={Evolutionary-scale prediction of atomic-level protein structure with a language model},
author={Lin, Zeming and Akin, Halil and Rao, Roshan and Hie, Brian and Zhu, Zhongkai and Lu, Wenting and Smetanin, Nikita and Verkuil, Robert and Kabeli, Ori and Shmueli, Yaniv and others},
Expand All @@ -91,3 +89,34 @@ @article{lin:2023
year={2023},
publisher={American Association for the Advancement of Science}
}

@InProceedings{feydy:19,
title = {Interpolating between Optimal Transport and MMD using Sinkhorn Divergences},
author = {Feydy, Jean and S\'{e}journ\'{e}, Thibault and Vialard, Fran\c{c}ois-Xavier and Amari, Shun-ichi and Trouve, Alain and Peyr\'{e}, Gabriel},
booktitle = {Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics},
pages = {2681--2690},
year = {2019},
editor = {Chaudhuri, Kamalika and Sugiyama, Masashi},
volume = {89},
series = {Proceedings of Machine Learning Research},
month = {16--18 Apr},
publisher = {PMLR},
pdf = {http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf},
url = {https://proceedings.mlr.press/v89/feydy19a.html},
abstract = {Comparing probability distributions is a fundamental problem in data sciences. Simple norms and divergences such as the total variation and the relative entropy only compare densities in a point-wise manner and fail to capture the geometric nature of the problem. In sharp contrast, Maximum Mean Discrepancies (MMD) and Optimal Transport distances (OT) are two classes of distances between measures that take into account the geometry of the underlying space and metrize the convergence in law. This paper studies the Sinkhorn divergences, a family of geometric divergences that interpolates between MMD and OT. Relying on a new notion of geometric entropy, we provide theoretical guarantees for these divergences: positivity, convexity and metrization of the convergence in law. On the practical side, we detail a numerical scheme that enables the large scale application of these divergences for machine learning: on the GPU, gradients of the Sinkhorn loss can be computed for batches of a million samples.}
}

@article{Peidli2024,
title = {scPerturb: harmonized single-cell perturbation data},
volume = {21},
ISSN = {1548-7105},
url = {http://dx.doi.org/10.1038/s41592-023-02144-y},
DOI = {10.1038/s41592-023-02144-y},
number = {3},
journal = {Nature Methods},
publisher = {Springer Science and Business Media LLC},
author = {Peidli, Stefan and Green, Tessa D. and Shen, Ciyue and Gross, Torsten and Min, Joseph and Garda, Samuele and Yuan, Bo and Schumacher, Linus J. and Taylor-King, Jake P. and Marks, Debora S. and Luna, Augustin and Bl\"{u}thgen, Nils and Sander, Chris},
year = {2024},
month = jan,
pages = {531–540}
}
1 change: 1 addition & 0 deletions docs/user/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ User API
solvers
networks
datasets
metrics
utils
training
plotting
Expand Down
2 changes: 1 addition & 1 deletion docs/user/metrics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ Metrics
:toctree: genapi

metrics.compute_metrics
metrics.compute_metrics_fast
metrics.compute_mean_metrics
metrics.compute_scalar_mmd
metrics.compute_r_squared
metrics.compute_sinkhorn_div
metrics.compute_e_distance
metrics.maximum_mean_discrepancy
168 changes: 143 additions & 25 deletions src/cellflow/metrics/_metrics.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Sequence

import jax
import numpy as np
from jax import numpy as jnp
from jax.typing import ArrayLike
from ott.geometry import costs, pointcloud
from ott.tools.sinkhorn_divergence import sinkhorn_divergence
from sklearn.metrics import pairwise_distances, r2_score
from sklearn.metrics import r2_score
from sklearn.metrics.pairwise import rbf_kernel

__all__ = [
Expand All @@ -21,12 +23,38 @@


def compute_r_squared(x: ArrayLike, y: ArrayLike) -> float:
"""Compute the R squared between true (x) and predicted (y)"""
"""Compute the R squared score between means of the true (x) and predicted (y) distributions.

Parameters
----------
x
An array of shape [num_samples, num_features].
y
An array of shape [num_samples, num_features].

Returns
-------
A scalar denoting the R squared score.
"""
return r2_score(np.mean(x, axis=0), np.mean(y, axis=0))


def compute_sinkhorn_div(x: ArrayLike, y: ArrayLike, epsilon: float = 1e-2) -> float:
"""Compute the Sinkhorn divergence between x and y."""
"""Compute the Sinkhorn divergence between x and y as in :cite:`feydy:19`.

Parameters
----------
x
An array of shape [num_samples, num_features].
y
An array of shape [num_samples, num_features].
epsilon
The regularization parameter.

Returns
-------
A scalar denoting the sinkhorn divergence value.
"""
return float(
sinkhorn_divergence(
pointcloud.PointCloud,
Expand All @@ -35,15 +63,27 @@ def compute_sinkhorn_div(x: ArrayLike, y: ArrayLike, epsilon: float = 1e-2) -> f
cost_fn=costs.SqEuclidean(),
epsilon=epsilon,
scale_cost=1.0,
).divergence[0]
)[0]
)


def compute_e_distance(x: ArrayLike, y: ArrayLike) -> float:
"""Compute the energy distance as in Peidli et al."""
sigma_X = pairwise_distances(x, x, metric="sqeuclidean").mean()
sigma_Y = pairwise_distances(y, y, metric="sqeuclidean").mean()
delta = pairwise_distances(x, y, metric="sqeuclidean").mean()
"""Compute the energy distance between x and y as in :cite:`Peidli2024`.

Parameters
----------
x
An array of shape [num_samples, num_features].
y
An array of shape [num_samples, num_features].

Returns
-------
A scalar denoting the energy distance value.
"""
sigma_X = pairwise_squeuclidean(x, x).mean()
sigma_Y = pairwise_squeuclidean(y, y).mean()
delta = pairwise_squeuclidean(x, y).mean()
return 2 * delta - sigma_X - sigma_Y


Expand All @@ -54,15 +94,43 @@ def pairwise_squeuclidean(x: ArrayLike, y: ArrayLike) -> ArrayLike:

@jax.jit
def compute_e_distance_fast(x: ArrayLike, y: ArrayLike) -> float:
"""Compute the energy distance as in Peidli et al."""
sigma_X = pairwise_squeuclidean(x, x).mean()
sigma_Y = pairwise_squeuclidean(y, y).mean()
delta = pairwise_squeuclidean(x, y).mean()
return 2 * delta - sigma_X - sigma_Y
"""Compute the energy distance between x and y as in :cite:`Peidli2024`.

Parameters
----------
x
An array of shape [num_samples, num_features].
y
An array of shape [num_samples, num_features].

Returns
-------
A scalar denoting the energy distance value.
"""
return compute_e_distance(x, y)


def compute_metrics(x: ArrayLike, y: ArrayLike) -> dict[str, float]:
"""Compute different metrics for x (true) and y (predicted)."""
"""Compute a set of metrics between two distributions x and y.

Parameters
----------
x
An array of shape [num_samples, num_features].
y
An array of shape [num_samples, num_features].

Returns
-------
A dictionary containing the following computed metrics:

- the r squared score.
- the sinkhorn divergence with ``epsilon = 1.0``.
- the sinkhorn divergence with ``epsilon = 10.0``.
- the sinkhorn divergence with ``epsilon = 100.0``.
- the energy distance value.
- the mean maximum discrepancy loss
"""
metrics = {}
metrics["r_squared"] = compute_r_squared(x, y)
metrics["sinkhorn_div_1"] = compute_sinkhorn_div(x, y, epsilon=1.0)
Expand All @@ -74,7 +142,21 @@ def compute_metrics(x: ArrayLike, y: ArrayLike) -> dict[str, float]:


def compute_mean_metrics(metrics: dict[str, dict[str, float]], prefix: str = "") -> dict[str, list[float]]:
"""Compute the mean value of different metrics."""
"""Compute the mean value of different metrics.

Parameters
----------
metrics
A dictionary where the keys indicate the name of the pertubations and the values are
dictionaries containing computed metrics.
prefix
A string definining the prefix of all metrics in the output dictionary.

Returns
-------
A dictionary where the keys indicate the metrics and the values contain the average metric
values over all pertubations.
"""
metric_names = list(list(metrics.values())[0].keys())
metric_dict: dict[str, list[float]] = {prefix + met_name: [] for met_name in metric_names}
for met in metric_names:
Expand All @@ -95,16 +177,22 @@ def rbf_kernel_fast(x: ArrayLike, y: ArrayLike, gamma: float) -> ArrayLike:


def maximum_mean_discrepancy(x: ArrayLike, y: ArrayLike, gamma: float = 1.0, exact: bool = False) -> float:
"""Compute the Maximum Mean Discrepancy (MMD) between two samples: x and y.

Args:
x: a tensor of shape [num_samples, num_features]
y: a tensor of shape [num_samples, num_features]
exact: a bool
"""Compute the Maximum Mean Discrepancy (MMD) between two distributions x and y.

Parameters
----------
x
An array of shape [num_samples, num_features].
y
An array of shape [num_samples, num_features].
gamma
Parameter for the rbf kernel.
exact
Use exact or fast rbf kernel.

Returns
-------
a scalar denoting the squared maximum mean discrepancy loss.
A scalar denoting the squared maximum mean discrepancy loss.
"""
kernel = rbf_kernel if exact else rbf_kernel_fast
xx = kernel(x, x, gamma)
Expand All @@ -113,16 +201,46 @@ def maximum_mean_discrepancy(x: ArrayLike, y: ArrayLike, gamma: float = 1.0, exa
return xx.mean() + yy.mean() - 2 * xy.mean()


def compute_scalar_mmd(x: ArrayLike, y: ArrayLike, gammas: float | None = None) -> float:
"""Compute MMD across different length scales"""
def compute_scalar_mmd(x: ArrayLike, y: ArrayLike, gammas: Sequence[float] | None = None) -> float:
"""Compute the Mean Maximum Discrepancy (MMD) across different length scales

Parameters
----------
x
An array of shape [num_samples, num_features].
y
An array of shape [num_samples, num_features].
gammas
A sequence of values for the paramater gamma of the rbf kernel.

Returns
-------
A scalar denoting the average MMD over all gammas.
"""
if gammas is None:
gammas = [2, 1, 0.5, 0.1, 0.01, 0.005]
mmds = [maximum_mean_discrepancy(x, y, gamma=gamma) for gamma in gammas] # type: ignore[union-attr]
return np.nanmean(np.array(mmds))


def compute_metrics_fast(x: ArrayLike, y: ArrayLike) -> dict[str, float]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we expose it, but not add to the docs? to keep it for backward compatibility ,but it's a bit redundant.

"""Compute metrics which are fast to compute."""
"""Compute metrics which are fast to compute

Parameters
----------
x
An array of shape [num_samples, num_features].
y
An array of shape [num_samples, num_features].

Returns
-------
A dictionary containing the following computed metrics:

- the r squared score.
- the energy distance value.
- the mean maximum discrepancy loss
"""
metrics = {}
metrics["r_squared"] = compute_r_squared(x, y)
metrics["e_distance"] = compute_e_distance(x, y)
Expand Down
Loading