-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
458 lines (358 loc) · 15.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
import importlib
import io
import logging
import os
import shutil
import sys
import uuid
import h5py
import matplotlib.pyplot as plt
import numpy as np
import scipy.sparse as sparse
import torch
from PIL import Image
# from sklearn.decomposition import PCA
plt.ioff()
plt.switch_backend('agg')
def save_checkpoint(state, is_best, checkpoint_dir, logger=None):
"""Saves model and training parameters at '{checkpoint_dir}/last_checkpoint.pytorch'.
If is_best==True saves '{checkpoint_dir}/best_checkpoint.pytorch' as well.
Args:
state (dict): contains model's state_dict, optimizer's state_dict, epoch
and best evaluation metric value so far
is_best (bool): if True state contains the best model seen so far
checkpoint_dir (string): directory where the checkpoint are to be saved
"""
def log_info(message):
if logger is not None:
logger.info(message)
if not os.path.exists(checkpoint_dir):
log_info(
f"Checkpoint directory does not exists. Creating {checkpoint_dir}")
os.mkdir(checkpoint_dir)
last_file_path = os.path.join(checkpoint_dir, 'last_checkpoint.pytorch')
log_info(f"Saving last checkpoint to '{last_file_path}'")
torch.save(state, last_file_path)
if is_best:
best_file_path = os.path.join(checkpoint_dir, 'best_checkpoint.pytorch')
log_info(f"Saving best checkpoint to '{best_file_path}'")
shutil.copyfile(last_file_path, best_file_path)
def load_checkpoint(checkpoint_path, model, optimizer=None):
"""Loads model and training parameters from a given checkpoint_path
If optimizer is provided, loads optimizer's state_dict of as well.
Args:
checkpoint_path (string): path to the checkpoint to be loaded
model (torch.nn.Module): model into which the parameters are to be copied
optimizer (torch.optim.Optimizer) optional: optimizer instance into
which the parameters are to be copied
Returns:
state
"""
if not os.path.exists(checkpoint_path):
raise IOError(f"Checkpoint '{checkpoint_path}' does not exist")
state = torch.load(checkpoint_path)
model.load_state_dict(state['model_state_dict'])
if optimizer is not None:
optimizer.load_state_dict(state['optimizer_state_dict'])
return state
def save_network_output(output_path, output, logger=None):
if logger is not None:
logger.info(f'Saving network output to: {output_path}...')
output = output.detach().cpu()[0]
with h5py.File(output_path, 'w') as f:
f.create_dataset('predictions', data=output, compression='gzip')
def get_logger(name, level=logging.INFO):
logger = logging.getLogger(name)
logger.setLevel(level)
# Logging to console
stream_handler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
'%(asctime)s [%(threadName)s] %(levelname)s %(name)s - %(message)s')
stream_handler.setFormatter(formatter)
logger.addHandler(stream_handler)
return logger
def get_number_of_learnable_parameters(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
return sum([np.prod(p.size()) for p in model_parameters])
class RunningAverage:
"""Computes and stores the average
"""
def __init__(self):
self.count = 0
self.sum = 0
self.avg = 0
def update(self, value, n=1):
self.count += n
self.sum += value * n
self.avg = self.sum / self.count
def find_maximum_patch_size(model, device):
"""Tries to find the biggest patch size that can be send to GPU for inference
without throwing CUDA out of memory"""
logger = get_logger('PatchFinder')
in_channels = model.in_channels
patch_shapes = [(64, 128, 128), (96, 128, 128),
(64, 160, 160), (96, 160, 160),
(64, 192, 192), (96, 192, 192)]
for shape in patch_shapes:
# generate random patch of a given size
patch = np.random.randn(*shape).astype('float32')
patch = torch \
.from_numpy(patch) \
.view((1, in_channels) + patch.shape) \
.to(device)
logger.info(f"Current patch size: {shape}")
model(patch)
def unpad(patch, index, shape, pad_width=4):
"""
Remove `pad_width` voxels around the edges of a given patch.
"""
def _new_slices(slicing, max_size):
if slicing.start == 0:
p_start = 0
i_start = 0
else:
p_start = pad_width
i_start = slicing.start + pad_width
if slicing.stop == max_size:
p_stop = None
i_stop = max_size
else:
p_stop = -pad_width
i_stop = slicing.stop - pad_width
return slice(p_start, p_stop), slice(i_start, i_stop)
D, H, W = shape
i_c, i_z, i_y, i_x = index
p_c = slice(0, patch.shape[0])
p_z, i_z = _new_slices(i_z, D)
p_y, i_y = _new_slices(i_y, H)
p_x, i_x = _new_slices(i_x, W)
patch_index = (p_c, p_z, p_y, p_x)
index = (i_c, i_z, i_y, i_x)
return patch[patch_index], index
def create_feature_maps(init_channel_number, number_of_fmaps):
return [init_channel_number * 2 ** k for k in range(number_of_fmaps)]
# Code taken from https://github.com/cremi/cremi_python
def adapted_rand(seg, gt, all_stats=False):
"""Compute Adapted Rand error as defined by the SNEMI3D contest [1]
Formula is given as 1 - the maximal F-score of the Rand index
(excluding the zero component of the original labels). Adapted
from the SNEMI3D MATLAB script, hence the strange style.
Parameters
----------
seg : np.ndarray
the segmentation to score, where each value is the label at that point
gt : np.ndarray, same shape as seg
the groundtruth to score against, where each value is a label
all_stats : boolean, optional
whether to also return precision and recall as a 3-tuple with rand_error
Returns
-------
are : float
The adapted Rand error; equal to $1 - \frac{2pr}{p + r}$,
where $p$ and $r$ are the precision and recall described below.
prec : float, optional
The adapted Rand precision. (Only returned when `all_stats` is ``True``.)
rec : float, optional
The adapted Rand recall. (Only returned when `all_stats` is ``True``.)
References
----------
[1]: http://brainiac2.mit.edu/SNEMI3D/evaluation
"""
# just to prevent division by 0
epsilon = 1e-6
# segA is truth, segB is query
segA = np.ravel(gt)
segB = np.ravel(seg)
n = segA.size
n_labels_A = np.amax(segA) + 1
n_labels_B = np.amax(segB) + 1
ones_data = np.ones(n)
p_ij = sparse.csr_matrix((ones_data, (segA[:], segB[:])), shape=(n_labels_A, n_labels_B))
a = p_ij[1:n_labels_A, :]
b = p_ij[1:n_labels_A, 1:n_labels_B]
c = p_ij[1:n_labels_A, 0].todense()
d = b.multiply(b)
a_i = np.array(a.sum(1))
b_i = np.array(b.sum(0))
sumA = np.sum(a_i * a_i)
sumB = np.sum(b_i * b_i) + (np.sum(c) / n)
sumAB = np.sum(d) + (np.sum(c) / n)
precision = sumAB / max(sumB, epsilon)
recall = sumAB / max(sumA, epsilon)
fScore = 2.0 * precision * recall / max(precision + recall, epsilon)
are = 1.0 - fScore
if all_stats:
return are, precision, recall
else:
return are
class _TensorboardFormatter:
"""
Tensorboard formatters converts a given batch of images (be it input/output to the network or the target segmentation
image) to a series of images that can be displayed in tensorboard. This is the parent class for all tensorboard
formatters which ensures that returned images are in the 'CHW' format.
"""
def __init__(self, **kwargs):
pass
def __call__(self, name, batch):
"""
Transform a batch to a series of tuples of the form (tag, img), where `tag` corresponds to the image tag
and `img` is the image itself.
Args:
name (str): one of 'inputs'/'targets'/'predictions'
batch (torch.tensor): 4D or 5D torch tensor
"""
def _check_img(tag_img):
tag, img = tag_img
assert img.ndim == 2 or img.ndim == 3, 'Only 2D (HW) and 3D (CHW) images are accepted for display'
if img.ndim == 2:
img = np.expand_dims(img, axis=0)
else:
C = img.shape[0]
assert C == 1 or C == 3, 'Only (1, H, W) or (3, H, W) images are supported'
return tag, img
assert name in ['inputs', 'targets', 'predictions']
tagged_images = self.process_batch(name, batch)
return list(map(_check_img, tagged_images))
def process_batch(self, name, batch):
raise NotImplementedError
class DefaultTensorboardFormatter(_TensorboardFormatter):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def process_batch(self, name, batch):
tag_template = '{}/batch_{}/channel_{}/slice_{}'
tagged_images = []
if batch.ndim == 5:
# NCDHW
slice_idx = batch.shape[2] // 2 # get the middle slice
for batch_idx in range(batch.shape[0]):
for channel_idx in range(batch.shape[1]):
tag = tag_template.format(name, batch_idx, channel_idx, slice_idx)
img = batch[batch_idx, channel_idx, slice_idx, ...]
tagged_images.append((tag, self._normalize_img(img)))
else:
# batch hafrom sklearn.decomposition import PCAs no channel dim: NDHW
slice_idx = batch.shape[1] // 2 # get the middle slice
for batch_idx in range(batch.shape[0]):
tag = tag_template.format(name, batch_idx, 0, slice_idx)
img = batch[batch_idx, slice_idx, ...]
tagged_images.append((tag, self._normalize_img(img)))
return tagged_images
@staticmethod
def _normalize_img(img):
return np.nan_to_num((img - np.min(img)) / np.ptp(img))
class EmbeddingsTensorboardFormatter(DefaultTensorboardFormatter):
def __init__(self, plot_variance=False, **kwargs):
super().__init__(**kwargs)
self.plot_variance = plot_variance
def process_batch(self, name, batch):
if name == 'inputs':
assert batch.ndim == 5
# skip coordinate channels and take only the first 'raw' channel
batch = batch[:, 0, ...]
return super().process_batch(name, batch)
elif name == 'predictions':
return self._embeddings_to_rgb(batch)
else:
return super().process_batch(name, batch)
def _embeddings_to_rgb(self, batch):
assert batch.ndim == 5
tag_template = 'embeddings/batch_{}/slice_{}'
tagged_images = []
slice_idx = batch.shape[2] // 2 # get the middle slice
for batch_idx in range(batch.shape[0]):
tag = tag_template.format(batch_idx, slice_idx)
img = batch[batch_idx, :, slice_idx, ...] # CHW
rgb_img = self._pca_project(img)
tagged_images.append((tag, rgb_img))
if self.plot_variance:
cum_explained_variance_img = self._plot_cum_explained_variance(img)
tagged_images.append((f'cumulative_explained_variance/batch_{batch_idx}', cum_explained_variance_img))
return tagged_images
# def _pca_project(self, embeddings):
# assert embeddings.ndim == 3
# # reshape (C, H, W) -> (C, H * W) and transpose
# flattened_embeddings = embeddings.reshape(embeddings.shape[0], -1).transpose()
# # init PCA with 3 principal components: one for each RGB channel
# pca = PCA(n_components=3)
# # fit the model with embeddings and apply the dimensionality reduction
# flattened_embeddings = pca.fit_transform(flattened_embeddings)
# # reshape back to original
# shape = list(embeddings.shape)
# shape[0] = 3
# img = flattened_embeddings.transpose().reshape(shape)
# # normalize to [0, 255]
# img = 255 * (img - np.min(img)) / np.ptp(img)
# return img.astype('uint8')
# def _plot_cum_explained_variance(self, embeddings):
# # reshape (C, H, W) -> (C, H * W) and transpose
# flattened_embeddings = embeddings.reshape(embeddings.shape[0], -1).transpose()
# # fit PCA to the data
# pca = PCA().fit(flattened_embeddings)
#
# plt.figure()
# # plot cumulative explained variance ratio
# plt.plot(np.cumsum(pca.explained_variance_ratio_))
# plt.xlabel('number of components')
# plt.ylabel('cumulative explained variance');
# buf = io.BytesIO()
# plt.savefig(buf, format='jpeg')
# buf.seek(0)
# img = np.asarray(Image.open(buf)).transpose(2, 0, 1)
# plt.close('all')
# return img
def get_tensorboard_formatter(config):
if config is None:
return DefaultTensorboardFormatter()
class_name = config['name']
m = importlib.import_module('unet3d.utils')
clazz = getattr(m, class_name)
return clazz(**config)
def expand_as_one_hot(input, C, ignore_index=None):
"""
Converts NxDxHxW label image to NxCxDxHxW, where each label gets converted to its corresponding one-hot vector
:param input: 4D input image (NxDxHxW)
:param C: number of channels/labels
:param ignore_index: ignore index to be kept during the expansion
:return: 5D output image (NxCxDxHxW)
"""
assert input.dim() == 4
# expand the input tensor to Nx1xDxHxW before scattering
input = input.unsqueeze(1)
# create result tensor shape (NxCxDxHxW)
shape = list(input.size())
shape[1] = C
if ignore_index is not None:
# create ignore_index mask for the result
mask = input.expand(shape) == ignore_index
# clone the src_res tensor and zero out ignore_index in the input
input = input.clone()
input[input == ignore_index] = 0
# scatter to get the one-hot tensor
result = torch.zeros(shape).to(input.device).scatter_(1, input, 1)
# bring back the ignore_index in the result
result[mask] = ignore_index
return result
else:
# scatter to get the one-hot tensor
return torch.zeros(shape).to(input.device).scatter_(1, input, 1)
def plot_segm(segm, ground_truth, plots_dir='.'):
"""
Saves predicted and ground truth segmentation into a PNG files (one per channel).
:param segm: 4D ndarray (CDHW)
:param ground_truth: 4D ndarray (CDHW)
:param plots_dir: directory where to save the plots
"""
assert segm.ndim == 4
if ground_truth.ndim == 3:
stacked = [ground_truth for _ in range(segm.shape[0])]
ground_truth = np.stack(stacked)
assert ground_truth.ndim == 4
f, axarr = plt.subplots(1, 2)
for seg, gt in zip(segm, ground_truth):
mid_z = seg.shape[0] // 2
axarr[0].imshow(seg[mid_z], cmap='prism')
axarr[0].set_title('Predicted segmentation')
axarr[1].imshow(gt[mid_z], cmap='prism')
axarr[1].set_title('Ground truth segmentation')
file_name = f'segm_{str(uuid.uuid4())[:8]}.png'
plt.savefig(os.path.join(plots_dir, file_name))