Skip to content

Commit 3c58850

Browse files
try lung with leiden
1 parent 135a7f8 commit 3c58850

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

novae/data/dataset.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,21 @@ def _init_dataset(self):
110110

111111
def __len__(self) -> int:
112112
if self.sample_cells is not None:
113-
return min(self.sample_cells, len(self.shuffled_obs_ilocs))
113+
return min(self.sample_cells, self.n_obs)
114114

115115
if self.training:
116-
n_obs = len(self.shuffled_obs_ilocs)
117-
return min(n_obs, max(Nums.MIN_DATASET_LENGTH, int(n_obs * Nums.MAX_DATASET_LENGTH_RATIO)))
116+
return self._clipped_dataset_length()
118117

119118
assert self.single_adata, "Multi-adata mode not supported for inference"
120119

121-
return len(self.obs_ilocs)
120+
return self.n_obs
121+
122+
@property
123+
def n_obs(self):
124+
return sum(adata.n_obs for adata in self.adatas)
125+
126+
def _clipped_dataset_length(self) -> int:
127+
return min(self.n_obs, max(Nums.MIN_DATASET_LENGTH, int(self.n_obs * Nums.MAX_DATASET_LENGTH_RATIO)))
122128

123129
def __getitem__(self, index: int) -> dict[str, Data]:
124130
"""Gets a sample from the dataset, with one "main" graph and its corresponding "view" graph (only during training).

scripts/revision/missing_domains.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,11 @@
66
import novae
77
from novae._constants import Nums
88

9-
Nums.WARMUP_EPOCHS = 4
10-
Nums.WARMUP_ILOCS = 7
9+
Nums.WARMUP_EPOCHS = 3
10+
Nums.WARMUP_ILOCS = 6
1111
Nums.LEVEL_SUBSELECT = 10
1212

13-
suffix = "_sub_select11"
13+
suffix = "_sub_select12"
1414

1515
path = Path("/gpfs/workdir/blampeyq/novae/data/_lung_robustness")
1616

@@ -41,7 +41,8 @@
4141
# model.fine_tune(adatas, min_prototypes_ratio=0.5, reference="largest")
4242
# model.compute_representations(adatas)
4343

44-
obs_key = model.assign_domains(adatas, level=7)
44+
# obs_key = model.assign_domains(adatas, level=7)
45+
obs_key = model.assign_domains(adatas)
4546

4647
model.plot_prototype_weights()
4748
plt.savefig(path / f"prototype_weights{suffix}.pdf", bbox_inches="tight")

0 commit comments

Comments
 (0)