6
6
7
7
from . import SEDR , STAGATE
8
8
9
+ DEFAULT_N_CLUSTERS = 7
10
+
9
11
10
12
class Model :
11
13
def __init__ (self , model_name : str , hidden_dim : int ) -> None :
@@ -14,34 +16,64 @@ def __init__(self, model_name: str, hidden_dim: int) -> None:
14
16
self .hidden_dim = hidden_dim
15
17
16
18
def preprocess (self , adata : AnnData ):
19
+ """
20
+ Preprocess the data before training the model. Raw counts can be found in `adata.layers["count"]`
21
+ """
17
22
adata .X = adata .layers ["count" ]
18
23
sc .pp .normalize_total (adata )
19
24
sc .pp .log1p (adata )
20
- return adata
21
25
22
- def train (self , adata : AnnData , batch_key : str | None , device : str = "cpu" ) -> None :
26
+ def train (
27
+ self , adata : AnnData , batch_key : str | None = None , device : str = "cpu" , fast_dev_run : bool = False
28
+ ) -> None :
29
+ """
30
+ Train the model. Use `fast_dev_run` to run only a few epochs (for testing purposes).
31
+ """
23
32
raise NotImplementedError
24
33
25
34
def inference (self , adata : AnnData ) -> np .ndarray :
35
+ """
36
+ Runs inference. The output should be stored in `adata.obsm[self.model_name]`.
37
+ """
26
38
assert self .model_name in adata .obsm .keys ()
27
39
28
- def cluster (self , adata : AnnData , n_clusters : int ):
40
+ def cluster (self , adata : AnnData , n_clusters : int = DEFAULT_N_CLUSTERS ):
41
+ """
42
+ Clusters the data. The output should be stored in `adata.obs[self.model_name]`.
43
+ """
29
44
raise NotImplementedError
30
45
31
46
def __call__ (
32
- self , adata : AnnData , batch_key : str | None , n_clusters : int , device : str = "cpu"
47
+ self ,
48
+ adata : AnnData ,
49
+ n_clusters : int = DEFAULT_N_CLUSTERS ,
50
+ batch_key : str | None = None ,
51
+ device : str = "cpu" ,
52
+ fast_dev_run : bool = False ,
33
53
) -> tuple [np .ndarray , pd .Series ]:
54
+ """
55
+ Runs all steps, i.e preprocessing -> training -> inference -> clustering.
56
+
57
+ Returns:
58
+ A numpy array of shape (n_cells, hidden_dim) and a pandas Series with the cluster labels.
59
+ """
34
60
self .preprocess (adata )
35
- self .train (adata , batch_key , device )
61
+ self .train (adata , batch_key = batch_key , device = device , fast_dev_run = fast_dev_run )
36
62
self .inference (adata )
37
63
self .cluster (adata , n_clusters )
64
+
65
+ adata .obs [self .model_name ] = adata .obs [self .model_name ].astype ("category" )
66
+
67
+ assert adata .obsm [self .model_name ].shape [1 ] == self .hidden_dim
68
+ assert len (adata .obs [self .model_name ].cat .categories ) == n_clusters
69
+
38
70
return adata .obsm [self .model_name ], adata .obs [self .model_name ]
39
71
40
72
41
73
class STAGATEModel (Model ):
42
74
RAD_CUTOFF = 25
43
75
44
- def train (self , adata : AnnData , batch_key : str | None , device : str = "cpu" ):
76
+ def train (self , adata : AnnData , batch_key : str | None = None , device : str = "cpu" , fast_dev_run : bool = False ):
45
77
if batch_key is None :
46
78
STAGATE .Cal_Spatial_Net (adata , rad_cutoff = self .RAD_CUTOFF )
47
79
else :
@@ -54,12 +86,14 @@ def train(self, adata: AnnData, batch_key: str | None, device: str = "cpu"):
54
86
adata .uns ["Spatial_Net" ] = pd .concat ([adata_ .uns ["Spatial_Net" ] for adata_ in adatas ])
55
87
print ("\n Concatenated:" , adata )
56
88
57
- adata = STAGATE .train_STAGATE (adata , key_added = self .model_name , device = device )
89
+ adata = STAGATE .train_STAGATE (
90
+ adata , key_added = self .model_name , device = device , n_epochs = 2 if fast_dev_run else 1000
91
+ )
58
92
return adata
59
93
60
- def cluster (self , adata : AnnData , n_clusters : int ):
94
+ def cluster (self , adata : AnnData , n_clusters : int = DEFAULT_N_CLUSTERS ):
61
95
STAGATE .mclust_R (adata , used_obsm = self .model_name , num_cluster = n_clusters )
62
- adata .obs [self .model_name ] = adata .obs ["m_clust " ]
96
+ adata .obs [self .model_name ] = adata .obs ["mclust " ]
63
97
64
98
65
99
class SEDRModel (Model ):
@@ -74,15 +108,15 @@ def preprocess(self, adata: AnnData):
74
108
def cluster (self , adata : AnnData , n_clusters : int ):
75
109
SEDR .mclust_R (adata , n_clusters , use_rep = self .model_name , key_added = self .model_name )
76
110
77
- def train (self , adata : AnnData , batch_key : str | None , device : str = "cpu" ):
111
+ def train (self , adata : AnnData , batch_key : str | None = None , device : str = "cpu" , fast_dev_run : bool = False ):
78
112
graph_dict = SEDR .graph_construction (adata , 6 )
79
113
80
- sedr_net = SEDR .Sedr (adata .obsm ["X_pca" ], graph_dict )
114
+ sedr_net = SEDR .Sedr (adata .obsm ["X_pca" ], graph_dict , device = device )
81
115
using_dec = True
82
116
if using_dec :
83
- sedr_net .train_with_dec ()
117
+ sedr_net .train_with_dec (epochs = 2 if fast_dev_run else 200 )
84
118
else :
85
- sedr_net .train_without_dec ()
119
+ sedr_net .train_without_dec (epochs = 2 if fast_dev_run else 200 )
86
120
sedr_feat , _ , _ , _ = sedr_net .process ()
87
121
adata .obsm [self .model_name ] = sedr_feat
88
122
@@ -93,7 +127,7 @@ def train(self, adata: AnnData, batch_key: str | None, device: str = "cpu"):
93
127
}
94
128
95
129
96
- def get_model (model_name : str , hidden_dim : int ) -> Model :
130
+ def get_model (model_name : str , hidden_dim : int = 64 ) -> Model :
97
131
assert model_name in MODEL_DICT .keys ()
98
132
99
133
return MODEL_DICT [model_name ](model_name , hidden_dim )
0 commit comments