Skip to content

Commit 82a7798

Browse files
Start Preparing train script
1 parent f03c920 commit 82a7798

File tree

8 files changed

+431
-14
lines changed

8 files changed

+431
-14
lines changed

novae_benchmark/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .model import MODEL_DICT, get_model
2+
from .dataset import AnnDataset

novae_benchmark/dataset.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
import numpy as np
3+
import pandas as pd
4+
import scanpy as sc
5+
import anndata
6+
7+
class AnnDataset:
8+
def __init__(self, data_dir, metadata_filename):
9+
self.data_dir = data_dir
10+
self.metadata_file = os.path.join(data_dir, metadata_filename)
11+
self.metadata = pd.read_csv(self.metadata_file)
12+
13+
def load_data(self, tissue_types, use_common_genes=True, multi_slide=False):
14+
anndata_list = []
15+
anndata_list_original = []
16+
17+
for tissue in tissue_types:
18+
files_to_load = self.metadata[self.metadata['tissue'] == tissue]['dataset_name']
19+
20+
for dataset_name in files_to_load:
21+
file_path = os.path.join(self.data_dir, f"{dataset_name}.h5ad")
22+
adata = sc.read_h5ad(file_path)
23+
anndata_list_original.append(adata)
24+
25+
26+
# Convert gene names to lowercase to handle case insensitivity
27+
adata.var.index = adata.var.index.str.lower()
28+
29+
# Add a column to indicate the dataset
30+
adata.obs['dataset'] = dataset_name
31+
32+
anndata_list.append(adata)
33+
34+
if anndata_list:
35+
if use_common_genes:
36+
# Find common genes across all datasets
37+
common_genes = set(anndata_list[0].var.index)
38+
for adata in anndata_list[1:]:
39+
common_genes.intersection_update(adata.var.index)
40+
41+
# Filter each AnnData to include only the common genes
42+
anndata_list = [adata[:, list(common_genes)] for adata in anndata_list]
43+
44+
combined_adata = anndata.concat(
45+
anndata_list,
46+
axis=0,
47+
join='inner',
48+
label='slide_id',
49+
keys=[adata.obs['slide_id'][0] for adata in anndata_list],
50+
pairwise=True
51+
)
52+
53+
else:
54+
combined_adata = None
55+
56+
if multi_slide:
57+
return combined_adata
58+
else:
59+
return anndata_list_original
60+

