Skip to content

Commit 500039e

Browse files
Adding main script
1 parent 82a7798 commit 500039e

File tree

8 files changed

+257
-62
lines changed

8 files changed

+257
-62
lines changed

config/benchmark_union_cpu.yml

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
dataset:
2+
data_dir: '../../data/spatial'
3+
metadata-file: 'metadata_2024_06_21.csv'
4+
tissues: ['ovarian']
5+
mode: 'union'
6+
7+
params:
8+
model_names: ['SEDR', 'STAGATE', 'SpaceFlow', 'GraphST']
9+
hidden_dim: 64
10+
batch_key: 'ID'
11+
device: 'cpu'
12+
fast_dev_run: True
13+
n_clusters: 7
14+
15+
16+
17+

main.py

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import argparse
2+
import yaml
3+
4+
from novae_benchmark import AnnDataset
5+
from novae_benchmark import get_model
6+
7+
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument('-c', '--config', help='config file to use', type=str)
11+
12+
args = parser.parse_args()
13+
14+
def load_yaml(file_path):
15+
with open(file_path, 'r') as file:
16+
data = yaml.safe_load(file)
17+
return data
18+
19+
config = load_yaml(args.config)
20+
21+
print(config)
22+
23+
if __name__ == "__main__":
24+
print("------------ Loading Dataset ----------------\n")
25+
26+
print('Tissues considered : ')
27+
for tissue in config['dataset']['tissues']:
28+
print('----- ', tissue, '\n')
29+
30+
dataset = AnnDataset(data_dir='../../data/spatial', metadata_filename='metadata_2024_06_21.csv')
31+
adataset = dataset.load_data(tissue_types=[tissue], mode=config['dataset']['mode'])
32+
33+
print("------------ Dataset Loaded ! ----------------\n")
34+
35+
results = {model_name: [] for model_name in config['params']['model_names']}
36+
37+
for model_name in config['params']['model_names']:
38+
print("------------ Loading {} Model ----------------\n".format(model_name))
39+
model = get_model(model_name=model_name, hidden_dim=config['params']['hidden_dim'])
40+
print("------------ Model Loaded ! ----------------\n")
41+
if config['dataset']['mode'] == 'union':
42+
for adata in adataset:
43+
model(adata=adata, n_clusters=config['params']['n_clusters'], batch_key=config['params']['batch_key'],
44+
device=config['params']['device'], fast_dev_run=config['params']['fast_dev_run'])
45+
46+
results[model_name].append(model.model_performances)
47+
else:
48+
model(adata=adataset, n_clusters=config['params']['n_clusters'], batch_key=config['params']['batch_key'],
49+
device=config['params']['device'], fast_dev_run=config['params']['fast_dev_run'])
50+
results[model_name].append(model.model_performances)
51+
52+
print(results)
53+
54+
55+
56+
57+
58+

novae_benchmark/dataset.py

+37-20
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import os
2-
import numpy as np
32
import pandas as pd
43
import scanpy as sc
54
import anndata
@@ -10,7 +9,7 @@ def __init__(self, data_dir, metadata_filename):
109
self.metadata_file = os.path.join(data_dir, metadata_filename)
1110
self.metadata = pd.read_csv(self.metadata_file)
1211

13-
def load_data(self, tissue_types, use_common_genes=True, multi_slide=False):
12+
def load_data(self, tissue_types, mode='union'):
1413
anndata_list = []
1514
anndata_list_original = []
1615

@@ -22,7 +21,6 @@ def load_data(self, tissue_types, use_common_genes=True, multi_slide=False):
2221
adata = sc.read_h5ad(file_path)
2322
anndata_list_original.append(adata)
2423

25-
2624
# Convert gene names to lowercase to handle case insensitivity
2725
adata.var.index = adata.var.index.str.lower()
2826

@@ -32,7 +30,7 @@ def load_data(self, tissue_types, use_common_genes=True, multi_slide=False):
3230
anndata_list.append(adata)
3331

3432
if anndata_list:
35-
if use_common_genes:
33+
if mode=='inter':
3634
# Find common genes across all datasets
3735
common_genes = set(anndata_list[0].var.index)
3836
for adata in anndata_list[1:]:
@@ -41,20 +39,39 @@ def load_data(self, tissue_types, use_common_genes=True, multi_slide=False):
4139
# Filter each AnnData to include only the common genes
4240
anndata_list = [adata[:, list(common_genes)] for adata in anndata_list]
4341

