-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
211 lines (179 loc) · 6.68 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import os
import numpy as np
import math
import matplotlib
import sys
import pathlib
import colorama
import inspect
from typing import Dict
def reproducible_random_choice(n: int, k: int):
state = np.random.get_state()
np.random.seed(42)
ii = np.random.choice(n, k, replace=False)
np.random.set_state(state)
return ii
def get_execute_function():
def adjust_path(f):
if os.path.isfile(f):
parent_path = str(pathlib.Path(__file__).parent.absolute()) + "/"
assert f.startswith(
parent_path
), f'f = "{f}", parent_path = "{parent_path}"'
f = f.replace(parent_path, "")
return f
else:
return f
def execute_():
if (
"SPATIALMUON_NOTEBOOK" not in os.environ
and "SPATIALMUON_TEST" not in os.environ
):
return False
else:
if "SPATIALMUON_NOTEBOOK" in os.environ:
assert "SPATIALMUON_TEST" not in os.environ
target = os.environ["SPATIALMUON_NOTEBOOK"]
else:
if sys.gettrace() is None:
matplotlib.use("Agg")
target = os.environ["SPATIALMUON_TEST"]
caller_filename = inspect.stack()[1].filename
caller_filename = adjust_path(caller_filename)
print(f"target = {target}, caller_filename = {caller_filename}")
return (
target == caller_filename
or "SPATIALMUON_NOTEBOOK" in os.environ
and caller_filename.startswith("<ipython-input-")
or "SPATIALMUON_NOTEBOOK" in os.environ
and caller_filename.startswith("/tmp/ipykernel_")
)
return execute_
path_debug = None
def print_and_return(f):
color = colorama.Fore.GREEN if "/a_test/" in f else colorama.Fore.CYAN
# a "color change" means that there is a bug and somewhere get_execute_function() is called both before and
# after 'SPATIALMUON_TEST' or 'SPATIALMUON_NOTEBOOK' are set, leading to mixing test and default paths
global path_debug
if path_debug is None:
path_debug = color
else:
assert path_debug == color
PRINT = False
# PRINT = True
if PRINT:
print(f"{color}{f}{colorama.Fore.RESET}")
return f
try:
current_file_path = pathlib.Path(__file__).parent.absolute()
if not os.path.exists("data"):
raise NameError()
def file_path(f):
if "SPATIALMUON_TEST" not in os.environ:
return print_and_return(
os.path.join(current_file_path, "data/spatial_uzh_processed/a", f)
)
else:
return print_and_return(
os.path.join(current_file_path, "data/spatial_uzh_processed/a_test", f)
)
except NameError:
print("setting data path manually")
def file_path(f):
if "SPATIALMUON_TEST" not in os.environ:
return print_and_return(
os.path.join("/data/l989o/data/basel_zurich/spatial_uzh_processed/a", f)
)
else:
return print_and_return(
os.path.join(
"/data/l989o/data/basel_zurich/spatial_uzh_processed/a_test", f
)
)
def get_bimap(names_length_map):
map_left = {}
map_right = {}
i = 0
for name, length in names_length_map.items():
for j in range(length):
map_left[i] = (name, j)
map_right[(name, j)] = i
i += 1
return map_left, map_right
def print_corrupted_entries_hash(corrupted_entries: np.ndarray, split: str):
h = np.sum(np.concatenate(np.where(corrupted_entries == 1)))
print(f"corrupted entries hash ({split}):", h)
def validate_and_cast_flags(flags: Dict, cast: bool):
validate_type_functions = {
"MODEL_NAME": lambda s: True,
"DATASET_NAME": lambda s: True,
"TILE_SIZE": lambda s: s.isdigit() or s == 'large',
"GRAPH_SIZE": lambda s: s.isdigit(),
"PCA": lambda s: s in ["True", "False"],
}
cast_functions = {
"MODEL_NAME": lambda s: s,
"DATASET_NAME": lambda s: s,
"TILE_SIZE": lambda s: int(s) if s.isdigit() else s,
"GRAPH_SIZE": int,
"PCA": lambda s: s == 'True',
}
validation_functions = {
"MODEL_NAME": lambda s: s in ["expression_vae", "image_expression_vae", "image_expression_pca_vae",
"image_expression_conv_vae", "expression_gnn_vae"],
"DATASET_NAME": lambda s: s
in ['visium_mousebrain', 'visium_endometrium'],
"TILE_SIZE": lambda s: s == 'large' or 32 <= int(s) <= 256,
"GRAPH_SIZE": lambda s: 3 <= int(s) <= 50,
"PCA": lambda s: True,
}
assert set(validate_type_functions.keys()) == set(cast_functions.keys())
assert set(cast_functions.keys()) == set(validation_functions.keys())
for flag in flags.keys():
assert flag in validation_functions
if cast:
for value in flags.values():
assert type(value) == str
for flag, value in flags.items():
validate_type_functions[flag](value)
flags[flag] = cast_functions[flag](value)
for flag, value in flags.items():
assert validation_functions[flag](value)
def parse_flags(default: Dict):
if "SPATIALMUON_FLAGS" in os.environ:
flags = os.environ["SPATIALMUON_FLAGS"]
else:
validate_and_cast_flags(default, cast=False)
flags = ','.join(f'{k}={v}' for k, v in default.items())
os.environ['SPATIALMUON_FLAGS'] = flags
return default
d = {}
flags = flags.split(",")
for flag in flags:
assert flag.count("=") == 1
k, v = flag.split("=")
d[k] = v
# assert set(default.keys()).issuperset(set(d.keys()))
# for k_def, v_def in default.items():
# if k_def not in d:
# d[k_def] = v_def
validate_and_cast_flags(d, cast=True)
return d
def get_splits_indices(n, ratios):
ns = [math.floor(n * ratios[0]), math.ceil(n * ratios[1])]
ns.append(n - np.sum(ns))
train_indices = sorted(reproducible_random_choice(n, ns[0]))
remaining = list(set(range(n)).difference(train_indices))
validation_indices = sorted(
np.array(remaining)[
np.array(reproducible_random_choice(len(remaining), ns[1]))
].tolist()
)
test_indices = sorted(list(set(remaining).difference(validation_indices)))
assert len(set(train_indices) | set(validation_indices) | set(test_indices)) == n
indices = {
"train": train_indices,
"validation": validation_indices,
"test": test_indices,
}
return indices