|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import scanpy as sc |
| 5 | +from anndata import AnnData |
| 6 | +from sklearn import metrics |
| 7 | + |
| 8 | + |
| 9 | +ALL_METRICS = ["FIDE", "JSD", "SVG"] |
| 10 | +ADJ = "spatial_distances" |
| 11 | +EPS = 1e-8 |
| 12 | + |
| 13 | + |
| 14 | +def mean_fide_score( |
| 15 | + adatas: AnnData | list[AnnData], obs_key: str, slide_key: str = None, n_classes: int | None = None |
| 16 | +) -> float: |
| 17 | + """Mean FIDE score over all slides. A low score indicates a great domain continuity. |
| 18 | +
|
| 19 | + Args: |
| 20 | + adatas: An `AnnData` object, or a list of `AnnData` objects. |
| 21 | + {obs_key} |
| 22 | + {slide_key} |
| 23 | + n_classes: Optional number of classes. This can be useful if not all classes are predicted, for a fair comparision. |
| 24 | +
|
| 25 | + Returns: |
| 26 | + The FIDE score averaged for all slides. |
| 27 | + """ |
| 28 | + return np.mean( |
| 29 | + [fide_score(adata, obs_key, n_classes=n_classes) for adata in _iter_uid(adatas, slide_key=slide_key)] |
| 30 | + ) |
| 31 | + |
| 32 | + |
| 33 | +def fide_score(adata: AnnData, obs_key: str, n_classes: int | None = None) -> float: |
| 34 | + """F1-score of intra-domain edges (FIDE). A high score indicates a great domain continuity. |
| 35 | +
|
| 36 | + Note: |
| 37 | + The F1-score is computed for every class, then all F1-scores are averaged. If some classes |
| 38 | + are not predicted, the `n_classes` argument allows to pad with zeros before averaging the F1-scores. |
| 39 | +
|
| 40 | + Args: |
| 41 | + adata: An `AnnData` object |
| 42 | + {obs_key} |
| 43 | + n_classes: Optional number of classes. This can be useful if not all classes are predicted, for a fair comparision. |
| 44 | +
|
| 45 | + Returns: |
| 46 | + The FIDE score. |
| 47 | + """ |
| 48 | + adata.obs[obs_key] = adata.obs[obs_key].astype("category") |
| 49 | + |
| 50 | + i_left, i_right = adata.obsp[ADJ].nonzero() |
| 51 | + classes_left, classes_right = adata.obs.iloc[i_left][obs_key].values, adata.obs.iloc[i_right][obs_key].values |
| 52 | + |
| 53 | + where_valid = ~classes_left.isna() & ~classes_right.isna() |
| 54 | + classes_left, classes_right = classes_left[where_valid], classes_right[where_valid] |
| 55 | + |
| 56 | + f1_scores = metrics.f1_score(classes_left, classes_right, average=None) |
| 57 | + |
| 58 | + if n_classes is None: |
| 59 | + return f1_scores.mean() |
| 60 | + |
| 61 | + assert n_classes >= len(f1_scores), f"Expected {n_classes:=}, but found {len(f1_scores)}, which is greater" |
| 62 | + |
| 63 | + return np.pad(f1_scores, (0, n_classes - len(f1_scores))).mean() |
| 64 | + |
| 65 | + |
| 66 | +def jensen_shannon_divergence(adatas: AnnData | list[AnnData], obs_key: str, slide_key: str = None) -> float: |
| 67 | + """Jensen-Shannon divergence (JSD) over all slides |
| 68 | +
|
| 69 | + Args: |
| 70 | + adatas: One or a list of AnnData object(s) |
| 71 | + {obs_key} |
| 72 | + {slide_key} |
| 73 | +
|
| 74 | + Returns: |
| 75 | + The Jensen-Shannon divergence score for all slides |
| 76 | + """ |
| 77 | + distributions = [ |
| 78 | + adata.obs[obs_key].value_counts(sort=False).values |
| 79 | + for adata in _iter_uid(adatas, slide_key=slide_key, obs_key=obs_key) |
| 80 | + ] |
| 81 | + |
| 82 | + return _jensen_shannon_divergence(np.array(distributions)) |
| 83 | + |
| 84 | + |
| 85 | +def mean_svg_score(adata: AnnData | list[AnnData], obs_key: str, slide_key: str = None, n_top_genes: int = 3) -> float: |
| 86 | + """Mean SVG score over all slides. A high score indicates better niche-specific genes, or spatial variable genes. |
| 87 | +
|
| 88 | + Args: |
| 89 | + adata: An `AnnData` object, or a list. |
| 90 | + {obs_key} |
| 91 | + {slide_key} |
| 92 | + {n_top_genes} |
| 93 | +
|
| 94 | + Returns: |
| 95 | + The mean SVG score accross all slides. |
| 96 | + """ |
| 97 | + return np.mean( |
| 98 | + [svg_score(adata, obs_key, n_top_genes=n_top_genes) for adata in _iter_uid(adata, slide_key=slide_key)] |
| 99 | + ) |
| 100 | + |
| 101 | + |
| 102 | +def svg_score(adata: AnnData, obs_key: str, n_top_genes: int = 3) -> float: |
| 103 | + """Average score of the top differentially expressed genes for each niche. |
| 104 | +
|
| 105 | + Args: |
| 106 | + adata: An `AnnData` object |
| 107 | + {obs_key} |
| 108 | + {n_top_genes} |
| 109 | +
|
| 110 | + Returns: |
| 111 | + The average SVG score. |
| 112 | + """ |
| 113 | + sc.tl.rank_genes_groups(adata, groupby=obs_key) |
| 114 | + sub_recarray: np.recarray = adata.uns["rank_genes_groups"]["scores"][:n_top_genes] |
| 115 | + return np.mean([sub_recarray[field].mean() for field in sub_recarray.dtype.names]) |
| 116 | + |
| 117 | + |
| 118 | +def _jensen_shannon_divergence(distributions: np.ndarray) -> float: |
| 119 | + """Compute the Jensen-Shannon divergence (JSD) for a multiple probability distributions. |
| 120 | +
|
| 121 | + The lower the score, the better distribution of clusters among the different batches. |
| 122 | +
|
| 123 | + Args: |
| 124 | + distributions: An array of shape (B x C), where B is the number of batches, and C is the number of clusters. For each batch, it contains the percentage of each cluster among cells. |
| 125 | +
|
| 126 | + Returns: |
| 127 | + A float corresponding to the JSD |
| 128 | + """ |
| 129 | + distributions = distributions / distributions.sum(1)[:, None] |
| 130 | + mean_distribution = np.mean(distributions, 0) |
| 131 | + |
| 132 | + return _entropy(mean_distribution) - np.mean([_entropy(dist) for dist in distributions]) |
| 133 | + |
| 134 | + |
| 135 | +def _entropy(distribution: np.ndarray) -> float: |
| 136 | + """Shannon entropy |
| 137 | +
|
| 138 | + Args: |
| 139 | + distribution: An array of probabilities (should sum to one) |
| 140 | +
|
| 141 | + Returns: |
| 142 | + The Shannon entropy |
| 143 | + """ |
| 144 | + return -(distribution * np.log(distribution + EPS)).sum() |
| 145 | + |
| 146 | + |
| 147 | +def _iter_uid(adatas: AnnData | list[AnnData], slide_key: str | None = None, obs_key: str | None = None): |
| 148 | + if isinstance(adatas, AnnData): |
| 149 | + adatas = [adatas] |
| 150 | + |
| 151 | + if obs_key is not None: |
| 152 | + categories = set.union(*[set(adata.obs[obs_key].astype("category").cat.categories) for adata in adatas]) |
| 153 | + for adata in adatas: |
| 154 | + adata.obs[obs_key] = adata.obs[obs_key].astype("category").cat.set_categories(categories) |
| 155 | + |
| 156 | + for adata in adatas: |
| 157 | + if slide_key is not None: |
| 158 | + for slide_id in adata.obs[slide_key].unique(): |
| 159 | + yield adata[adata.obs[slide_key] == slide_id].copy() |
| 160 | + else: |
| 161 | + yield adata |
| 162 | + |
| 163 | + |
| 164 | +def evaluate_latent(adatas: AnnData | list[AnnData], |
| 165 | + obs_key: str, slide_key: str = None, |
| 166 | + n_classes: int | None = None, n_top_genes: int = 3): |
| 167 | + eval_dt = {} |
| 168 | + eval_dt["FIDE"] = mean_fide_score(adatas=adatas, obs_key=obs_key, slide_key=slide_key, n_classes=n_classes) |
| 169 | + eval_dt["JSD"] = jensen_shannon_divergence(adatas=adatas, obs_key=obs_key, slide_key=slide_key) |
| 170 | + eval_dt["SVG"] = mean_svg_score(adata=adatas, obs_key=obs_key, slide_key=slide_key, n_top_genes=n_top_genes) |
| 171 | + return eval_dt |
| 172 | + |
| 173 | + |
| 174 | + |
0 commit comments