Skip to content

Commit 50d18f9

Browse files
fix stagate and sedr
1 parent 556051f commit 50d18f9

File tree

4 files changed

+50
-16
lines changed

4 files changed

+50
-16
lines changed

data/README.md

Whitespace-only changes.

novae_benchmark/model/SEDR/clustering_func.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def mclust_R(adata, n_clusters, use_rep="SEDR", key_added="SEDR", random_seed=20
5050
"""
5151
import os
5252

53-
os.environ["R_HOME"] = "/scbio4/tools/R/R-4.0.3_openblas/R-4.0.3"
53+
# os.environ["R_HOME"] = "/scbio4/tools/R/R-4.0.3_openblas/R-4.0.3"
5454
modelNames = "EEE"
5555

5656
np.random.seed(random_seed)

novae_benchmark/model/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from . import SEDR
22
from . import STAGATE_pyG as STAGATE
3-
from .trainer import MODEL_DICT, get_model
3+
from .build import MODEL_DICT, get_model

novae_benchmark/model/trainer.py novae_benchmark/model/build.py

+48-14
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
from . import SEDR, STAGATE
88

9+
DEFAULT_N_CLUSTERS = 7
10+
911

1012
class Model:
1113
def __init__(self, model_name: str, hidden_dim: int) -> None:
@@ -14,34 +16,64 @@ def __init__(self, model_name: str, hidden_dim: int) -> None:
1416
self.hidden_dim = hidden_dim
1517

1618
def preprocess(self, adata: AnnData):
19+
"""
20+
Preprocess the data before training the model. Raw counts can be found in `adata.layers["count"]`
21+
"""
1722
adata.X = adata.layers["count"]
1823
sc.pp.normalize_total(adata)
1924
sc.pp.log1p(adata)
20-
return adata
2125

22-
def train(self, adata: AnnData, batch_key: str | None, device: str = "cpu") -> None:
26+
def train(
27+
self, adata: AnnData, batch_key: str | None = None, device: str = "cpu", fast_dev_run: bool = False
28+
) -> None:
29+
"""
30+
Train the model. Use `fast_dev_run` to run only a few epochs (for testing purposes).
31+
"""
2332
raise NotImplementedError
2433

2534
def inference(self, adata: AnnData) -> np.ndarray:
35+
"""
36+
Runs inference. The output should be stored in `adata.obsm[self.model_name]`.
37+
"""
2638
assert self.model_name in adata.obsm.keys()
2739

28-
def cluster(self, adata: AnnData, n_clusters: int):
40+
def cluster(self, adata: AnnData, n_clusters: int = DEFAULT_N_CLUSTERS):
41+
"""
42+
Clusters the data. The output should be stored in `adata.obs[self.model_name]`.
43+
"""
2944
raise NotImplementedError
3045

3146
def __call__(
32-
self, adata: AnnData, batch_key: str | None, n_clusters: int, device: str = "cpu"
47+
self,
48+
adata: AnnData,
49+
n_clusters: int = DEFAULT_N_CLUSTERS,
50+
batch_key: str | None = None,
51+
device: str = "cpu",
52+
fast_dev_run: bool = False,
3353
) -> tuple[np.ndarray, pd.Series]:
54+
"""
55+
Runs all steps, i.e preprocessing -> training -> inference -> clustering.
56+
57+
Returns:
58+
A numpy array of shape (n_cells, hidden_dim) and a pandas Series with the cluster labels.
59+
"""
3460
self.preprocess(adata)
35-
self.train(adata, batch_key, device)
61+
self.train(adata, batch_key=batch_key, device=device, fast_dev_run=fast_dev_run)
3662
self.inference(adata)
3763
self.cluster(adata, n_clusters)
64+
65+
adata.obs[self.model_name] = adata.obs[self.model_name].astype("category")
66+
67+
assert adata.obsm[self.model_name].shape[1] == self.hidden_dim
68+
assert len(adata.obs[self.model_name].cat.categories) == n_clusters
69+
3870
return adata.obsm[self.model_name], adata.obs[self.model_name]
3971

4072

4173
class STAGATEModel(Model):
4274
RAD_CUTOFF = 25
4375

44-
def train(self, adata: AnnData, batch_key: str | None, device: str = "cpu"):
76+
def train(self, adata: AnnData, batch_key: str | None = None, device: str = "cpu", fast_dev_run: bool = False):
4577
if batch_key is None:
4678
STAGATE.Cal_Spatial_Net(adata, rad_cutoff=self.RAD_CUTOFF)
4779
else:
@@ -54,12 +86,14 @@ def train(self, adata: AnnData, batch_key: str | None, device: str = "cpu"):
5486
adata.uns["Spatial_Net"] = pd.concat([adata_.uns["Spatial_Net"] for adata_ in adatas])
5587
print("\nConcatenated:", adata)
5688

57-
adata = STAGATE.train_STAGATE(adata, key_added=self.model_name, device=device)
89+
adata = STAGATE.train_STAGATE(
90+
adata, key_added=self.model_name, device=device, n_epochs=2 if fast_dev_run else 1000
91+
)
5892
return adata
5993

60-
def cluster(self, adata: AnnData, n_clusters: int):
94+
def cluster(self, adata: AnnData, n_clusters: int = DEFAULT_N_CLUSTERS):
6195
STAGATE.mclust_R(adata, used_obsm=self.model_name, num_cluster=n_clusters)
62-
adata.obs[self.model_name] = adata.obs["m_clust"]
96+
adata.obs[self.model_name] = adata.obs["mclust"]
6397

6498

6599
class SEDRModel(Model):
@@ -74,15 +108,15 @@ def preprocess(self, adata: AnnData):
74108
def cluster(self, adata: AnnData, n_clusters: int):
75109
SEDR.mclust_R(adata, n_clusters, use_rep=self.model_name, key_added=self.model_name)
76110

77-
def train(self, adata: AnnData, batch_key: str | None, device: str = "cpu"):
111+
def train(self, adata: AnnData, batch_key: str | None = None, device: str = "cpu", fast_dev_run: bool = False):
78112
graph_dict = SEDR.graph_construction(adata, 6)
79113

80-
sedr_net = SEDR.Sedr(adata.obsm["X_pca"], graph_dict)
114+
sedr_net = SEDR.Sedr(adata.obsm["X_pca"], graph_dict, device=device)
81115
using_dec = True
82116
if using_dec:
83-
sedr_net.train_with_dec()
117+
sedr_net.train_with_dec(epochs=2 if fast_dev_run else 200)
84118
else:
85-
sedr_net.train_without_dec()
119+
sedr_net.train_without_dec(epochs=2 if fast_dev_run else 200)
86120
sedr_feat, _, _, _ = sedr_net.process()
87121
adata.obsm[self.model_name] = sedr_feat
88122

@@ -93,7 +127,7 @@ def train(self, adata: AnnData, batch_key: str | None, device: str = "cpu"):
93127
}
94128

95129

96-
def get_model(model_name: str, hidden_dim: int) -> Model:
130+
def get_model(model_name: str, hidden_dim: int = 64) -> Model:
97131
assert model_name in MODEL_DICT.keys()
98132

99133
return MODEL_DICT[model_name](model_name, hidden_dim)

0 commit comments

Comments
 (0)