-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy pathutils.py
59 lines (44 loc) · 1.82 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import numpy as np
import os
import torch
import errno
from PIL import Image
def mkdir_p(path):
"""Linux mkdir -p"""
try:
os.makedirs(path)
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
def one_hot(x, K, dtype=torch.float):
"""One hot encoding"""
with torch.no_grad():
ind = torch.zeros(x.shape + (K,), dtype=dtype, device=x.device)
ind.scatter_(-1, x.unsqueeze(-1), 1)
return ind
def save_image_stack(samples, num_rows, num_columns, filename, margin=5, margin_gray_val=1., frame=0, frame_gray_val=0.0):
"""Save image stack in a tiled image"""
# for gray scale, convert to rgb
if len(samples.shape) == 3:
samples = np.stack((samples,) * 3, -1)
height = samples.shape[1]
width = samples.shape[2]
samples -= samples.min()
samples /= samples.max()
img = margin_gray_val * np.ones((height*num_rows + (num_rows-1)*margin, width*num_columns + (num_columns-1)*margin, 3))
for h in range(num_rows):
for w in range(num_columns):
img[h*(height+margin):h*(height+margin)+height, w*(width+margin):w*(width+margin)+width, :] = samples[h*num_columns + w, :]
framed_img = frame_gray_val * np.ones((img.shape[0] + 2*frame, img.shape[1] + 2*frame, 3))
framed_img[frame:(frame+img.shape[0]), frame:(frame+img.shape[1]), :] = img
img = Image.fromarray(np.round(framed_img * 255.).astype(np.uint8))
img.save(filename)
def sample_matrix_categorical(p):
"""Sample many Categorical distributions represented as rows in a matrix."""
with torch.no_grad():
cp = torch.cumsum(p[:, 0:-1], -1)
rand = torch.rand((cp.shape[0], 1), device=cp.device)
rand_idx = torch.sum(rand > cp, -1).long()
return rand_idx