Skip to content

Commit

Permalink
added SVHN mixture
Browse files Browse the repository at this point in the history
  • Loading branch information
smatmo committed Jul 14, 2020
1 parent 629d28d commit b2dcbbb
Show file tree
Hide file tree
Showing 3 changed files with 455 additions and 0 deletions.
80 changes: 80 additions & 0 deletions src/EinsumNetwork/EinetMixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import torch
import numpy as np
from scipy.special import logsumexp
from EinsumNetwork.EinsumNetwork import log_likelihoods


class EinetMixture:
"""A simple class for mixtures of Einets, implemented in numpy."""

def __init__(self, p, einets):

if len(p) != len(einets):
raise AssertionError("p and einets must have the same length.")

self.num_components = len(p)

self.p = p
self.einets = einets

num_var = set([e.args.num_var for e in einets])
if len(num_var) != 1:
raise AssertionError("all EiNet components must have the same num_var.")
self.num_var = list(num_var)[0]

num_dims = set([e.args.num_dims for e in einets])
if len(num_dims) != 1:
raise AssertionError("all EiNet components must have the same num_dims.")
self.num_dims = list(num_dims)[0]

def sample(self, N, **kwargs):
samples = np.zeros((N, self.num_var, self.num_dims))
for k in range(N):
rand_idx = np.sum(np.random.rand() > np.cumsum(self.p[0:-1]))
samples[k, ...] = self.einets[rand_idx].sample(num_samples=1, **kwargs).cpu().numpy()
return samples

def conditional_sample(self, x, marginalize_idx, **kwargs):
marginalization_backup = []
component_posterior = np.zeros((self.num_components, x.shape[0]))
for einet_counter, einet in enumerate(self.einets):
marginalization_backup.append(einet.get_marginalization_idx())
einet.set_marginalization_idx(marginalize_idx)
lls = einet.forward(x)
lls = lls.sum(1)
component_posterior[einet_counter, :] = lls.detach().cpu().numpy() + np.log(self.p[einet_counter])

component_posterior = component_posterior - logsumexp(component_posterior, 0, keepdims=True)
component_posterior = np.exp(component_posterior)

samples = np.zeros((x.shape[0], self.num_var, self.num_dims))
for test_idx in range(x.shape[0]):
component_idx = np.argmax(component_posterior[:, test_idx])
sample = self.einets[component_idx].sample(x=x[test_idx:test_idx + 1, :], **kwargs)
samples[test_idx, ...] = sample.squeeze().cpu().numpy()

# restore the original marginalization indices
for einet_counter, einet in enumerate(self.einets):
einet.set_marginalization_idx(marginalization_backup[einet_counter])

return samples

def log_likelihood(self, x, labels=None, batch_size=100):
with torch.no_grad():
idx_batches = torch.arange(0, x.shape[0], dtype=torch.int64, device=x.device).split(batch_size)
ll_total = 0.0
for batch_count, idx in enumerate(idx_batches):
batch_x = x[idx, :]
if labels is not None:
batch_labels = labels[idx]
else:
batch_labels = None

lls = torch.zeros(len(idx), self.num_components, device=x.device)
for einet_count, einet in enumerate(self.einets):
outputs = einet(batch_x)
lls[:, einet_count] = log_likelihoods(outputs, labels=batch_labels).squeeze()
lls[:, einet_count] -= torch.log(torch.tensor(self.p[einet_count]))
lls = torch.logsumexp(lls, dim=1)
ll_total += lls.sum().item()
return ll_total
147 changes: 147 additions & 0 deletions src/eval_svhn_mixture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
import numpy as np
import torch
from EinsumNetwork import Graph, EinsumNetwork
from EinsumNetwork.EinetMixture import EinetMixture
import pickle
import os
import datasets
import utils
from PIL import Image



device = 'cpu'

einet_path = '../models/einet/svhn/num_clusters_100'

sample_path = '../samples/svhn'


utils.mkdir_p(sample_path)



height = 32
width = 32
num_clusters = 100

poon_domingos_pieces = [4]
num_sums = 40
input_multiplier = 1

structure = 'poon_domingos_vertical'

exponential_family = EinsumNetwork.NormalArray
exponential_family_args = {'min_var': 1e-6, 'max_var': 0.01}

if input_multiplier == 1:
block_mix_input = None
num_input_distributions = num_sums
else:
block_mix_input = num_sums
num_input_distributions = num_sums * input_multiplier

#######################################################################################


def compute_cluster_means(data, cluster_idx):
unique_idx = np.unique(cluster_idx)
means = np.zeros((len(unique_idx), height, width, 3), dtype=np.float32)
for k in unique_idx:
means[k, ...] = np.mean(data[cluster_idx == k, ...].astype(np.float32), 0)
return means