44-
combined_adata = anndata.concat(
45-
anndata_list,
46-
axis=0,
47-
join='inner',
48-
label='slide_id',
49-
keys=[adata.obs['slide_id'][0] for adata in anndata_list],
50-
pairwise=True
51-
)
52-
53-
else:
54-
combined_adata = None
55-
56-
if multi_slide:
57-
return combined_adata
42+
if mode=='inter':
43+
combined_adata = anndata.concat(
44+
anndata_list,
45+
axis=0,
46+
join='inner',
47+
label='slide_id',
48+
keys=[adata.obs['slide_id'][0] for adata in anndata_list],
49+
pairwise=True
50+
)
51+
return combined_adata
52+
else:
53+
# Group by gene panels
54+
gene_panels = {}
55+
for adata in anndata_list:
56+
genes = tuple(sorted(adata.var.index))
57+
if genes not in gene_panels:
58+
gene_panels[genes] = []
59+
gene_panels[genes].append(adata)
60+
61+
# Concatenate within each group
62+
concatenated_adatas = []
63+
for genes, adatas in gene_panels.items():
64+
65+
concatenated_adata = anndata.concat(
66+
adatas,
67+
axis=0,
68+
join='inner',
69+
label='slide_id',
70+
keys=[adata.obs['slide_id'][0] for adata in adatas],
71+
pairwise=True
72+
)
73+
concatenated_adatas.append(concatenated_adata)
74+
75+
return concatenated_adatas
5876
else:
59-
return anndata_list_original
60-
77+
return None

