diff --git a/docs/references.bib b/docs/references.bib index ee93da45..ffecc034 100644 --- a/docs/references.bib +++ b/docs/references.bib @@ -79,8 +79,6 @@ @misc{klein:23 year = {2024} } - - @article{lin:2023, title={Evolutionary-scale prediction of atomic-level protein structure with a language model}, author={Lin, Zeming and Akin, Halil and Rao, Roshan and Hie, Brian and Zhu, Zhongkai and Lu, Wenting and Smetanin, Nikita and Verkuil, Robert and Kabeli, Ori and Shmueli, Yaniv and others}, @@ -91,3 +89,34 @@ @article{lin:2023 year={2023}, publisher={American Association for the Advancement of Science} } + +@InProceedings{feydy:19, + title = {Interpolating between Optimal Transport and MMD using Sinkhorn Divergences}, + author = {Feydy, Jean and S\'{e}journ\'{e}, Thibault and Vialard, Fran\c{c}ois-Xavier and Amari, Shun-ichi and Trouve, Alain and Peyr\'{e}, Gabriel}, + booktitle = {Proceedings of the Twenty-Second International Conference on Artificial Intelligence and Statistics}, + pages = {2681--2690}, + year = {2019}, + editor = {Chaudhuri, Kamalika and Sugiyama, Masashi}, + volume = {89}, + series = {Proceedings of Machine Learning Research}, + month = {16--18 Apr}, + publisher = {PMLR}, + pdf = {http://proceedings.mlr.press/v89/feydy19a/feydy19a.pdf}, + url = {https://proceedings.mlr.press/v89/feydy19a.html}, + abstract = {Comparing probability distributions is a fundamental problem in data sciences. Simple norms and divergences such as the total variation and the relative entropy only compare densities in a point-wise manner and fail to capture the geometric nature of the problem. In sharp contrast, Maximum Mean Discrepancies (MMD) and Optimal Transport distances (OT) are two classes of distances between measures that take into account the geometry of the underlying space and metrize the convergence in law. This paper studies the Sinkhorn divergences, a family of geometric divergences that interpolates between MMD and OT. Relying on a new notion of geometric entropy, we provide theoretical guarantees for these divergences: positivity, convexity and metrization of the convergence in law. On the practical side, we detail a numerical scheme that enables the large scale application of these divergences for machine learning: on the GPU, gradients of the Sinkhorn loss can be computed for batches of a million samples.} +} + +@article{Peidli2024, + title = {scPerturb: harmonized single-cell perturbation data}, + volume = {21}, + ISSN = {1548-7105}, + url = {http://dx.doi.org/10.1038/s41592-023-02144-y}, + DOI = {10.1038/s41592-023-02144-y}, + number = {3}, + journal = {Nature Methods}, + publisher = {Springer Science and Business Media LLC}, + author = {Peidli, Stefan and Green, Tessa D. and Shen, Ciyue and Gross, Torsten and Min, Joseph and Garda, Samuele and Yuan, Bo and Schumacher, Linus J. and Taylor-King, Jake P. and Marks, Debora S. and Luna, Augustin and Bl\"{u}thgen, Nils and Sander, Chris}, + year = {2024}, + month = jan, + pages = {531–540} +} \ No newline at end of file diff --git a/docs/user/index.rst b/docs/user/index.rst index e0d4d0a7..af12c42e 100644 --- a/docs/user/index.rst +++ b/docs/user/index.rst @@ -10,6 +10,7 @@ User API solvers networks datasets + metrics utils training plotting diff --git a/docs/user/metrics.rst b/docs/user/metrics.rst index 1a5cff90..3f1b558e 100644 --- a/docs/user/metrics.rst +++ b/docs/user/metrics.rst @@ -6,9 +6,9 @@ Metrics :toctree: genapi metrics.compute_metrics - metrics.compute_metrics_fast metrics.compute_mean_metrics metrics.compute_scalar_mmd metrics.compute_r_squared metrics.compute_sinkhorn_div metrics.compute_e_distance + metrics.maximum_mean_discrepancy diff --git a/src/cellflow/metrics/_metrics.py b/src/cellflow/metrics/_metrics.py index 7e35f6ff..180e501d 100644 --- a/src/cellflow/metrics/_metrics.py +++ b/src/cellflow/metrics/_metrics.py @@ -1,10 +1,12 @@ +from collections.abc import Sequence + import jax import numpy as np from jax import numpy as jnp from jax.typing import ArrayLike from ott.geometry import costs, pointcloud from ott.tools.sinkhorn_divergence import sinkhorn_divergence -from sklearn.metrics import pairwise_distances, r2_score +from sklearn.metrics import r2_score from sklearn.metrics.pairwise import rbf_kernel __all__ = [ @@ -21,12 +23,38 @@ def compute_r_squared(x: ArrayLike, y: ArrayLike) -> float: - """Compute the R squared between true (x) and predicted (y)""" + """Compute the R squared score between means of the true (x) and predicted (y) distributions. + + Parameters + ---------- + x + An array of shape [num_samples, num_features]. + y + An array of shape [num_samples, num_features]. + + Returns + ------- + A scalar denoting the R squared score. + """ return r2_score(np.mean(x, axis=0), np.mean(y, axis=0)) def compute_sinkhorn_div(x: ArrayLike, y: ArrayLike, epsilon: float = 1e-2) -> float: - """Compute the Sinkhorn divergence between x and y.""" + """Compute the Sinkhorn divergence between x and y as in :cite:`feydy:19`. + + Parameters + ---------- + x + An array of shape [num_samples, num_features]. + y + An array of shape [num_samples, num_features]. + epsilon + The regularization parameter. + + Returns + ------- + A scalar denoting the sinkhorn divergence value. + """ return float( sinkhorn_divergence( pointcloud.PointCloud, @@ -35,15 +63,27 @@ def compute_sinkhorn_div(x: ArrayLike, y: ArrayLike, epsilon: float = 1e-2) -> f cost_fn=costs.SqEuclidean(), epsilon=epsilon, scale_cost=1.0, - ).divergence[0] + )[0] ) def compute_e_distance(x: ArrayLike, y: ArrayLike) -> float: - """Compute the energy distance as in Peidli et al.""" - sigma_X = pairwise_distances(x, x, metric="sqeuclidean").mean() - sigma_Y = pairwise_distances(y, y, metric="sqeuclidean").mean() - delta = pairwise_distances(x, y, metric="sqeuclidean").mean() + """Compute the energy distance between x and y as in :cite:`Peidli2024`. + + Parameters + ---------- + x + An array of shape [num_samples, num_features]. + y + An array of shape [num_samples, num_features]. + + Returns + ------- + A scalar denoting the energy distance value. + """ + sigma_X = pairwise_squeuclidean(x, x).mean() + sigma_Y = pairwise_squeuclidean(y, y).mean() + delta = pairwise_squeuclidean(x, y).mean() return 2 * delta - sigma_X - sigma_Y @@ -54,15 +94,43 @@ def pairwise_squeuclidean(x: ArrayLike, y: ArrayLike) -> ArrayLike: @jax.jit def compute_e_distance_fast(x: ArrayLike, y: ArrayLike) -> float: - """Compute the energy distance as in Peidli et al.""" - sigma_X = pairwise_squeuclidean(x, x).mean() - sigma_Y = pairwise_squeuclidean(y, y).mean() - delta = pairwise_squeuclidean(x, y).mean() - return 2 * delta - sigma_X - sigma_Y + """Compute the energy distance between x and y as in :cite:`Peidli2024`. + + Parameters + ---------- + x + An array of shape [num_samples, num_features]. + y + An array of shape [num_samples, num_features]. + + Returns + ------- + A scalar denoting the energy distance value. + """ + return compute_e_distance(x, y) def compute_metrics(x: ArrayLike, y: ArrayLike) -> dict[str, float]: - """Compute different metrics for x (true) and y (predicted).""" + """Compute a set of metrics between two distributions x and y. + + Parameters + ---------- + x + An array of shape [num_samples, num_features]. + y + An array of shape [num_samples, num_features]. + + Returns + ------- + A dictionary containing the following computed metrics: + + - the r squared score. + - the sinkhorn divergence with ``epsilon = 1.0``. + - the sinkhorn divergence with ``epsilon = 10.0``. + - the sinkhorn divergence with ``epsilon = 100.0``. + - the energy distance value. + - the mean maximum discrepancy loss + """ metrics = {} metrics["r_squared"] = compute_r_squared(x, y) metrics["sinkhorn_div_1"] = compute_sinkhorn_div(x, y, epsilon=1.0) @@ -74,7 +142,21 @@ def compute_metrics(x: ArrayLike, y: ArrayLike) -> dict[str, float]: def compute_mean_metrics(metrics: dict[str, dict[str, float]], prefix: str = "") -> dict[str, list[float]]: - """Compute the mean value of different metrics.""" + """Compute the mean value of different metrics. + + Parameters + ---------- + metrics + A dictionary where the keys indicate the name of the pertubations and the values are + dictionaries containing computed metrics. + prefix + A string definining the prefix of all metrics in the output dictionary. + + Returns + ------- + A dictionary where the keys indicate the metrics and the values contain the average metric + values over all pertubations. + """ metric_names = list(list(metrics.values())[0].keys()) metric_dict: dict[str, list[float]] = {prefix + met_name: [] for met_name in metric_names} for met in metric_names: @@ -95,16 +177,22 @@ def rbf_kernel_fast(x: ArrayLike, y: ArrayLike, gamma: float) -> ArrayLike: def maximum_mean_discrepancy(x: ArrayLike, y: ArrayLike, gamma: float = 1.0, exact: bool = False) -> float: - """Compute the Maximum Mean Discrepancy (MMD) between two samples: x and y. - - Args: - x: a tensor of shape [num_samples, num_features] - y: a tensor of shape [num_samples, num_features] - exact: a bool + """Compute the Maximum Mean Discrepancy (MMD) between two distributions x and y. + + Parameters + ---------- + x + An array of shape [num_samples, num_features]. + y + An array of shape [num_samples, num_features]. + gamma + Parameter for the rbf kernel. + exact + Use exact or fast rbf kernel. Returns ------- - a scalar denoting the squared maximum mean discrepancy loss. + A scalar denoting the squared maximum mean discrepancy loss. """ kernel = rbf_kernel if exact else rbf_kernel_fast xx = kernel(x, x, gamma) @@ -113,8 +201,22 @@ def maximum_mean_discrepancy(x: ArrayLike, y: ArrayLike, gamma: float = 1.0, exa return xx.mean() + yy.mean() - 2 * xy.mean() -def compute_scalar_mmd(x: ArrayLike, y: ArrayLike, gammas: float | None = None) -> float: - """Compute MMD across different length scales""" +def compute_scalar_mmd(x: ArrayLike, y: ArrayLike, gammas: Sequence[float] | None = None) -> float: + """Compute the Mean Maximum Discrepancy (MMD) across different length scales + + Parameters + ---------- + x + An array of shape [num_samples, num_features]. + y + An array of shape [num_samples, num_features]. + gammas + A sequence of values for the paramater gamma of the rbf kernel. + + Returns + ------- + A scalar denoting the average MMD over all gammas. + """ if gammas is None: gammas = [2, 1, 0.5, 0.1, 0.01, 0.005] mmds = [maximum_mean_discrepancy(x, y, gamma=gamma) for gamma in gammas] # type: ignore[union-attr] @@ -122,7 +224,23 @@ def compute_scalar_mmd(x: ArrayLike, y: ArrayLike, gammas: float | None = None) def compute_metrics_fast(x: ArrayLike, y: ArrayLike) -> dict[str, float]: - """Compute metrics which are fast to compute.""" + """Compute metrics which are fast to compute + + Parameters + ---------- + x + An array of shape [num_samples, num_features]. + y + An array of shape [num_samples, num_features]. + + Returns + ------- + A dictionary containing the following computed metrics: + + - the r squared score. + - the energy distance value. + - the mean maximum discrepancy loss + """ metrics = {} metrics["r_squared"] = compute_r_squared(x, y) metrics["e_distance"] = compute_e_distance(x, y)