def compute_cluster_idx(data, cluster_means):
cluster_idx = np.zeros(len(data), dtype=np.uint32)
for k in range(len(data)):
img = data[k].astype(np.float32)
cluster_idx[k] = np.argmin(np.sum((cluster_means.reshape(-1, height * width * 3) - img.reshape(1, height * width * 3)) ** 2, 1))
return cluster_idx


#############################
# Data

print("loading data")
train_x_all, train_labels, test_x_all, test_labels, extra_x, extra_labels = datasets.load_svhn()

valid_x_all = train_x_all[50000:, ...]
train_x_all = np.concatenate((train_x_all[0:50000, ...], extra_x), 0)
train_x_all = train_x_all.reshape(train_x_all.shape[0], height, width, 3)
valid_x_all = valid_x_all.reshape(valid_x_all.shape[0], height, width, 3)
test_x_all = test_x_all.reshape(test_x_all.shape[0], height, width, 3)

train_x_all = torch.tensor(train_x_all, device=device, dtype=torch.float32).reshape(-1, width*height, 3) / 255.
valid_x_all = torch.tensor(valid_x_all, device=device, dtype=torch.float32).reshape(-1, width*height, 3) / 255.
test_x_all = torch.tensor(test_x_all, device=device, dtype=torch.float32).reshape(-1, width*height, 3) / 255.
print("done")

_, cluster_idx = pickle.load(open('../auxiliary/svhn/kmeans_{}.pkl'.format(num_clusters), 'rb'))

# make a mixture of EiNets
p = np.histogram(cluster_idx, num_clusters)[0].astype(np.float32)
p = p / p.sum()

einets = []
for k in range(num_clusters):
print("Load model for cluster {}".format(k))
model_file = os.path.join(einet_path, 'cluster_{}'.format(k), 'einet.mdl')
einets.append(torch.load(model_file).to(device))

mixture = EinetMixture(p, einets)

L = 7
samples = mixture.sample(L**2, std_correction=0.0)
utils.save_image_stack(samples.reshape(-1, height, width, 3),
L, L,
os.path.join(sample_path, 'einet_samples.png'),
margin=2,
margin_gray_val=0.,
frame=2,
frame_gray_val=0.0)
print("Saved samples to {}".format(os.path.join(sample_path, 'einet_samples.png')))


num_reconstructions = 10

rp = np.random.permutation(test_x_all.shape[0])
test_x = test_x_all[rp[0:num_reconstructions], ...]

# Make covered images -- Top
test_x_covered_top = np.reshape(test_x.clone().cpu().numpy(), (num_reconstructions, height, width, 3))
test_x_covered_top[:, 0:round(height/2), ...] = 0.0

# Draw conditional samples for reconstruction -- Top
image_scope = np.array(range(height * width)).reshape(height, width)
marginalize_idx = list(image_scope[0:round(height/2), :].reshape(-1))
rec_samples_top = mixture.conditional_sample(test_x, marginalize_idx, std_correction=0.0)

# Make covered images -- Left
test_x_covered_left = np.reshape(test_x.clone().cpu().numpy(), (num_reconstructions, height, width, 3))
test_x_covered_left[:, :, 0:round(width/2), ...] = 0.0

# Draw conditional samples for reconstruction -- Left
image_scope = np.array(range(height * width)).reshape(height, width)
marginalize_idx = list(image_scope[:, 0:round(width/2)].reshape(-1))
rec_samples_left = mixture.conditional_sample(test_x, marginalize_idx, std_correction=0.0)

reconstruction_stack = np.concatenate((np.reshape(test_x.cpu().numpy(), (num_reconstructions, height, width, 3)),
np.reshape(test_x_covered_top, (num_reconstructions, height, width, 3)),
np.reshape(rec_samples_top, (num_reconstructions, height, width, 3)),
np.reshape(test_x_covered_left, (num_reconstructions, height, width, 3)),
np.reshape(rec_samples_left, (num_reconstructions, height, width, 3))), 0)

reconstruction_stack -= reconstruction_stack.min()
reconstruction_stack /= reconstruction_stack.max()
utils.save_image_stack(reconstruction_stack, 5,
num_reconstructions, os.path.join(sample_path, 'einet_reconstructions.png'),
margin=2,
margin_gray_val=0.,
frame=2,
frame_gray_val=0.0)
print("Saved reconstructions to {}".format(os.path.join(sample_path, 'einet_reconstructions.png')))

print("Compute test log-likelihood...")
ll = mixture.log_likelihood(test_x_all)
print("log-likelihood = {}".format(ll / test_x_all.shape[0]))
Loading

0 comments on commit b2dcbbb

Please sign in to comment.