-
Notifications
You must be signed in to change notification settings - Fork 21
/
Copy patheval_svhn_mixture.py
147 lines (108 loc) · 5.4 KB
/
eval_svhn_mixture.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import numpy as np
import torch
from EinsumNetwork import Graph, EinsumNetwork
from EinsumNetwork.EinetMixture import EinetMixture
import pickle
import os
import datasets
import utils
from PIL import Image
device = 'cpu'
einet_path = '../models/einet/svhn/num_clusters_100'
sample_path = '../samples/svhn'
utils.mkdir_p(sample_path)
height = 32
width = 32
num_clusters = 100
poon_domingos_pieces = [4]
num_sums = 40
input_multiplier = 1
structure = 'poon_domingos_vertical'
exponential_family = EinsumNetwork.NormalArray
exponential_family_args = {'min_var': 1e-6, 'max_var': 0.01}
if input_multiplier == 1:
block_mix_input = None
num_input_distributions = num_sums
else:
block_mix_input = num_sums
num_input_distributions = num_sums * input_multiplier
#######################################################################################
def compute_cluster_means(data, cluster_idx):
unique_idx = np.unique(cluster_idx)
means = np.zeros((len(unique_idx), height, width, 3), dtype=np.float32)
for k in unique_idx:
means[k, ...] = np.mean(data[cluster_idx == k, ...].astype(np.float32), 0)
return means
def compute_cluster_idx(data, cluster_means):
cluster_idx = np.zeros(len(data), dtype=np.uint32)
for k in range(len(data)):
img = data[k].astype(np.float32)
cluster_idx[k] = np.argmin(np.sum((cluster_means.reshape(-1, height * width * 3) - img.reshape(1, height * width * 3)) ** 2, 1))
return cluster_idx
#############################
# Data
print("loading data")
train_x_all, train_labels, test_x_all, test_labels, extra_x, extra_labels = datasets.load_svhn()
valid_x_all = train_x_all[50000:, ...]
train_x_all = np.concatenate((train_x_all[0:50000, ...], extra_x), 0)
train_x_all = train_x_all.reshape(train_x_all.shape[0], height, width, 3)
valid_x_all = valid_x_all.reshape(valid_x_all.shape[0], height, width, 3)
test_x_all = test_x_all.reshape(test_x_all.shape[0], height, width, 3)
train_x_all = torch.tensor(train_x_all, device=device, dtype=torch.float32).reshape(-1, width*height, 3) / 255.
valid_x_all = torch.tensor(valid_x_all, device=device, dtype=torch.float32).reshape(-1, width*height, 3) / 255.
test_x_all = torch.tensor(test_x_all, device=device, dtype=torch.float32).reshape(-1, width*height, 3) / 255.
print("done")
_, cluster_idx = pickle.load(open('../auxiliary/svhn/kmeans_{}.pkl'.format(num_clusters), 'rb'))
# make a mixture of EiNets
p = np.histogram(cluster_idx, num_clusters)[0].astype(np.float32)
p = p / p.sum()
einets = []
for k in range(num_clusters):
print("Load model for cluster {}".format(k))
model_file = os.path.join(einet_path, 'cluster_{}'.format(k), 'einet.mdl')
einets.append(torch.load(model_file).to(device))
mixture = EinetMixture(p, einets)
L = 7
samples = mixture.sample(L**2, std_correction=0.0)
utils.save_image_stack(samples.reshape(-1, height, width, 3),
L, L,
os.path.join(sample_path, 'einet_samples.png'),
margin=2,
margin_gray_val=0.,
frame=2,
frame_gray_val=0.0)
print("Saved samples to {}".format(os.path.join(sample_path, 'einet_samples.png')))
num_reconstructions = 10
rp = np.random.permutation(test_x_all.shape[0])
test_x = test_x_all[rp[0:num_reconstructions], ...]
# Make covered images -- Top
test_x_covered_top = np.reshape(test_x.clone().cpu().numpy(), (num_reconstructions, height, width, 3))
test_x_covered_top[:, 0:round(height/2), ...] = 0.0
# Draw conditional samples for reconstruction -- Top
image_scope = np.array(range(height * width)).reshape(height, width)
marginalize_idx = list(image_scope[0:round(height/2), :].reshape(-1))
rec_samples_top = mixture.conditional_sample(test_x, marginalize_idx, std_correction=0.0)
# Make covered images -- Left
test_x_covered_left = np.reshape(test_x.clone().cpu().numpy(), (num_reconstructions, height, width, 3))
test_x_covered_left[:, :, 0:round(width/2), ...] = 0.0
# Draw conditional samples for reconstruction -- Left
image_scope = np.array(range(height * width)).reshape(height, width)
marginalize_idx = list(image_scope[:, 0:round(width/2)].reshape(-1))
rec_samples_left = mixture.conditional_sample(test_x, marginalize_idx, std_correction=0.0)
reconstruction_stack = np.concatenate((np.reshape(test_x.cpu().numpy(), (num_reconstructions, height, width, 3)),
np.reshape(test_x_covered_top, (num_reconstructions, height, width, 3)),
np.reshape(rec_samples_top, (num_reconstructions, height, width, 3)),
np.reshape(test_x_covered_left, (num_reconstructions, height, width, 3)),
np.reshape(rec_samples_left, (num_reconstructions, height, width, 3))), 0)
reconstruction_stack -= reconstruction_stack.min()
reconstruction_stack /= reconstruction_stack.max()
utils.save_image_stack(reconstruction_stack, 5,
num_reconstructions, os.path.join(sample_path, 'einet_reconstructions.png'),
margin=2,
margin_gray_val=0.,
frame=2,
frame_gray_val=0.0)
print("Saved reconstructions to {}".format(os.path.join(sample_path, 'einet_reconstructions.png')))
print("Compute test log-likelihood...")
ll = mixture.log_likelihood(test_x_all)
print("log-likelihood = {}".format(ll / test_x_all.shape[0]))