novae_benchmark/model/GraphST/graphst_model.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,10 @@ def __init__(self,
9292

9393
fix_seed(self.random_seed)
9494

95-
if batch_key:
96-
list_adatas = [self.adata[self.adata.obs[batch_key] == b].copy() for b in self.adata.obs[batch_key].unique()]
97-
else:
95+
if batch_key is None:
9896
list_adatas = [self.adata]
97+
else:
98+
list_adatas = [self.adata[self.adata.obs[batch_key] == b].copy() for b in self.adata.obs[batch_key].unique()]
9999

100100
for adata in list_adatas:
101101
if 'highly_variable' not in adata.var.keys():

novae_benchmark/model/SpaceFlow/spaceflow_model.py

+22-12
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch.nn as nn
1212
import matplotlib.pyplot as plt
1313
from scipy.spatial import distance_matrix
14+
from scipy.sparse import block_diag
1415
from torch_geometric.nn import GCNConv, DeepGraphInfomax
1516
from sklearn.neighbors import kneighbors_graph
1617

@@ -115,7 +116,7 @@ def prepare_figure(self, rsz=4., csz=4., wspace=.4, hspace=.5, left=0.125, right
115116
plt.subplots_adjust(wspace=wspace, hspace=hspace, left=left, right=right, bottom=bottom, top=top)
116117
return fig, axs
117118

118-
def preprocessing_data(self, n_top_genes=None, n_neighbors=10):
119+
def preprocessing_data(self, n_top_genes=None, n_neighbors=10, batch_key=None):
119120
"""
120121
Preprocessing the spatial transcriptomics data
121122
Generates: `self.adata_filtered`: (n_cells, n_locations) `numpy.ndarray`
@@ -131,19 +132,28 @@ def preprocessing_data(self, n_top_genes=None, n_neighbors=10):
131132
:return: a geometry-aware spatial proximity graph of the spatial spots of cells
132133
:rtype: class:`scipy.sparse.csr_matrix`
133134
"""
134-
adata = self.adata
135-
if not adata:
135+
adatas = self.adata
136+
sc.pp.normalize_total(adatas, target_sum=1e4)
137+
sc.pp.log1p(adatas)
138+
sc.pp.highly_variable_genes(adatas, n_top_genes=n_top_genes, flavor='cell_ranger', subset=True)
139+
sc.pp.pca(adatas)
140+
if batch_key is None:
141+
list_adatas = [adatas]
142+
else:
143+
list_adatas = [adatas[self.adata.obs[batch_key] == b].copy() for b in adatas.obs[batch_key].unique()]
144+
if not adatas:
136145
print("No annData object found, please run SpaceFlow.SpaceFlow(expr_data, spatial_locs) first!")
137146
return
138-
sc.pp.normalize_total(adata, target_sum=1e4)
139-
sc.pp.log1p(adata)
140-
sc.pp.highly_variable_genes(adata, n_top_genes=n_top_genes, flavor='cell_ranger', subset=True)
141-
sc.pp.pca(adata)
142-
spatial_locs = adata.obsm['spatial']
143-
spatial_graph = self.graph_alpha(spatial_locs, n_neighbors=n_neighbors)
144-
145-
self.adata_preprocessed = adata
146-
self.spatial_graph = spatial_graph
147+
spatial_graphs = []
148+
adatas_preprocessed = []
149+
for adata in list_adatas:
150+
spatial_locs = adata.obsm['spatial']
151+
spatial_graphs.append(self.graph_alpha(spatial_locs, n_neighbors=n_neighbors))
152+
adatas_preprocessed.append(adata)
153+
154+
155+
self.adata_preprocessed = sc.concat(adatas_preprocessed)
156+
self.spatial_graph = block_diag(spatial_graphs, format='csr')
147157

148158
def graph_alpha(self, spatial_locs, n_neighbors=10):
149159
"""

novae_benchmark/model/build.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -53,25 +53,24 @@ def __call__(
5353
batch_key: str | None = None,
5454
device: str = "cpu",
5555
fast_dev_run: bool = False,
56-
multi_slide:bool = False,
5756
) -> tuple[np.ndarray, pd.Series]:
5857
"""
5958
Runs all steps, i.e preprocessing -> training -> inference -> clustering.
6059
6160
Returns:
6261
A numpy array of shape (n_cells, hidden_dim) and a pandas Series with the cluster labels.
6362
"""
64-
print("--------------- {}: Preprocessing Started-------------------".format(self.model_name))
63+
print("--------------- {}: Preprocessing Started-------------------\n".format(self.model_name))
6564
self.preprocess(adata)
66-
print("--------------- {}: Preprocessing Finished-------------------".format(self.model_name))
67-
print("--------------- {}: Training Started-------------------".format(self.model_name))
65+
print("--------------- {}: Preprocessing Finished-------------------\n".format(self.model_name))
66+
print("--------------- {}: Training Started-------------------\n".format(self.model_name))
6867
self.train(adata, batch_key=batch_key, device=device, fast_dev_run=fast_dev_run)
69-
print("--------------- {}: Training Finished-------------------".format(self.model_name))
70-
print("--------------- {}: Clustering Started-------------------".format(self.model_name))
68+
print("--------------- {}: Training Finished-------------------\n".format(self.model_name))
69+
print("--------------- {}: Clustering Started-------------------\n".format(self.model_name))
7170
self.cluster(adata, n_clusters)
72-
print("--------------- {}: Clustering Finished-------------------".format(self.model_name))
71+
print("--------------- {}: Clustering Finished-------------------\n".format(self.model_name))
7372
self.evaluate(adata, batch_key, n_clusters)
74-
print("--------------- {}: Evaluation completed-------------------".format(self.model_name))
73+
print("--------------- {}: Evaluation completed-------------------\n".format(self.model_name))
7574
print(self.model_performances)
7675

7776

@@ -135,7 +134,7 @@ def preprocess(self, adata: AnnData):
135134

136135
def train(self, adata: AnnData, batch_key: str | None = None, device: str = "cpu", fast_dev_run: bool = False):
137136
spaceflow_net = SpaceFlow.Spaceflow(adata=adata)
138-
spaceflow_net.preprocessing_data(n_top_genes=self.N_TOP_GENES)
137+
spaceflow_net.preprocessing_data(n_top_genes=self.N_TOP_GENES, batch_key=batch_key)
139138
spaceflow_embedding = spaceflow_net.train(z_dim=self.hidden_dim, epochs=2 if fast_dev_run else 1000)
140139
adata.obsm[self.model_name] = spaceflow_embedding
141140

scripts/train.py

-10
This file was deleted.

0 commit comments

Comments
 (0)