novae_benchmark/model/GraphST/graphst_model.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def __init__(self,
2929
lamda1 = 10,
3030
lamda2 = 1,
3131
datatype = '10X',
32-
batch_key='slide_name'
32+
batch_key=None,
3333
):
3434
'''\
3535
@@ -92,7 +92,10 @@ def __init__(self,
9292

9393
fix_seed(self.random_seed)
9494

95-
list_adatas = [self.adata[self.adata.obs[batch_key] == b].copy() for b in self.adata.obs[batch_key].unique()]
95+
if batch_key:
96+
list_adatas = [self.adata[self.adata.obs[batch_key] == b].copy() for b in self.adata.obs[batch_key].unique()]
97+
else:
98+
list_adatas = [self.adata]
9699

97100
for adata in list_adatas:
98101
if 'highly_variable' not in adata.var.keys():

novae_benchmark/model/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
from .build import MODEL_DICT, get_model
44
from . import SpaceFlow
55
from . import cluster_utils
6+
from . import eval_utils
67
from . import GraphST

novae_benchmark/model/build.py

+20-12
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from anndata import AnnData
55
from sklearn.decomposition import PCA
66

7-
from . import SEDR, STAGATE, SpaceFlow, cluster_utils, GraphST
7+
from . import SEDR, STAGATE, SpaceFlow, cluster_utils, eval_utils, GraphST
88

99
DEFAULT_N_CLUSTERS = 7
1010
DEFAULT_RADIUS_CLUSTERS = 50
@@ -32,14 +32,18 @@ def train(
3232
"""
3333
raise NotImplementedError
3434

35-
def cluster(self, adata: AnnData, n_clusters: int = DEFAULT_N_CLUSTERS,
36-
radius : int = DEFAULT_RADIUS_CLUSTERS, method: str = "mclust",
37-
pca: bool = False):
35+
def cluster(self, adata: AnnData, n_clusters: int = DEFAULT_N_CLUSTERS,
36+
method: str = "mclust", pca: bool = False):
3837
"""
3938
Clusters the data. The output should be stored in `adata.obs[self.model_name]`.
4039
"""
4140
cluster_utils.clustering(adata=adata, model_name=self.model_name,
42-
n_clusters=n_clusters, method=method, radius=radius, pca=pca)
41+
n_clusters=n_clusters, method=method, pca=pca)
42+
43+
def evaluate(self, adata: AnnData, batch_key: str | None = None, n_clusters: int = DEFAULT_N_CLUSTERS,
44+
n_top_genes: int =3):
45+
self.model_performances = eval_utils.evaluate_latent(adatas=adata, obs_key=self.model_name, slide_key=batch_key,
46+
n_classes=n_clusters, n_top_genes=n_top_genes)
4347

4448

4549
def __call__(
@@ -49,16 +53,27 @@ def __call__(
4953
batch_key: str | None = None,
5054
device: str = "cpu",
5155
fast_dev_run: bool = False,
56+
multi_slide:bool = False,
5257
) -> tuple[np.ndarray, pd.Series]:
5358
"""
5459
Runs all steps, i.e preprocessing -> training -> inference -> clustering.
5560
5661
Returns:
5762
A numpy array of shape (n_cells, hidden_dim) and a pandas Series with the cluster labels.
5863
"""
64+
print("--------------- {}: Preprocessing Started-------------------".format(self.model_name))
5965
self.preprocess(adata)
66+
print("--------------- {}: Preprocessing Finished-------------------".format(self.model_name))
67+
print("--------------- {}: Training Started-------------------".format(self.model_name))
6068
self.train(adata, batch_key=batch_key, device=device, fast_dev_run=fast_dev_run)
69+
print("--------------- {}: Training Finished-------------------".format(self.model_name))
70+
print("--------------- {}: Clustering Started-------------------".format(self.model_name))
6171
self.cluster(adata, n_clusters)
72+
print("--------------- {}: Clustering Finished-------------------".format(self.model_name))
73+
self.evaluate(adata, batch_key, n_clusters)
74+
print("--------------- {}: Evaluation completed-------------------".format(self.model_name))
75+
print(self.model_performances)
76+
6277

6378
adata.obs[self.model_name] = adata.obs[self.model_name].astype("category")
6479

@@ -88,10 +103,6 @@ def train(self, adata: AnnData, batch_key: str | None = None, device: str = "cpu
88103
adata, key_added=self.model_name, device=device, n_epochs=2 if fast_dev_run else 1000
89104
)
90105

91-
#def cluster(self, adata: AnnData, n_clusters: int = DEFAULT_N_CLUSTERS):
92-
# STAGATE.mclust_R(adata, used_obsm=self.model_name, num_cluster=n_clusters)
93-
# adata.obs[self.model_name] = adata.obs["mclust"]
94-
95106

96107
class SEDRModel(Model):
97108
def preprocess(self, adata: AnnData):
@@ -102,8 +113,6 @@ def preprocess(self, adata: AnnData):
102113
adata_X = PCA(n_components=200, random_state=42).fit_transform(adata.X)
103114
adata.obsm["X_pca"] = adata_X
104115

105-
#def cluster(self, adata: AnnData, n_clusters: int):
106-
# SEDR.mclust_R(adata, n_clusters, use_rep=self.model_name, key_added=self.model_name)
107116

108117
def train(self, adata: AnnData, batch_key: str | None = None, device: str = "cpu", fast_dev_run: bool = False):
109118
graph_dict = SEDR.graph_construction(adata, 6)
@@ -135,7 +144,6 @@ class GraphSTModel(Model):
135144
def train(self, adata: AnnData, batch_key: str | None = None, device: str = "cpu", fast_dev_run: bool = False):
136145
graphst_net = GraphST.Graphst(adata=adata, device=device, epochs=2 if fast_dev_run else 1000)
137146
adata = graphst_net.train()
138-
139147

140148

141149
MODEL_DICT = {

novae_benchmark/model/eval_utils.py

+174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
from __future__ import annotations
2+
3+
import numpy as np
4+
import scanpy as sc
5+
from anndata import AnnData
6+
from sklearn import metrics
7+
8+
9+
ALL_METRICS = ["FIDE", "JSD", "SVG"]
10+
ADJ = "spatial_distances"
11+
EPS = 1e-8
12+
13+
14+
def mean_fide_score(
15+
adatas: AnnData | list[AnnData], obs_key: str, slide_key: str = None, n_classes: int | None = None
16+
) -> float:
17+
"""Mean FIDE score over all slides. A low score indicates a great domain continuity.
18+
19+
Args:
20+
adatas: An `AnnData` object, or a list of `AnnData` objects.
21+
{obs_key}
22+
{slide_key}
23+
n_classes: Optional number of classes. This can be useful if not all classes are predicted, for a fair comparision.
24+
25+
Returns:
26+
The FIDE score averaged for all slides.
27+
"""
28+
return np.mean(
29+
[fide_score(adata, obs_key, n_classes=n_classes) for adata in _iter_uid(adatas, slide_key=slide_key)]
30+
)
31+
32+
33+
def fide_score(adata: AnnData, obs_key: str, n_classes: int | None = None) -> float:
34+
"""F1-score of intra-domain edges (FIDE). A high score indicates a great domain continuity.
35+
36+
Note:
37+
The F1-score is computed for every class, then all F1-scores are averaged. If some classes
38+
are not predicted, the `n_classes` argument allows to pad with zeros before averaging the F1-scores.
39+
40+
Args:
41+
adata: An `AnnData` object
42+
{obs_key}
43+
n_classes: Optional number of classes. This can be useful if not all classes are predicted, for a fair comparision.
44+
45+
Returns:
46+
The FIDE score.
47+
"""
48+
adata.obs[obs_key] = adata.obs[obs_key].astype("category")
49+
50+
i_left, i_right = adata.obsp[ADJ].nonzero()
51+
classes_left, classes_right = adata.obs.iloc[i_left][obs_key].values, adata.obs.iloc[i_right][obs_key].values
52+
53+
where_valid = ~classes_left.isna() & ~classes_right.isna()
54+
classes_left, classes_right = classes_left[where_valid], classes_right[where_valid]
55+
56+
f1_scores = metrics.f1_score(classes_left, classes_right, average=None)
57+
58+
if n_classes is None:
59+
return f1_scores.mean()
60+
61+
assert n_classes >= len(f1_scores), f"Expected {n_classes:=}, but found {len(f1_scores)}, which is greater"
62+
63+
return np.pad(f1_scores, (0, n_classes - len(f1_scores))).mean()
64+
65+
66+
def jensen_shannon_divergence(adatas: AnnData | list[AnnData], obs_key: str, slide_key: str = None) -> float:
67+
"""Jensen-Shannon divergence (JSD) over all slides
68+
69+
Args:
70+
adatas: One or a list of AnnData object(s)
71+
{obs_key}
72+
{slide_key}
73+
74+
Returns:
75+
The Jensen-Shannon divergence score for all slides
76+
"""
77+
distributions = [
78+
adata.obs[obs_key].value_counts(sort=False).values
79+
for adata in _iter_uid(adatas, slide_key=slide_key, obs_key=obs_key)
80+
]
81+
82+
return _jensen_shannon_divergence(np.array(distributions))
83+
84+
85+
def mean_svg_score(adata: AnnData | list[AnnData], obs_key: str, slide_key: str = None, n_top_genes: int = 3) -> float:
86+
"""Mean SVG score over all slides. A high score indicates better niche-specific genes, or spatial variable genes.
87+
88+
Args:
89+
adata: An `AnnData` object, or a list.
90+
{obs_key}
91+
{slide_key}
92+
{n_top_genes}
93+
94+
Returns:
95+
The mean SVG score accross all slides.
96+
"""
97+
return np.mean(
98+
[svg_score(adata, obs_key, n_top_genes=n_top_genes) for adata in _iter_uid(adata, slide_key=slide_key)]
99+
)
100+
101+
102+
def svg_score(adata: AnnData, obs_key: str, n_top_genes: int = 3) -> float:
103+
"""Average score of the top differentially expressed genes for each niche.
104+
105+
Args:
106+
adata: An `AnnData` object
107+
{obs_key}
108+
{n_top_genes}
109+
110+
Returns:
111+
The average SVG score.
112+
"""
113+
sc.tl.rank_genes_groups(adata, groupby=obs_key)
114+
sub_recarray: np.recarray = adata.uns["rank_genes_groups"]["scores"][:n_top_genes]
115+
return np.mean([sub_recarray[field].mean() for field in sub_recarray.dtype.names])
116+
117+
118+
def _jensen_shannon_divergence(distributions: np.ndarray) -> float:
119+
"""Compute the Jensen-Shannon divergence (JSD) for a multiple probability distributions.
120+
121+
The lower the score, the better distribution of clusters among the different batches.
122+
123+
Args:
124+
distributions: An array of shape (B x C), where B is the number of batches, and C is the number of clusters. For each batch, it contains the percentage of each cluster among cells.
125+
126+
Returns:
127+
A float corresponding to the JSD
128+
"""
129+
distributions = distributions / distributions.sum(1)[:, None]
130+
mean_distribution = np.mean(distributions, 0)
131+
132+
return _entropy(mean_distribution) - np.mean([_entropy(dist) for dist in distributions])
133+
134+
135+
def _entropy(distribution: np.ndarray) -> float:
136+
"""Shannon entropy
137+
138+
Args:
139+
distribution: An array of probabilities (should sum to one)
140+
141+
Returns:
142+
The Shannon entropy
143+
"""
144+
return -(distribution * np.log(distribution + EPS)).sum()
145+
146+
147+
def _iter_uid(adatas: AnnData | list[AnnData], slide_key: str | None = None, obs_key: str | None = None):
148+
if isinstance(adatas, AnnData):
149+
adatas = [adatas]
150+
151+
if obs_key is not None:
152+
categories = set.union(*[set(adata.obs[obs_key].astype("category").cat.categories) for adata in adatas])
153+
for adata in adatas:
154+
adata.obs[obs_key] = adata.obs[obs_key].astype("category").cat.set_categories(categories)
155+
156+
for adata in adatas:
157+
if slide_key is not None:
158+
for slide_id in adata.obs[slide_key].unique():
159+
yield adata[adata.obs[slide_key] == slide_id].copy()
160+
else:
161+
yield adata
162+
163+
164+
def evaluate_latent(adatas: AnnData | list[AnnData],
165+
obs_key: str, slide_key: str = None,
166+
n_classes: int | None = None, n_top_genes: int = 3):
167+
eval_dt = {}
168+
eval_dt["FIDE"] = mean_fide_score(adatas=adatas, obs_key=obs_key, slide_key=slide_key, n_classes=n_classes)
169+
eval_dt["JSD"] = jensen_shannon_divergence(adatas=adatas, obs_key=obs_key, slide_key=slide_key)
170+
eval_dt["SVG"] = mean_svg_score(adata=adatas, obs_key=obs_key, slide_key=slide_key, n_top_genes=n_top_genes)
171+
return eval_dt
172+
173+
174+

scripts/train.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from novae_benchmark import get_model, AnnDataset
2+
3+
4+
def train(tissue_types, model_names, data_dir='../../data/spatial', metadata_filename='metadata_2024_06_21.csv', hidden_dim=64, multi_slide=False):
5+
dataset = AnnDataset(data_dir=data_dir, metadata_filename=metadata_filename)
6+
adataset = dataset.load_data(tissue_types=tissue_types)
7+
8+
for model_name in model_names:
9+
model = get_model(model_name=model_names, hidden_dim=hidden_dim)
10+

0 commit comments

Comments
 (0)