diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..016093f --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +.DS_Store +pull +push +experiment_tools.py +scribe diff --git a/README.md b/README.md new file mode 100755 index 0000000..9535aa3 --- /dev/null +++ b/README.md @@ -0,0 +1,27 @@ +# PixelVAE + +Code for the models in [PixelVAE: A Latent Variable Model for Natural Images](https://arxiv.org/abs/1611.05013) + +## MNIST + +To train: + +``` +python models/mnist_pixelvae_train.py -L 12 -fs 5 -algo cond_z_bias -dpx 16 -ldim 16 +``` + +To evaluate, take the weights of the model with best validation score from the above training procedure and then run + +``` +python models/mnist_pixelvae_evaluate.py -L 12 -fs 5 -algo cond_z_bias -dpx 16 -ldim 16 -w path/to/weights.pkl +``` + +## Other datasets + +To train, evaluate, and generate samples: + +``` +python pixelvae.py +``` + +By default, this runs on real-valued MNIST. You can pecify different datasets or model settings within `pixelvae.py`. \ No newline at end of file diff --git a/lib/__init__.py b/lib/__init__.py new file mode 100755 index 0000000..986fcf7 --- /dev/null +++ b/lib/__init__.py @@ -0,0 +1,128 @@ +import numpy +import theano +import theano.tensor as T + +import cPickle as pickle +import math +import time +import locale + +locale.setlocale(locale.LC_ALL, '') + +_params = {} +def param(name, *args, **kwargs): + """ + A wrapper for `theano.shared` which enables parameter sharing in models. + + Creates and returns theano shared variables similarly to `theano.shared`, + except if you try to create a param with the same name as a + previously-created one, `param(...)` will just return the old one instead of + making a new one. + + This constructor also adds a `param` attribute to the shared variables it + creates, so that you can easily search a graph for all params. + """ + + if name not in _params: + kwargs['name'] = name + param = theano.shared(*args, **kwargs) + param.param = True + _params[name] = param + return _params[name] + +def delete_params_with_name(name): + to_delete = [p_name for p_name in _params if name in p_name] + for p_name in to_delete: + del _params[p_name] + +def delete_all_params(): + to_delete = [p_name for p_name in _params] + for p_name in to_delete: + del _params[p_name] + +def save_params(path): + param_vals = {} + for name, param in _params.iteritems(): + param_vals[name] = param.get_value() + # print name + + with open(path, 'wb') as f: + pickle.dump(param_vals, f) + +def load_params(path): + with open(path, 'rb') as f: + param_vals = pickle.load(f) + + for name, val in param_vals.iteritems(): + _params[name].set_value(val) + # print name + +def search(node, critereon): + """ + Traverse the Theano graph starting at `node` and return a list of all nodes + which match the `critereon` function. When optimizing a cost function, you + can use this to get a list of all of the trainable params in the graph, like + so: + + `lib.search(cost, lambda x: hasattr(x, "param"))` + """ + + def _search(node, critereon, visited): + if node in visited: + return [] + visited.add(node) + + results = [] + if isinstance(node, T.Apply): + for inp in node.inputs: + results += _search(inp, critereon, visited) + else: # Variable node + if critereon(node): + results.append(node) + if node.owner is not None: + results += _search(node.owner, critereon, visited) + return results + + return _search(node, critereon, set()) + +def floatX(x): + """ + Convert `x` to the numpy type specified in `theano.config.floatX`. + """ + if theano.config.floatX == 'float16': + return numpy.float16(x) + elif theano.config.floatX == 'float32': + return numpy.float32(x) + else: # Theano's default float type is float64 + print "Warning: lib.floatX using float64" + return numpy.float64(x) + +def print_params_info(params): + """Print information about the parameters in the given param set.""" + + params = sorted(params, key=lambda p: p.name) + values = [p.get_value(borrow=True) for p in params] + shapes = [p.shape for p in values] + print "Params for cost:" + for param, value, shape in zip(params, values, shapes): + print "\t{0} ({1})".format( + param.name, + ",".join([str(x) for x in shape]) + ) + + total_param_count = 0 + for shape in shapes: + param_count = 1 + for dim in shape: + param_count *= dim + total_param_count += param_count + print "Total parameter count: {0}".format( + locale.format("%d", total_param_count, grouping=True) + ) + +def print_model_settings(locals_): + print "Model settings:" + all_vars = [(k,v) for (k,v) in locals_.items() if (k.isupper() and k!='T')] + all_vars = sorted(all_vars, key=lambda x: x[0]) + for var_name, var_value in all_vars: + print "\t{}: {}".format(var_name, var_value) \ No newline at end of file diff --git a/lib/debug.py b/lib/debug.py new file mode 100755 index 0000000..69ea172 --- /dev/null +++ b/lib/debug.py @@ -0,0 +1,34 @@ +import numpy as np +import theano +from theano import gof + +class DebugOp(gof.Op): + def __init__(self, name, fn): + super(DebugOp, self).__init__() + self._name = name + self._fn = fn + + def make_node(self, x): + return gof.Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, output_storage): + self._fn(self._name, inputs[0]) + output_storage[0][0] = np.copy(inputs[0]) + + def grad(self, inputs, output_gradients): + return [DebugOp(self._name+'.grad', self._fn)(output_gradients[0])] + +def print_shape(name, x): + def fn(_name, _x): + print "{} shape: {}".format(_name, _x.shape) + return DebugOp(name, fn)(x) + +def print_stats(name, x): + return x + def fn(_name, _x): + mean = np.mean(_x) + std = np.std(_x) + percentiles = np.percentile(_x, [0,25,50,75,100]) + # percentiles = "skipping" + print "{}\tmean:{}\tstd:{}\tpercentiles:{}\t".format(_name, mean, std, percentiles) + return DebugOp(name, fn)(x) \ No newline at end of file diff --git a/lib/mnist_binarized.py b/lib/mnist_binarized.py new file mode 100755 index 0000000..0e99418 --- /dev/null +++ b/lib/mnist_binarized.py @@ -0,0 +1,38 @@ +from fuel.datasets import BinarizedMNIST +import numpy as np + +from fuel.datasets import BinarizedMNIST +from fuel.schemes import ShuffledScheme, SequentialScheme +from fuel.streams import DataStream +# from fuel.transformers.image import RandomFixedSizeCrop + +def _make_stream(stream, bs): + def new_stream(): + result = np.empty((bs, 1, 28, 28), dtype = 'float32') + for (imb,) in stream.get_epoch_iterator(): + for i, img in enumerate(imb): + result[i] = img + yield (result,) + return new_stream + +def load(batch_size, test_batch_size): + tr_data = BinarizedMNIST(which_sets=('train',)) + val_data = BinarizedMNIST(which_sets=('valid',)) + test_data = BinarizedMNIST(which_sets=('test',)) + + ntrain = tr_data.num_examples + nval = val_data.num_examples + ntest = test_data.num_examples + + tr_scheme = ShuffledScheme(examples=ntrain, batch_size=batch_size) + tr_stream = DataStream(tr_data, iteration_scheme=tr_scheme) + + te_scheme = SequentialScheme(examples=ntest, batch_size=test_batch_size) + te_stream = DataStream(test_data, iteration_scheme=te_scheme) + + val_scheme = SequentialScheme(examples=nval, batch_size=batch_size) + val_stream = DataStream(val_data, iteration_scheme=val_scheme) + + return _make_stream(tr_stream, batch_size), \ + _make_stream(val_stream, batch_size), \ + _make_stream(te_stream, test_batch_size) \ No newline at end of file diff --git a/lib/ops/__init__.py b/lib/ops/__init__.py new file mode 100755 index 0000000..e69de29 diff --git a/lib/ops/conv2d.py b/lib/ops/conv2d.py new file mode 100755 index 0000000..f1a2ecc --- /dev/null +++ b/lib/ops/conv2d.py @@ -0,0 +1,123 @@ +import lib +import lib.debug + +import numpy as np +import theano +import theano.tensor as T + +_default_weightnorm = False +def enable_default_weightnorm(): + global _default_weightnorm + _default_weightnorm = True + +def Conv2D(name, input_dim, output_dim, filter_size, inputs, he_init=True, mask_type=None, mode = 'half', stride=1, weightnorm=None, biases=True): + """ + inputs: tensor of shape (batch size, num channels, height, width) + mask_type: one of None, 'a', 'b', 'hstack_a', 'hstack', 'vstack' + + returns: tensor of shape (batch size, num channels, height, width) + """ + if mask_type is not None: + mask_type, mask_n_channels = mask_type + assert(mode == "half") + + if isinstance(filter_size, int): + filter_size = (filter_size, filter_size) + + #else it is assumed to be a tuple + + def uniform(stdev, size): + return np.random.uniform( + low=-stdev * np.sqrt(3), + high=stdev * np.sqrt(3), + size=size + ).astype(theano.config.floatX) + + fan_in = input_dim * filter_size[0]*filter_size[1] + fan_out = output_dim * filter_size[0]*filter_size[1] + # TODO: shouldn't fan_out be divided by stride + + + if mask_type is not None: # only approximately correct + fan_in /= 2. + fan_out /= 2. + + if he_init: + filters_stdev = np.sqrt(4./(fan_in+fan_out)) + else: # Normalized init (Glorot & Bengio) + filters_stdev = np.sqrt(2./(fan_in+fan_out)) + + filter_values = uniform( + filters_stdev, + (output_dim, input_dim, filter_size[0], filter_size[1]) + ) + + filters = lib.param(name+'.Filters', filter_values) + + if weightnorm==None: + weightnorm = _default_weightnorm + if weightnorm: + norm_values = np.linalg.norm(filter_values.reshape((filter_values.shape[0], -1)), axis=1) + norms = lib.param( + name + '.g', + norm_values + ) + filters = filters * (norms / filters.reshape((filters.shape[0],-1)).norm(2, axis=1)).dimshuffle(0,'x','x','x') + + if mask_type is not None: + mask = np.ones( + (output_dim, input_dim, filter_size[0], filter_size[1]), + dtype=theano.config.floatX + ) + center_row = filter_size[0] // 2 + + center_col = filter_size[1]//2 + + # Mask out future locations + # filter shape is (out_channels, in_channels, height, width) + if center_row == 0: + mask[:, :, :, center_col+1:] = 0. + elif center_col == 0: + mask[:, :, center_row+1:, :] = 0. + else: + mask[:, :, center_row+1:, :] = 0. + mask[:, :, center_row, center_col+1:] = 0. + + # Mask out future channels + for i in xrange(mask_n_channels): + for j in xrange(mask_n_channels): + if ((mask_type=='a' or mask_type == 'hstack_a') and i >= j) or (mask_type=='b' and i > j): + mask[ + j::mask_n_channels, + i::mask_n_channels, + center_row, + center_col + ] = 0. + + if mask_type == 'vstack': + assert(center_col > 0 and center_row > 0) + mask[:, :, center_row, :] = 1. + + # print mask[0,0,:,:] + + + filters = filters * mask + + if biases: + _biases = lib.param( + name+'.Biases', + np.zeros(output_dim, dtype=theano.config.floatX) + ) + + result = T.nnet.conv2d( + inputs, + filters, + border_mode=mode, + filter_flip=False, + subsample=(stride,stride) + ) + + if biases: + result = result + _biases[None, :, None, None] + # result = lib.debug.print_stats(name, result) + return result \ No newline at end of file diff --git a/lib/ops/deconv2d.py b/lib/ops/deconv2d.py new file mode 100755 index 0000000..b87e9f0 --- /dev/null +++ b/lib/ops/deconv2d.py @@ -0,0 +1,110 @@ +import lib +import lib.debug + +import numpy as np +import theano +import theano.tensor as T +from theano.sandbox.cuda.basic_ops import (as_cuda_ndarray_variable, + host_from_gpu, + gpu_contiguous, HostFromGpu, + gpu_alloc_empty) +from theano.sandbox.cuda.dnn import (GpuDnnConvDesc, + GpuDnnConv, + GpuDnnConvGradI, + dnn_conv, + dnn_pool) + +_default_weightnorm = False +def enable_default_weightnorm(): + global _default_weightnorm + _default_weightnorm = True + +def _deconv2d(X, w, subsample=(1, 1), border_mode=(0, 0), conv_mode='conv'): + """ + from Alec (https://github.com/Newmu/dcgan_code/blob/master/lib/ops.py) + sets up dummy convolutional forward pass and uses its grad as deconv + currently only tested/working with same padding + """ + img = gpu_contiguous(X) + kerns = gpu_contiguous(w) + + out = gpu_alloc_empty( + img.shape[0], + kerns.shape[1], + img.shape[2]*subsample[0], + img.shape[3]*subsample[1] + ) + + desc = GpuDnnConvDesc(border_mode=border_mode, subsample=subsample, + conv_mode=conv_mode) + + desc = desc( + out.shape, + kerns.shape + ) + + d_img = GpuDnnConvGradI()(kerns, img, out, desc) + + return d_img + + +def Deconv2D( + name, + input_dim, + output_dim, + filter_size, + inputs, + he_init=True, + weightnorm=None, + ): + """ + inputs: tensor of shape (batch size, num channels, height, width) + returns: tensor of shape (batch size, num channels, 2*height, 2*width) + """ + def uniform(stdev, size): + return np.random.uniform( + low=-stdev * np.sqrt(3), + high=stdev * np.sqrt(3), + size=size + ).astype(theano.config.floatX) + + filters_stdev = np.sqrt(1./(input_dim * filter_size**2)) + filters_stdev *= 2. # Because of the stride + if he_init: + filters_stdev *= np.sqrt(2.) + + filter_values = uniform( + filters_stdev, + (input_dim, output_dim, filter_size, filter_size) + ) + + filters = lib.param( + name+'.Filters', + filter_values + ) + + if weightnorm==None: + weightnorm = _default_weightnorm + if weightnorm: + norm_values = np.sqrt(np.sum(np.square(filter_values), axis=(0,2,3))) + norms = lib.param( + name + '.g', + norm_values + ) + filters = filters * (norms / T.sqrt(T.sum(T.sqr(filters), axis=(0,2,3)))).dimshuffle('x',0,'x','x') + + biases = lib.param( + name+'.Biases', + np.zeros(output_dim, dtype=theano.config.floatX) + ) + + pad = (filter_size-1)/2 + result = _deconv2d( + inputs, + filters, + subsample=(2,2), + border_mode=(pad,pad), + ) + result = result + biases[None, :, None, None] + # result = lib.debug.print_stats(name, result) + return result \ No newline at end of file diff --git a/lib/ops/kl_unit_gaussian.py b/lib/ops/kl_unit_gaussian.py new file mode 100755 index 0000000..4b02498 --- /dev/null +++ b/lib/ops/kl_unit_gaussian.py @@ -0,0 +1,9 @@ +import theano.tensor as T + +def kl_unit_gaussian(mu, log_sigma): + """ + KL divergence from a unit Gaussian prior + mean across axis 0 (minibatch), sum across all other axes + based on yaost, via Alec + """ + return -0.5 * (1 + 2 * log_sigma - mu**2 - T.exp(2 * log_sigma)) \ No newline at end of file diff --git a/lib/ops/linear.py b/lib/ops/linear.py new file mode 100755 index 0000000..92fb737 --- /dev/null +++ b/lib/ops/linear.py @@ -0,0 +1,104 @@ +import lib +import lib.debug + +import numpy as np +import theano +import theano.tensor as T + +_default_weightnorm = False +def enable_default_weightnorm(): + global _default_weightnorm + _default_weightnorm = True + +def Linear( + name, + input_dim, + output_dim, + inputs, + biases=True, + initialization=None, + weightnorm=None + ): + """ + initialization: None, `lecun`, `he`, `orthogonal`, `("uniform", range)` + """ + + def uniform(stdev, size): + return np.random.uniform( + low=-stdev * np.sqrt(3), + high=stdev * np.sqrt(3), + size=size + ).astype(theano.config.floatX) + + if initialization == 'lecun' or \ + (initialization == None and input_dim != output_dim): + + weight_values = uniform(np.sqrt(1./input_dim), (input_dim, output_dim)) + + elif initialization == 'glorot': + + weight_values = uniform(np.sqrt(2./(input_dim+output_dim)), (input_dim, output_dim)) + + elif initialization == 'he': + + weight_values = uniform(np.sqrt(2./input_dim), (input_dim, output_dim)) + + elif initialization == 'glorot_he': + + weight_values = uniform(np.sqrt(4./(input_dim+output_dim)), (input_dim, output_dim)) + + elif initialization == 'orthogonal' or \ + (initialization == None and input_dim == output_dim): + + # From lasagne + def sample(shape): + if len(shape) < 2: + raise RuntimeError("Only shapes of length 2 or more are " + "supported.") + flat_shape = (shape[0], np.prod(shape[1:])) + # TODO: why normal and not uniform? + a = np.random.normal(0.0, 1.0, flat_shape) + u, _, v = np.linalg.svd(a, full_matrices=False) + # pick the one with the correct shape + q = u if u.shape == flat_shape else v + q = q.reshape(shape) + return q.astype(theano.config.floatX) + weight_values = sample((input_dim, output_dim)) + + elif initialization[0] == 'uniform': + + weight_values = np.random.uniform( + low=-initialization[1], + high=initialization[1], + size=(input_dim, output_dim) + ).astype(theano.config.floatX) + + else: + raise Exception("Invalid initialization!") + + weight = lib.param( + name + '.W', + weight_values + ) + + if weightnorm==None: + weightnorm = _default_weightnorm + if weightnorm: + norm_values = np.linalg.norm(weight_values, axis=0) + norms = lib.param( + name + '.g', + norm_values + ) + + weight = weight * (norms / weight.norm(2, axis=0)).dimshuffle('x', 0) + + result = T.dot(inputs, weight) + + if biases: + result = result + lib.param( + name + '.b', + np.zeros((output_dim,), dtype=theano.config.floatX) + ) + + # result = lib.debug.print_stats(name, result) + return result \ No newline at end of file diff --git a/lib/train_loop.py b/lib/train_loop.py new file mode 100755 index 0000000..7278cdf --- /dev/null +++ b/lib/train_loop.py @@ -0,0 +1,189 @@ +import lib + +import numpy as np +import theano +import theano.tensor as T +import lasagne + +import time +import itertools +import collections + +def train_loop( + inputs, + cost, + train_data, + times, + prints=None, + inject_total_iters=False, + test_data=None, + callback=None, + optimizer=lasagne.updates.adam, + save_params=False, + nan_guard=False + ): + + params = lib.search(cost, lambda x: hasattr(x, 'param')) + lib.print_params_info(params) + + grads = T.grad(cost, wrt=params, disconnected_inputs='warn') + + grads = [T.clip(g, lib.floatX(-1), lib.floatX(1)) for g in grads] + + updates = optimizer(grads, params) + + if prints is None: + prints = [('cost', cost)] + else: + prints = [('cost', cost)] + prints + + print "Compiling train function..." + if nan_guard: + from theano.compile.nanguardmode import NanGuardMode + mode = NanGuardMode( + nan_is_error=True, + inf_is_error=True, + big_is_error=True + ) + else: + mode = None + train_fn = theano.function( + inputs, + [p[1] for p in prints], + updates=updates, + on_unused_input='warn', + mode=mode + ) + + print "Compiling eval function..." + eval_fn = theano.function( + inputs, + [p[1] for p in prints], + on_unused_input='warn' + ) + + print "Training!" + + total_iters = 0 + total_seconds = 0. + last_print = 0 + last_gen = 0 + + if len(times) >= 4: + gen_every = times[3] + else: + gen_every = times[1] + + if len(times) >= 5: + early_stop = times[4] + if len(times) >= 6: + early_stop_min = times[5] + else: + early_stop_min = 0 + else: + early_stop = None + early_stop_min = None + + best_test_cost = np.inf + best_test_cost_iter = 0. + + all_outputs = [] + all_stats = [] + for epoch in itertools.count(): + + generator = train_data() + while True: + try: + inputs = generator.next() + except StopIteration: + break + + if inject_total_iters: + inputs = [np.int32(total_iters)] + list(inputs) + + start_time = time.time() + outputs = train_fn(*inputs) + total_seconds += time.time() - start_time + total_iters += 1 + + all_outputs.append(outputs) + + if total_iters == 1: + try: # This only matters on Ishaan's computer + import experiment_tools + experiment_tools.register_crash_notifier() + except ImportError: + pass + + if (times[0]=='iters' and total_iters-last_print == times[1]) or \ + (times[0]=='seconds' and total_seconds-last_print >= times[1]): + + mean_outputs = np.array(all_outputs).mean(axis=0) + + if test_data is not None: + if inject_total_iters: + test_outputs = [ + eval_fn(np.int32(total_iters), *inputs) + for inputs in test_data() + ] + else: + test_outputs = [ + eval_fn(*inputs) + for inputs in test_data() + ] + test_mean_outputs = np.array(test_outputs).mean(axis=0) + + stats = collections.OrderedDict() + stats['epoch'] = epoch + stats['iters'] = total_iters + for i,p in enumerate(prints): + stats['train '+p[0]] = mean_outputs[i] + if test_data is not None: + for i,p in enumerate(prints): + stats['test '+p[0]] = test_mean_outputs[i] + stats['secs'] = total_seconds + stats['secs/iter'] = total_seconds / total_iters + + if test_data != None and (stats['test cost'] < best_test_cost or (early_stop_min != None and total_iters <= early_stop_min)): + best_test_cost = stats['test cost'] + best_test_cost_iter = total_iters + + print_str = "" + for k,v in stats.items(): + if isinstance(v, int): + print_str += "{}:{}\t".format(k,v) + else: + print_str += "{}:{:.4f}\t".format(k,v) + print print_str[:-1] # omit the last \t + + all_stats.append(stats) + + all_outputs = [] + last_print += times[1] + + if (times[0]=='iters' and total_iters-last_gen==gen_every) or \ + (times[0]=='seconds' and total_seconds-last_gen >= gen_every): + tag = "iters{}_time{}".format(total_iters, total_seconds) + if callback is not None: + callback(tag) + if save_params: + lib.save_params('params_{}.pkl'.format(tag)) + + last_gen += gen_every + + if (times[0]=='iters' and total_iters == times[2]) or \ + (times[0]=='seconds' and total_seconds >= times[2]) or \ + (test_data != None and early_stop != None and total_iters > (3*early_stop) and (total_iters-best_test_cost_iter) > early_stop): + + if (test_data != None and early_stop != None and total_iters > (3*early_stop) and (total_iters-best_test_cost_iter) > early_stop): + print "Early stop! Best test cost was {} at iter {}".format(best_test_cost, best_test_cost_iter) + + print "Done!" + + try: # This only matters on Ishaan's computer + import experiment_tools + experiment_tools.send_sms("done!") + except ImportError: + pass + + return all_stats \ No newline at end of file diff --git a/mnist_pixelvae_evaluate.py b/mnist_pixelvae_evaluate.py new file mode 100755 index 0000000..70dc607 --- /dev/null +++ b/mnist_pixelvae_evaluate.py @@ -0,0 +1,437 @@ +""" +VAE + Pixel CNN +Ishaan Gulrajani +""" + + +""" +Modified by Kundan Kumar + +Usage: THEANO_FLAGS='mode=FAST_RUN,device=gpu0,floatX=float32,lib.cnmem=.95' python models/mnist_pixelvae_evaluate.py -L 12 -fs 5 -algo cond_z_bias -dpx 16 -ldim 16 -w path/to/weights.pkl +""" + +import os, sys +sys.path.append(os.getcwd()) + +import time + +import argparse + +import lib +import lib.train_loop +import lib.mnist_binarized +import lib.ops.kl_unit_gaussian +import lib.ops.conv2d +import lib.ops.deconv2d +import lib.ops.linear + +import numpy as np +import theano +import theano.tensor as T +from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams +import scipy.misc +import lasagne +import pickle + +import functools + + +parser = argparse.ArgumentParser(description='Generating images pixel by pixel') +parser.add_argument('-L','--num_pixel_cnn_layer', required=True, type=int, help='Number of layers to use in pixelCNN') +parser.add_argument('-algo', '--decoder_algorithm', required = True, help="One of 'cond_z_bias', 'upsample_z_no_conv', 'upsample_z_conv', 'upsample_z_conv_tied' 'vae_only'" ) +parser.add_argument('-enc', '--encoder', required = False, default='simple', help="Encoder: 'complecated' or 'simple' " ) +parser.add_argument('-dpx', '--dim_pix', required = False, default=32, type = int ) +parser.add_argument('-fs', '--filter_size', required = False, default=5, type = int ) +parser.add_argument('-ldim', '--latent_dim', required = False, default=64, type = int ) +parser.add_argument('-ait', '--alpha_iters', required = False, default=10000, type = int ) +parser.add_argument('-w', '--pre_trained_weights', required = True) + + +args = parser.parse_args() + + +assert args.decoder_algorithm in ['cond_z_bias', 'upsample_z_conv'] + +print args + + + +lib.ops.conv2d.enable_default_weightnorm() +lib.ops.linear.enable_default_weightnorm() + +OUT_DIR = '/Tmp/kumarkun/mnist_pixel_final' + "/num_layers_new2_" + str(args.num_pixel_cnn_layer) + args.decoder_algorithm + "_"+args.encoder + +if not os.path.isdir(OUT_DIR): + os.makedirs(OUT_DIR) + print "Created directory {}".format(OUT_DIR) + +def floatX(num): + if theano.config.floatX == 'float32': + return np.float32(num) + else: + raise Exception("{} type not supported".format(theano.config.floatX)) + + +T.nnet.elu = lambda x: T.switch(x >= floatX(0.), x, T.exp(x) - floatX(1.)) + +DIM_1 = 32 +DIM_2 = 32 +DIM_3 = 64 +DIM_4 = 64 +DIM_PIX = args.dim_pix +PIXEL_CNN_FILTER_SIZE = args.filter_size +PIXEL_CNN_LAYERS = args.num_pixel_cnn_layer + +LATENT_DIM = args.latent_dim +ALPHA_ITERS = args.alpha_iters +VANILLA = False +LR = 1e-3 + +BATCH_SIZE = 100 +N_CHANNELS = 1 +HEIGHT = 28 +WIDTH = 28 + +TEST_BATCH_SIZE = 100 +TIMES = ('iters', 500, 500*400, 500, 400*500, 2*ALPHA_ITERS) + +lib.print_model_settings(locals().copy()) + +theano_srng = RandomStreams(seed=234) + +def PixCNNGate(x): + a = x[:,::2] + b = x[:,1::2] + return T.tanh(a) * T.nnet.sigmoid(b) + +def PixCNN_condGate(x, z, dim, activation= 'tanh', name = ""): + a = x[:,::2] + b = x[:,1::2] + + Z_to_tanh = lib.ops.linear.Linear(name+".tanh", input_dim=LATENT_DIM, output_dim=dim, inputs=z) + Z_to_sigmoid = lib.ops.linear.Linear(name+".sigmoid", input_dim=LATENT_DIM, output_dim=dim, inputs=z) + + a = a + Z_to_tanh[:,:, None, None] + b = b + Z_to_sigmoid[:,:,None, None] + + if activation == 'tanh': + return T.tanh(a) * T.nnet.sigmoid(b) + else: + return T.nnet.elu(a) * T.nnet.sigmoid(b) + +def next_stacks(X_v, X_h, inp_dim, name, + global_conditioning = None, + filter_size = 3, + hstack = 'hstack', + residual = True + ): + zero_pad = T.zeros((X_v.shape[0], X_v.shape[1], 1, X_v.shape[3])) + + X_v_padded = T.concatenate([zero_pad, X_v], axis = 2) + + X_v_next = lib.ops.conv2d.Conv2D( + name + ".vstack", + input_dim=inp_dim, + output_dim=2*DIM_PIX, + filter_size=filter_size, + inputs=X_v_padded, + mask_type=('vstack', N_CHANNELS) + ) + + X_v_next_gated = PixCNNGate(X_v_next) + + X_v2h = lib.ops.conv2d.Conv2D( + name + ".v2h", + input_dim=2*DIM_PIX, + output_dim=2*DIM_PIX, + filter_size=(1,1), + inputs=X_v_next[:,:,:-1,:] + ) + + X_h_next = lib.ops.conv2d.Conv2D( + name + '.hstack', + input_dim= inp_dim, + output_dim= 2*DIM_PIX, + filter_size= (1,filter_size), + inputs= X_h, + mask_type=(hstack, N_CHANNELS) + ) + + X_h_next = PixCNNGate(X_h_next + X_v2h) + + X_h_next = lib.ops.conv2d.Conv2D( + name + '.h2h', + input_dim=DIM_PIX, + output_dim=DIM_PIX, + filter_size=(1,1), + inputs= X_h_next + ) + + if residual == True: + X_h_next = X_h_next + X_h + + return X_v_next_gated[:, :, 1:, :], X_h_next + +def next_stacks_gated(X_v, X_h, inp_dim, name, global_conditioning = None, + filter_size = 3, hstack = 'hstack', residual = True): + zero_pad = T.zeros((X_v.shape[0], X_v.shape[1], 1, X_v.shape[3])) + + X_v_padded = T.concatenate([zero_pad, X_v], axis = 2) + + X_v_next = lib.ops.conv2d.Conv2D( + name + ".vstack", + input_dim=inp_dim, + output_dim=2*DIM_PIX, + filter_size=filter_size, + inputs=X_v_padded, + mask_type=('vstack', N_CHANNELS) + ) + X_v_next_gated = PixCNN_condGate(X_v_next, global_conditioning, DIM_PIX, + name = name + ".vstack.conditional") + + X_v2h = lib.ops.conv2d.Conv2D( + name + ".v2h", + input_dim=2*DIM_PIX, + output_dim=2*DIM_PIX, + filter_size=(1,1), + inputs=X_v_next[:,:,:-1,:] + ) + + + X_h_next = lib.ops.conv2d.Conv2D( + name + '.hstack', + input_dim= inp_dim, + output_dim= 2*DIM_PIX, + filter_size= (1,filter_size), + inputs= X_h, + mask_type=(hstack, N_CHANNELS) + ) + + X_h_next = PixCNN_condGate(X_h_next + X_v2h, global_conditioning, DIM_PIX, name = name + ".hstack.conditional") + + X_h_next = lib.ops.conv2d.Conv2D( + name + '.h2h', + input_dim=DIM_PIX, + output_dim=DIM_PIX, + filter_size=(1,1), + inputs= X_h_next + ) + + if residual: + X_h_next = X_h_next + X_h + + return X_v_next_gated[:, :, 1:, :], X_h_next + + + +def Encoder(inputs): + + output = inputs + + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.1', input_dim=N_CHANNELS, output_dim=DIM_1, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.2', input_dim=DIM_1, output_dim=DIM_2, filter_size=3, inputs=output, stride=2)) + + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.3', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.4', input_dim=DIM_2, output_dim=DIM_3, filter_size=3, inputs=output, stride=2)) + + # Pad from 7x7 to 8x8 + padded = T.zeros((output.shape[0], output.shape[1], 8, 8), dtype='float32') + output = T.inc_subtensor(padded[:,:,:7,:7], output) + + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.5', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.6', input_dim=DIM_3, output_dim=DIM_4, filter_size=3, inputs=output, stride=2)) + + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.7', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.8', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, inputs=output)) + + output = output.reshape((output.shape[0], -1)) + output = lib.ops.linear.Linear('Enc.Out', input_dim=4*4*DIM_4, output_dim=2*LATENT_DIM, inputs=output) + return output[:, ::2], output[:, 1::2] + + +def Decoder_no_blind(latents, images): + output = latents + + output = lib.ops.linear.Linear('Dec.Inp', input_dim=LATENT_DIM, output_dim=4*4*DIM_4, inputs=output) + output = T.nnet.relu(output.reshape((output.shape[0], DIM_4, 4, 4))) + + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Dec.1', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Dec.2', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, inputs=output)) + + output = T.nnet.relu(lib.ops.deconv2d.Deconv2D('Dec.3', input_dim=DIM_4, output_dim=DIM_3, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D( 'Dec.4', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, inputs=output)) + + # Cut from 8x8 to 7x7 + output = output[:,:,:7,:7] + + output = T.nnet.relu(lib.ops.deconv2d.Deconv2D('Dec.5', input_dim=DIM_3, output_dim=DIM_2, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D( 'Dec.6', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, inputs=output)) + + output = T.nnet.relu(lib.ops.deconv2d.Deconv2D('Dec.7', input_dim=DIM_2, output_dim=DIM_1, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D( 'Dec.8', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, inputs=output)) + + skip_outputs = [] + + images_with_latent = T.concatenate([images, output], axis=1) + + X_v, X_h = next_stacks(images_with_latent, images_with_latent, N_CHANNELS + DIM_1, "Dec.PixInput", filter_size = 7, hstack = "hstack_a", residual = False) + + for i in xrange(PIXEL_CNN_LAYERS): + X_v, X_h = next_stacks(X_v, X_h, DIM_PIX, "Dec.Pix"+str(i+1), filter_size = 3) + + + output = PixCNNGate(lib.ops.conv2d.Conv2D('Dec.PixOut1', input_dim=DIM_PIX, output_dim=2*DIM_1, filter_size=1, inputs=X_h)) + + output = PixCNNGate(lib.ops.conv2d.Conv2D('Dec.PixOut2', input_dim=DIM_1, output_dim=2*DIM_1, filter_size=1, inputs=output)) + + output = lib.ops.conv2d.Conv2D('Dec.PixOut3', input_dim=DIM_1, output_dim=N_CHANNELS, filter_size=1, inputs=output, he_init=False) + + return output + + +def Decoder_no_blind_conditioned_on_z(latents, images): + output = latents + + X_v, X_h = next_stacks_gated( + images, images, N_CHANNELS, "Dec.PixInput", + global_conditioning = latents, filter_size = 7, + hstack = "hstack_a", residual = False + ) + + for i in xrange(PIXEL_CNN_LAYERS): + X_v, X_h = next_stacks_gated(X_v, X_h, DIM_PIX, "Dec.Pix"+str(i+1), global_conditioning = latents, filter_size = PIXEL_CNN_FILTER_SIZE) + + + output = lib.ops.conv2d.Conv2D('Dec.PixOut1', input_dim=DIM_PIX, output_dim=2*DIM_PIX, filter_size=1, inputs=X_h) + output = PixCNN_condGate(output, latents, DIM_PIX, name='Dec.PixOut1.cond' ) + output = lib.ops.conv2d.Conv2D('Dec.PixOut2', input_dim=DIM_PIX, output_dim=2*DIM_PIX, filter_size=1, inputs=output) + output = PixCNN_condGate(output, latents, DIM_PIX, name='Dec.PixOut2.cond' ) + + output = lib.ops.conv2d.Conv2D('Dec.PixOut3', input_dim=DIM_PIX, output_dim=N_CHANNELS, filter_size=1, inputs=output, he_init=False) + + return output + +def binarize(images): + """ + Stochastically binarize values in [0, 1] by treating them as p-values of + a Bernoulli distribution. + """ + return ( + np.random.uniform(size=images.shape) < images + ).astype(theano.config.floatX) + + + +if args.decoder_algorithm == 'cond_z_bias': + decode_algo = Decoder_no_blind_conditioned_on_z +elif args.decoder_algorithm == 'upsample_z_conv': + decode_algo = Decoder_no_blind +else: + assert False, "you should never be here!!" + + +encoder = Encoder + +total_iters = T.iscalar('total_iters') +images = T.tensor4('images') # shape: (batch size, n channels, height, width) + +mu, log_sigma = encoder(images) + +if VANILLA: + latents = mu +else: + eps = T.cast(theano_srng.normal(mu.shape), theano.config.floatX) + latents = mu + (eps * T.exp(log_sigma)) + +# Theano bug: NaNs unless I pass 2D tensors to binary_crossentropy +reconst_cost = T.nnet.binary_crossentropy( + T.nnet.sigmoid( + decode_algo(latents, images).reshape((-1, N_CHANNELS*HEIGHT*WIDTH)) + ), + images.reshape((-1, N_CHANNELS*HEIGHT*WIDTH)) +).sum(axis=1) + +reg_cost = lib.ops.kl_unit_gaussian.kl_unit_gaussian( + mu, + log_sigma +).sum(axis=1) + +alpha = T.minimum( + 1, + T.cast(total_iters, theano.config.floatX) / lib.floatX(ALPHA_ITERS) +) + +if VANILLA: + cost = reconst_cost +else: + cost = reconst_cost + (alpha * reg_cost) + +sample_fn_latents = T.matrix('sample_fn_latents') +sample_fn = theano.function( + [sample_fn_latents, images], + T.nnet.sigmoid(decode_algo(sample_fn_latents, images)), + on_unused_input='warn' +) + +eval_fn = theano.function( + [images, total_iters], + cost.mean() +) + +train_data, dev_data, test_data = lib.mnist_binarized.load( + BATCH_SIZE, + TEST_BATCH_SIZE +) + + +############################################# +##############Importance Sampling########### +log2pi = T.constant(np.log(2*np.pi).astype(theano.config.floatX)) + +k_ = 10 + +def log_mean_exp(x, axis=1): + m = T.max(x, keepdims=True) + return m + T.log(T.sum(T.exp(x - m), keepdims=True)) - T.log(k_) + +def log_lik(samples, mean, log_sigma): + return -log2pi*T.cast(samples.shape[1], 'float32') / 2 - \ + T.sum(T.sqr((samples-mean)/T.exp(log_sigma)) + 2*log_sigma, axis=1) / 2 + +vae_bound = reconst_cost + reg_cost +log_lik_latent_prior = log_lik(latents, 0., 0.) +log_lik_latent_posterior = log_lik(latents, mu, log_sigma) +loglikelihood_normal = log_lik_latent_prior - reconst_cost - log_lik_latent_posterior + +loglikelihood = -log_mean_exp(loglikelihood_normal) +lik_fn = theano.function( + [images], + [loglikelihood, vae_bound, reconst_cost, reg_cost, log_lik_latent_prior, log_lik_latent_posterior, loglikelihood_normal] +) + + + +def compute_importance_weighted_likelihood(): + i = 0 + total_lik = [] + total_lik_bound = [] + for (images,) in test_data(): + for im in images: + batch_ = np.tile(im, [k_, 1, 1, 1]) + res = lik_fn(batch_) + total_lik_bound.append(res[1].mean()) + + total_lik.append(res[0]) + i += 1 + + print "Importance weighted likelihood", np.mean(total_lik) + print "normal likelihood", np.mean(total_lik_bound) + +print("Loading parameters...") + +lib.load_params(args.pre_trained_weights) + +print("Computing Log-likelihood..") +compute_importance_weighted_likelihood() + + diff --git a/mnist_pixelvae_train.py b/mnist_pixelvae_train.py new file mode 100755 index 0000000..771cf05 --- /dev/null +++ b/mnist_pixelvae_train.py @@ -0,0 +1,454 @@ +""" +VAE + Pixel CNN +Ishaan Gulrajani +""" + + +""" +Modified by Kundan Kumar + +Usage: THEANO_FLAGS='mode=FAST_RUN,device=gpu0,floatX=float32,lib.cnmem=.95' python models/mnist_pixelvae_train.py -L 12 -fs 5 -algo cond_z_bias -dpx 16 -ldim 16 +""" + +import os, sys +sys.path.append(os.getcwd()) + +import time + +import argparse + +import lib +import lib.train_loop +import lib.mnist_binarized +import lib.ops.kl_unit_gaussian +import lib.ops.conv2d +import lib.ops.deconv2d +import lib.ops.linear + +import numpy as np +import theano +import theano.tensor as T +from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams +import scipy.misc +import lasagne +import pickle + +import functools + + +parser = argparse.ArgumentParser(description='Generating images pixel by pixel') +parser.add_argument('-L','--num_pixel_cnn_layer', required=True, type=int, help='Number of layers to use in pixelCNN') +parser.add_argument('-algo', '--decoder_algorithm', required = True, help="One of 'cond_z_bias', 'upsample_z_no_conv', 'upsample_z_conv', 'upsample_z_conv_tied' 'vae_only'" ) +parser.add_argument('-enc', '--encoder', required = False, default='simple', help="Encoder: 'complecated' or 'simple' " ) +parser.add_argument('-dpx', '--dim_pix', required = False, default=32, type = int ) +parser.add_argument('-fs', '--filter_size', required = False, default=5, type = int ) +parser.add_argument('-ldim', '--latent_dim', required = False, default=64, type = int ) +parser.add_argument('-ait', '--alpha_iters', required = False, default=10000, type = int ) +parser.add_argument('-o', '--out_dir', required = False, default=None ) + + +args = parser.parse_args() + + +assert args.decoder_algorithm in ['cond_z_bias', 'upsample_z_conv'] + +print args + + + +lib.ops.conv2d.enable_default_weightnorm() +lib.ops.linear.enable_default_weightnorm() + +if args.out_dir is None: + OUT_DIR_PREFIX = '/Tmp/kumarkun/mnist_pixel_final' +else: + OUT_DIR_PREFIX = args.out_dir + +OUT_DIR = OUT_DIR_PREFIX + "/num_layers_new3_" + str(args.num_pixel_cnn_layer) + args.decoder_algorithm + "_"+args.encoder + +if not os.path.isdir(OUT_DIR): + os.makedirs(OUT_DIR) + print "Created directory {}".format(OUT_DIR) + +def floatX(num): + if theano.config.floatX == 'float32': + return np.float32(num) + else: + raise Exception("{} type not supported".format(theano.config.floatX)) + + +T.nnet.elu = lambda x: T.switch(x >= floatX(0.), x, T.exp(x) - floatX(1.)) + +DIM_1 = 32 +DIM_2 = 32 +DIM_3 = 64 +DIM_4 = 64 +DIM_PIX = args.dim_pix +PIXEL_CNN_FILTER_SIZE = args.filter_size +PIXEL_CNN_LAYERS = args.num_pixel_cnn_layer + +LATENT_DIM = args.latent_dim +ALPHA_ITERS = args.alpha_iters +VANILLA = False +LR = 1e-3 + +BATCH_SIZE = 100 +N_CHANNELS = 1 +HEIGHT = 28 +WIDTH = 28 + +TEST_BATCH_SIZE = 100 +TIMES = ('iters', 500, 500*400, 500, 400*500, 2*ALPHA_ITERS) + +lib.print_model_settings(locals().copy()) + +theano_srng = RandomStreams(seed=234) + +np.random.seed(123) + +def PixCNNGate(x): + a = x[:,::2] + b = x[:,1::2] + return T.tanh(a) * T.nnet.sigmoid(b) + +def PixCNN_condGate(x, z, dim, activation= 'tanh', name = ""): + a = x[:,::2] + b = x[:,1::2] + + Z_to_tanh = lib.ops.linear.Linear(name+".tanh", input_dim=LATENT_DIM, output_dim=dim, inputs=z) + Z_to_sigmoid = lib.ops.linear.Linear(name+".sigmoid", input_dim=LATENT_DIM, output_dim=dim, inputs=z) + + a = a + Z_to_tanh[:,:, None, None] + b = b + Z_to_sigmoid[:,:,None, None] + + if activation == 'tanh': + return T.tanh(a) * T.nnet.sigmoid(b) + else: + return T.nnet.elu(a) * T.nnet.sigmoid(b) + +def next_stacks(X_v, X_h, inp_dim, name, + global_conditioning = None, + filter_size = 3, + hstack = 'hstack', + residual = True + ): + zero_pad = T.zeros((X_v.shape[0], X_v.shape[1], 1, X_v.shape[3])) + + X_v_padded = T.concatenate([zero_pad, X_v], axis = 2) + + X_v_next = lib.ops.conv2d.Conv2D( + name + ".vstack", + input_dim=inp_dim, + output_dim=2*DIM_PIX, + filter_size=filter_size, + inputs=X_v_padded, + mask_type=('vstack', N_CHANNELS) + ) + + X_v_next_gated = PixCNNGate(X_v_next) + + X_v2h = lib.ops.conv2d.Conv2D( + name + ".v2h", + input_dim=2*DIM_PIX, + output_dim=2*DIM_PIX, + filter_size=(1,1), + inputs=X_v_next[:,:,:-1,:] + ) + + X_h_next = lib.ops.conv2d.Conv2D( + name + '.hstack', + input_dim= inp_dim, + output_dim= 2*DIM_PIX, + filter_size= (1,filter_size), + inputs= X_h, + mask_type=(hstack, N_CHANNELS) + ) + + X_h_next = PixCNNGate(X_h_next + X_v2h) + + X_h_next = lib.ops.conv2d.Conv2D( + name + '.h2h', + input_dim=DIM_PIX, + output_dim=DIM_PIX, + filter_size=(1,1), + inputs= X_h_next + ) + + if residual == True: + X_h_next = X_h_next + X_h + + return X_v_next_gated[:, :, 1:, :], X_h_next + +def next_stacks_gated(X_v, X_h, inp_dim, name, global_conditioning = None, + filter_size = 3, hstack = 'hstack', residual = True): + zero_pad = T.zeros((X_v.shape[0], X_v.shape[1], 1, X_v.shape[3])) + + X_v_padded = T.concatenate([zero_pad, X_v], axis = 2) + + X_v_next = lib.ops.conv2d.Conv2D( + name + ".vstack", + input_dim=inp_dim, + output_dim=2*DIM_PIX, + filter_size=filter_size, + inputs=X_v_padded, + mask_type=('vstack', N_CHANNELS) + ) + X_v_next_gated = PixCNN_condGate(X_v_next, global_conditioning, DIM_PIX, + name = name + ".vstack.conditional") + + X_v2h = lib.ops.conv2d.Conv2D( + name + ".v2h", + input_dim=2*DIM_PIX, + output_dim=2*DIM_PIX, + filter_size=(1,1), + inputs=X_v_next[:,:,:-1,:] + ) + + + X_h_next = lib.ops.conv2d.Conv2D( + name + '.hstack', + input_dim= inp_dim, + output_dim= 2*DIM_PIX, + filter_size= (1,filter_size), + inputs= X_h, + mask_type=(hstack, N_CHANNELS) + ) + + X_h_next = PixCNN_condGate(X_h_next + X_v2h, global_conditioning, DIM_PIX, name = name + ".hstack.conditional") + + X_h_next = lib.ops.conv2d.Conv2D( + name + '.h2h', + input_dim=DIM_PIX, + output_dim=DIM_PIX, + filter_size=(1,1), + inputs= X_h_next + ) + + if residual: + X_h_next = X_h_next + X_h + + return X_v_next_gated[:, :, 1:, :], X_h_next + + + +def Encoder(inputs): + + output = inputs + + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.1', input_dim=N_CHANNELS, output_dim=DIM_1, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.2', input_dim=DIM_1, output_dim=DIM_2, filter_size=3, inputs=output, stride=2)) + + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.3', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.4', input_dim=DIM_2, output_dim=DIM_3, filter_size=3, inputs=output, stride=2)) + + # Pad from 7x7 to 8x8 + padded = T.zeros((output.shape[0], output.shape[1], 8, 8), dtype='float32') + output = T.inc_subtensor(padded[:,:,:7,:7], output) + + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.5', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.6', input_dim=DIM_3, output_dim=DIM_4, filter_size=3, inputs=output, stride=2)) + + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.7', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Enc.8', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, inputs=output)) + + output = output.reshape((output.shape[0], -1)) + output = lib.ops.linear.Linear('Enc.Out', input_dim=4*4*DIM_4, output_dim=2*LATENT_DIM, inputs=output) + return output[:, ::2], output[:, 1::2] + + +def Decoder_no_blind(latents, images): + output = latents + + output = lib.ops.linear.Linear('Dec.Inp', input_dim=LATENT_DIM, output_dim=4*4*DIM_4, inputs=output) + output = T.nnet.relu(output.reshape((output.shape[0], DIM_4, 4, 4))) + + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Dec.1', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D('Dec.2', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, inputs=output)) + + output = T.nnet.relu(lib.ops.deconv2d.Deconv2D('Dec.3', input_dim=DIM_4, output_dim=DIM_3, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D( 'Dec.4', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, inputs=output)) + + # Cut from 8x8 to 7x7 + output = output[:,:,:7,:7] + + output = T.nnet.relu(lib.ops.deconv2d.Deconv2D('Dec.5', input_dim=DIM_3, output_dim=DIM_2, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D( 'Dec.6', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, inputs=output)) + + output = T.nnet.relu(lib.ops.deconv2d.Deconv2D('Dec.7', input_dim=DIM_2, output_dim=DIM_1, filter_size=3, inputs=output)) + output = T.nnet.relu(lib.ops.conv2d.Conv2D( 'Dec.8', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, inputs=output)) + + skip_outputs = [] + + images_with_latent = T.concatenate([images, output], axis=1) + + X_v, X_h = next_stacks(images_with_latent, images_with_latent, N_CHANNELS + DIM_1, "Dec.PixInput", filter_size = 7, hstack = "hstack_a", residual = False) + + for i in xrange(PIXEL_CNN_LAYERS): + X_v, X_h = next_stacks(X_v, X_h, DIM_PIX, "Dec.Pix"+str(i+1), filter_size = PIXEL_CNN_FILTER_SIZE) + + + output = PixCNNGate(lib.ops.conv2d.Conv2D('Dec.PixOut1', input_dim=DIM_PIX, output_dim=2*DIM_1, filter_size=1, inputs=X_h)) + + output = PixCNNGate(lib.ops.conv2d.Conv2D('Dec.PixOut2', input_dim=DIM_1, output_dim=2*DIM_1, filter_size=1, inputs=output)) + + output = lib.ops.conv2d.Conv2D('Dec.PixOut3', input_dim=DIM_1, output_dim=N_CHANNELS, filter_size=1, inputs=output, he_init=False) + + return output + + +def Decoder_no_blind_conditioned_on_z(latents, images): + output = latents + + X_v, X_h = next_stacks_gated( + images, images, N_CHANNELS, "Dec.PixInput", + global_conditioning = latents, filter_size = 7, + hstack = "hstack_a", residual = False + ) + + for i in xrange(PIXEL_CNN_LAYERS): + X_v, X_h = next_stacks_gated(X_v, X_h, DIM_PIX, "Dec.Pix"+str(i+1), global_conditioning = latents, filter_size = PIXEL_CNN_FILTER_SIZE) + + + output = lib.ops.conv2d.Conv2D('Dec.PixOut1', input_dim=DIM_PIX, output_dim=2*DIM_PIX, filter_size=1, inputs=X_h) + output = PixCNN_condGate(output, latents, DIM_PIX, name='Dec.PixOut1.cond' ) + output = lib.ops.conv2d.Conv2D('Dec.PixOut2', input_dim=DIM_PIX, output_dim=2*DIM_PIX, filter_size=1, inputs=output) + output = PixCNN_condGate(output, latents, DIM_PIX, name='Dec.PixOut2.cond' ) + + output = lib.ops.conv2d.Conv2D('Dec.PixOut3', input_dim=DIM_PIX, output_dim=N_CHANNELS, filter_size=1, inputs=output, he_init=False) + + return output + +def binarize(images): + """ + Stochastically binarize values in [0, 1] by treating them as p-values of + a Bernoulli distribution. + """ + return ( + np.random.uniform(size=images.shape) < images + ).astype(theano.config.floatX) + + + +if args.decoder_algorithm == 'cond_z_bias': + decode_algo = Decoder_no_blind_conditioned_on_z +elif args.decoder_algorithm == 'upsample_z_conv': + decode_algo = Decoder_no_blind +else: + assert False, "you should never be here!!" + + +encoder = Encoder + +total_iters = T.iscalar('total_iters') +images = T.tensor4('images') # shape: (batch size, n channels, height, width) + +mu, log_sigma = encoder(images) + +if VANILLA: + latents = mu +else: + eps = T.cast(theano_srng.normal(mu.shape), theano.config.floatX) + latents = mu + (eps * T.exp(log_sigma)) + +# Theano bug: NaNs unless I pass 2D tensors to binary_crossentropy +reconst_cost = T.nnet.binary_crossentropy( + T.nnet.sigmoid( + decode_algo(latents, images).reshape((-1, N_CHANNELS*HEIGHT*WIDTH)) + ), + images.reshape((-1, N_CHANNELS*HEIGHT*WIDTH)) +).sum(axis=1) + +reg_cost = lib.ops.kl_unit_gaussian.kl_unit_gaussian( + mu, + log_sigma +).sum(axis=1) + +alpha = T.minimum( + 1, + T.cast(total_iters, theano.config.floatX) / lib.floatX(ALPHA_ITERS) +) + +if VANILLA: + cost = reconst_cost +else: + cost = reconst_cost + (alpha * reg_cost) + +sample_fn_latents = T.matrix('sample_fn_latents') +sample_fn = theano.function( + [sample_fn_latents, images], + T.nnet.sigmoid(decode_algo(sample_fn_latents, images)), + on_unused_input='warn' +) + +eval_fn = theano.function( + [images, total_iters], + cost.mean() +) + +train_data, dev_data, test_data = lib.mnist_binarized.load( + BATCH_SIZE, + TEST_BATCH_SIZE +) + + +def generate_and_save_samples(tag): + + lib.save_params(os.path.join(OUT_DIR, tag + "_params.pkl")) + + def save_images(images, filename, i = None): + """images.shape: (batch, n channels, height, width)""" + if i is not None: + new_tag = "{}_{}".format(tag, i) + else: + new_tag = tag + + images = images.reshape((10,10,28,28)) + + images = images.transpose(1,2,0,3) + images = images.reshape((10*28, 10*28)) + + image = scipy.misc.toimage(images, cmin=0.0, cmax=1.0) + image.save('{}/{}_{}.jpg'.format(OUT_DIR, filename, new_tag)) + + latents = np.random.normal(size=(100, LATENT_DIM)) + + latents = latents.astype(theano.config.floatX) + + samples = np.zeros( + (100, N_CHANNELS, HEIGHT, WIDTH), + dtype=theano.config.floatX + ) + + next_sample = samples.copy() + + t0 = time.time() + for j in xrange(HEIGHT): + for k in xrange(WIDTH): + for i in xrange(N_CHANNELS): + samples_p_value = sample_fn(latents, next_sample) + next_sample[:, i, j, k] = binarize(samples_p_value)[:, i, j, k] + samples[:, i, j, k] = samples_p_value[:, i, j, k] + + t1 = time.time() + print("Time taken for generation {:.4f}".format(t1 - t0)) + + save_images(samples_p_value, 'samples') + + +print("Training") + +lib.train_loop.train_loop( + inputs=[total_iters, images], + inject_total_iters=True, + cost=cost.mean(), + prints=[ + ('alpha', alpha), + ('reconst', reconst_cost.mean()), + ('reg', reg_cost.mean()) + ], + optimizer=functools.partial(lasagne.updates.adam, learning_rate=LR), + train_data=train_data, + test_data=dev_data, + callback=generate_and_save_samples, + times=TIMES +) diff --git a/pixelvae.py b/pixelvae.py new file mode 100644 index 0000000..cc5de1e --- /dev/null +++ b/pixelvae.py @@ -0,0 +1,936 @@ +""" +PixelVAE: A Latent Variable Model for Natural Images +Ishaan Gulrajani, Kundan Kumar, Faruk Ahmed, Adrien Ali Taiga, Francesco Visin, David Vazquez, Aaron Courville +""" + +import os, sys +sys.path.append(os.getcwd()) + +N_GPUS = 1 + +try: # This only matters on Ishaan's computer + import experiment_tools + experiment_tools.wait_for_gpu(tf=True, n_gpus=N_GPUS) +except ImportError: + pass + +import tflib as lib +import tflib.train_loop_2 +import tflib.ops.kl_unit_gaussian +import tflib.ops.kl_gaussian_gaussian +import tflib.ops.conv2d +import tflib.ops.linear +import tflib.ops.batchnorm +import tflib.ops.embedding + +import tflib.lsun_bedrooms +import tflib.mnist_256 +import tflib.small_imagenet + +import numpy as np +import tensorflow as tf +import scipy.misc +from scipy.misc import imsave + +import time +import functools + +DATASET = 'mnist_256' # mnist_256, lsun_32, lsun_64, imagenet_64 +SETTINGS = 'mnist_256' # mnist_256, 32px_small, 32px_big, 64px_small, 64px_big + +if SETTINGS == 'mnist_256': + # two_level uses Enc1/Dec1 for the bottom level, Enc2/Dec2 for the top level + # one_level uses EncFull/DecFull for the bottom (and only) level + MODE = 'one_level' + + # Whether to treat pixel inputs to the model as real-valued (as in the + # original PixelCNN) or discrete (gets better likelihoods). + EMBED_INPUTS = True + + # Turn on/off the bottom-level PixelCNN in Dec1/DecFull + PIXEL_LEVEL_PIXCNN = True + HIGHER_LEVEL_PIXCNN = True + + DIM_EMBED = 16 + DIM_PIX_1 = 32 + DIM_1 = 16 + DIM_2 = 32 + DIM_3 = 32 + DIM_4 = 64 + LATENT_DIM_2 = 128 + + ALPHA1_ITERS = 5000 + ALPHA2_ITERS = 5000 + KL_PENALTY = 1.0 + BETA_ITERS = 1000 + + # In Dec2, we break each spatial location into N blocks (analogous to channels + # in the original PixelCNN) and model each spatial location autoregressively + # as P(x)=P(x0)*P(x1|x0)*P(x2|x0,x1)... In my experiments values of N > 1 + # actually hurt performance. Unsure why; might be a bug. + PIX_2_N_BLOCKS = 1 + + TIMES = { + 'test_every': 2*500, + 'stop_after': 500*500, + 'callback_every': 10*500 + } + + LR = 1e-3 + + LR_DECAY_AFTER = TIMES['stop_after'] + LR_DECAY_FACTOR = 1. + + BATCH_SIZE = 100 + N_CHANNELS = 1 + HEIGHT = 28 + WIDTH = 28 + + # These aren't actually (typically) used for one-level models but some parts + # of the code still depend on them being defined. + LATENT_DIM_1 = 64 + LATENTS1_HEIGHT = 7 + LATENTS1_WIDTH = 7 + +elif SETTINGS == '32px_small': + MODE = 'two_level' + + EMBED_INPUTS = True + + PIXEL_LEVEL_PIXCNN = True + HIGHER_LEVEL_PIXCNN = True + + DIM_EMBED = 16 + DIM_PIX_1 = 128 + DIM_1 = 64 + DIM_2 = 128 + DIM_3 = 256 + LATENT_DIM_1 = 64 + DIM_PIX_2 = 512 + DIM_4 = 512 + LATENT_DIM_2 = 512 + + ALPHA1_ITERS = 2000 + ALPHA2_ITERS = 5000 + KL_PENALTY = 1.00 + BETA_ITERS = 1000 + + PIX_2_N_BLOCKS = 1 + + TIMES = { + 'test_every': 1000, + 'stop_after': 200000, + 'callback_every': 20000 + } + + LR = 1e-3 + + LR_DECAY_AFTER = 180000 + LR_DECAY_FACTOR = 1e-1 + + BATCH_SIZE = 64 + N_CHANNELS = 3 + HEIGHT = 32 + WIDTH = 32 + + LATENTS1_HEIGHT = 8 + LATENTS1_WIDTH = 8 + +elif SETTINGS == '32px_big': + + MODE = 'two_level' + + EMBED_INPUTS = False + + PIXEL_LEVEL_PIXCNN = True + HIGHER_LEVEL_PIXCNN = True + + DIM_EMBED = 16 + DIM_PIX_1 = 256 + DIM_1 = 128 + DIM_2 = 256 + DIM_3 = 512 + LATENT_DIM_1 = 128 + DIM_PIX_2 = 512 + DIM_4 = 512 + LATENT_DIM_2 = 512 + + ALPHA1_ITERS = 2000 + ALPHA2_ITERS = 5000 + KL_PENALTY = 1.00 + BETA_ITERS = 1000 + + PIX_2_N_BLOCKS = 1 + + TIMES = { + 'test_every': 1000, + 'stop_after': 300000, + 'callback_every': 20000 + } + + VANILLA = False + LR = 1e-3 + + LR_DECAY_AFTER = 300000 + LR_DECAY_FACTOR = 1e-1 + + BATCH_SIZE = 64 + N_CHANNELS = 3 + HEIGHT = 32 + WIDTH = 32 + LATENTS1_HEIGHT = 8 + LATENTS1_WIDTH = 8 + +elif SETTINGS == '64px_small': + MODE = 'two_level' + + EMBED_INPUTS = True + + PIXEL_LEVEL_PIXCNN = True + HIGHER_LEVEL_PIXCNN = True + + DIM_EMBED = 16 + DIM_PIX_1 = 128 + DIM_0 = 64 + DIM_1 = 64 + DIM_2 = 128 + LATENT_DIM_1 = 64 + DIM_PIX_2 = 256 + DIM_3 = 256 + DIM_4 = 512 + LATENT_DIM_2 = 512 + + PIX_2_N_BLOCKS = 1 + + TIMES = { + 'test_every': 10000, + 'stop_after': 200000, + 'callback_every': 50000 + } + + VANILLA = False + LR = 1e-3 + + LR_DECAY_AFTER = 180000 + LR_DECAY_FACTOR = .1 + + ALPHA1_ITERS = 2000 + ALPHA2_ITERS = 10000 + KL_PENALTY = 1.0 + BETA_ITERS = 1000 + + BATCH_SIZE = 64 + N_CHANNELS = 3 + HEIGHT = 64 + WIDTH = 64 + LATENTS1_WIDTH = 16 + LATENTS1_HEIGHT = 16 + +elif SETTINGS == '64px_big': + MODE = 'two_level' + + EMBED_INPUTS = True + + PIXEL_LEVEL_PIXCNN = True + HIGHER_LEVEL_PIXCNN = True + + DIM_EMBED = 16 + DIM_PIX_1 = 384 + DIM_0 = 192 + DIM_1 = 256 + DIM_2 = 512 + LATENT_DIM_1 = 64 + DIM_PIX_2 = 512 + DIM_3 = 512 + DIM_4 = 512 + LATENT_DIM_2 = 512 + + PIX_2_N_BLOCKS = 1 + + TIMES = { + 'test_every': 10000, + 'stop_after': 400000, + 'callback_every': 50000 + } + + VANILLA = False + LR = 1e-3 + + LR_DECAY_AFTER = 180000 + LR_DECAY_FACTOR = .5 + + ALPHA1_ITERS = 1000 + ALPHA2_ITERS = 10000 + KL_PENALTY = 1.00 + BETA_ITERS = 500 + + BATCH_SIZE = 48 + N_CHANNELS = 3 + HEIGHT = 64 + WIDTH = 64 + LATENTS1_WIDTH = 16 + LATENTS1_HEIGHT = 16 + +if DATASET == 'mnist_256': + train_data, dev_data, test_data = lib.mnist_256.load(BATCH_SIZE, BATCH_SIZE) +elif DATASET == 'lsun_32': + train_data, dev_data = lib.lsun_bedrooms.load(BATCH_SIZE, downsample=True) +elif DATASET == 'lsun_64': + train_data, dev_data = lib.lsun_bedrooms.load(BATCH_SIZE, downsample=False) +elif DATASET == 'imagenet_64': + train_data, dev_data = lib.small_imagenet.load(BATCH_SIZE) + +lib.print_model_settings(locals().copy()) + +DEVICES = ['/gpu:{}'.format(i) for i in xrange(N_GPUS)] + +lib.ops.conv2d.enable_default_weightnorm() +lib.ops.linear.enable_default_weightnorm() + +with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as session: + bn_is_training = tf.placeholder(tf.bool, shape=None, name='bn_is_training') + bn_stats_iter = tf.placeholder(tf.int32, shape=None, name='bn_stats_iter') + total_iters = tf.placeholder(tf.int32, shape=None, name='total_iters') + all_images = tf.placeholder(tf.int32, shape=[None, N_CHANNELS, HEIGHT, WIDTH], name='all_images') + all_latents1 = tf.placeholder(tf.float32, shape=[None, LATENT_DIM_1, LATENTS1_HEIGHT, LATENTS1_WIDTH], name='all_latents1') + + split_images = tf.split(0, len(DEVICES), all_images) + split_latents1 = tf.split(0, len(DEVICES), all_latents1) + + tower_cost = [] + tower_outputs1_sample = [] + + for device_index, (device, images, latents1_sample) in enumerate(zip(DEVICES, split_images, split_latents1)): + with tf.device(device): + + def nonlinearity(x): + return tf.nn.elu(x) + + def pixcnn_gated_nonlinearity(a, b): + return tf.sigmoid(a) * tf.tanh(b) + + def SubpixelConv2D(*args, **kwargs): + kwargs['output_dim'] = 4*kwargs['output_dim'] + output = lib.ops.conv2d.Conv2D(*args, **kwargs) + output = tf.transpose(output, [0,2,3,1]) + output = tf.depth_to_space(output, 2) + output = tf.transpose(output, [0,3,1,2]) + return output + + def ResidualBlock(name, input_dim, output_dim, inputs, filter_size, mask_type=None, resample=None, he_init=True): + """ + resample: None, 'down', or 'up' + """ + if mask_type != None and resample != None: + raise Exception('Unsupported configuration') + + if resample=='down': + conv_shortcut = functools.partial(lib.ops.conv2d.Conv2D, stride=2) + conv_1 = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=input_dim) + conv_2 = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=output_dim, stride=2) + elif resample=='up': + conv_shortcut = SubpixelConv2D + conv_1 = functools.partial(SubpixelConv2D, input_dim=input_dim, output_dim=output_dim) + conv_2 = functools.partial(lib.ops.conv2d.Conv2D, input_dim=output_dim, output_dim=output_dim) + elif resample==None: + conv_shortcut = lib.ops.conv2d.Conv2D + conv_1 = functools.partial(lib.ops.conv2d.Conv2D, input_dim=input_dim, output_dim=output_dim) + conv_2 = functools.partial(lib.ops.conv2d.Conv2D, input_dim=output_dim, output_dim=output_dim) + else: + raise Exception('invalid resample value') + + if output_dim==input_dim and resample==None: + shortcut = inputs # Identity skip-connection + else: + shortcut = conv_shortcut(name+'.Shortcut', input_dim=input_dim, output_dim=output_dim, filter_size=1, mask_type=mask_type, he_init=False, biases=True, inputs=inputs) + + output = inputs + if mask_type == None: + output = nonlinearity(output) + output = conv_1(name+'.Conv1', filter_size=filter_size, mask_type=mask_type, inputs=output, he_init=he_init, weightnorm=False) + output = nonlinearity(output) + output = conv_2(name+'.Conv2', filter_size=filter_size, mask_type=mask_type, inputs=output, he_init=he_init, weightnorm=False, biases=False) + if device_index == 0: + output = lib.ops.batchnorm.Batchnorm(name+'.BN', [0,2,3], output, bn_is_training, bn_stats_iter) + else: + output = lib.ops.batchnorm.Batchnorm(name+'.BN', [0,2,3], output, bn_is_training, bn_stats_iter, update_moving_stats=False) + else: + output = nonlinearity(output) + output_a = conv_1(name+'.Conv1A', filter_size=filter_size, mask_type=mask_type, inputs=output, he_init=he_init) + output_b = conv_1(name+'.Conv1B', filter_size=filter_size, mask_type=mask_type, inputs=output, he_init=he_init) + output = pixcnn_gated_nonlinearity(output_a, output_b) + output = conv_2(name+'.Conv2', filter_size=filter_size, mask_type=mask_type, inputs=output, he_init=he_init) + + return shortcut + output + + def Enc1(images): + output = images + + if WIDTH == 64: + if EMBED_INPUTS: + output = lib.ops.conv2d.Conv2D('Enc1.Input', input_dim=N_CHANNELS*DIM_EMBED, output_dim=DIM_0, filter_size=1, inputs=output, he_init=False) + output = ResidualBlock('Enc1.InputRes0', input_dim=DIM_0, output_dim=DIM_0, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Enc1.InputRes', input_dim=DIM_0, output_dim=DIM_1, filter_size=3, resample='down', inputs=output) + else: + output = lib.ops.conv2d.Conv2D('Enc1.Input', input_dim=N_CHANNELS, output_dim=DIM_1, filter_size=1, inputs=output, he_init=False) + output = ResidualBlock('Enc1.InputRes', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample='down', inputs=output) + else: + if EMBED_INPUTS: + output = lib.ops.conv2d.Conv2D('Enc1.Input', input_dim=N_CHANNELS*DIM_EMBED, output_dim=DIM_1, filter_size=1, inputs=output, he_init=False) + else: + output = lib.ops.conv2d.Conv2D('Enc1.Input', input_dim=N_CHANNELS, output_dim=DIM_1, filter_size=1, inputs=output, he_init=False) + + + output = ResidualBlock('Enc1.Res1Pre', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Enc1.Res1Pre2', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Enc1.Res1', input_dim=DIM_1, output_dim=DIM_2, filter_size=3, resample='down', inputs=output) + if LATENTS1_WIDTH == 16: + output = ResidualBlock('Enc1.Res4Pre', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Enc1.Res4', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Enc1.Res4Post', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) + mu_and_sigma = lib.ops.conv2d.Conv2D('Enc1.Out', input_dim=DIM_2, output_dim=2*LATENT_DIM_1, filter_size=1, inputs=output, he_init=False) + else: + output = ResidualBlock('Enc1.Res2Pre', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Enc1.Res2Pre2', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Enc1.Res2', input_dim=DIM_2, output_dim=DIM_3, filter_size=3, resample='down', inputs=output) + output = ResidualBlock('Enc1.Res3Pre', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Enc1.Res3Pre2', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Enc1.Res3Pre3', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, inputs=output) + mu_and_sigma = lib.ops.conv2d.Conv2D('Enc1.Out', input_dim=DIM_3, output_dim=2*LATENT_DIM_1, filter_size=1, inputs=output, he_init=False) + + return mu_and_sigma, output + + def Dec1(latents, images): + output = tf.clip_by_value(latents, -50., 50.) + + if LATENTS1_WIDTH == 16: + output = lib.ops.conv2d.Conv2D('Dec1.Input', input_dim=LATENT_DIM_1, output_dim=DIM_2, filter_size=1, inputs=output, he_init=False) + output = ResidualBlock('Dec1.Res1A', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Dec1.Res1B', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Dec1.Res1C', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) + else: + output = lib.ops.conv2d.Conv2D('Dec1.Input', input_dim=LATENT_DIM_1, output_dim=DIM_3, filter_size=1, inputs=output, he_init=False) + output = ResidualBlock('Dec1.Res1', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Dec1.Res1Post', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Dec1.Res1Post2', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Dec1.Res2', input_dim=DIM_3, output_dim=DIM_2, filter_size=3, resample='up', inputs=output) + output = ResidualBlock('Dec1.Res2Post', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Dec1.Res2Post2', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) + + output = ResidualBlock('Dec1.Res3', input_dim=DIM_2, output_dim=DIM_1, filter_size=3, resample='up', inputs=output) + output = ResidualBlock('Dec1.Res3Post', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('Dec1.Res3Post2', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample=None, inputs=output) + + if WIDTH == 64: + output = ResidualBlock('Dec1.Res4', input_dim=DIM_1, output_dim=DIM_0, filter_size=3, resample='up', inputs=output) + output = ResidualBlock('Dec1.Res4Post', input_dim=DIM_0, output_dim=DIM_0, filter_size=3, resample=None, inputs=output) + + if PIXEL_LEVEL_PIXCNN: + + if WIDTH == 64: + if EMBED_INPUTS: + masked_images = lib.ops.conv2d.Conv2D('Dec1.Pix1', input_dim=N_CHANNELS*DIM_EMBED, output_dim=DIM_0, filter_size=5, inputs=images, mask_type=('a', N_CHANNELS), he_init=False) + else: + masked_images = lib.ops.conv2d.Conv2D('Dec1.Pix1', input_dim=N_CHANNELS, output_dim=DIM_0, filter_size=5, inputs=images, mask_type=('a', N_CHANNELS), he_init=False) + else: + if EMBED_INPUTS: + masked_images = lib.ops.conv2d.Conv2D('Dec1.Pix1', input_dim=N_CHANNELS*DIM_EMBED, output_dim=DIM_1, filter_size=5, inputs=images, mask_type=('a', N_CHANNELS), he_init=False) + else: + masked_images = lib.ops.conv2d.Conv2D('Dec1.Pix1', input_dim=N_CHANNELS, output_dim=DIM_1, filter_size=5, inputs=images, mask_type=('a', N_CHANNELS), he_init=False) + + # Make the variance of output and masked_images (roughly) match + output /= 2 + + # Warning! Because of the masked convolutions it's very important that masked_images comes first in this concat + output = tf.concat(1, [masked_images, output]) + + if WIDTH == 64: + output = ResidualBlock('Dec1.Pix2Res', input_dim=2*DIM_0, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) + output = ResidualBlock('Dec1.Pix3Res', input_dim=DIM_PIX_1, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) + output = ResidualBlock('Dec1.Pix4Res', input_dim=DIM_PIX_1, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) + else: + output = ResidualBlock('Dec1.Pix2Res', input_dim=2*DIM_1, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) + output = ResidualBlock('Dec1.Pix3Res', input_dim=DIM_PIX_1, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) + + output = lib.ops.conv2d.Conv2D('Dec1.Out', input_dim=DIM_PIX_1, output_dim=256*N_CHANNELS, filter_size=1, mask_type=('b', N_CHANNELS), he_init=False, inputs=output) + + else: + + if WIDTH == 64: + output = lib.ops.conv2d.Conv2D('Dec1.Out', input_dim=DIM_0, output_dim=256*N_CHANNELS, filter_size=1, he_init=False, inputs=output) + else: + output = lib.ops.conv2d.Conv2D('Dec1.Out', input_dim=DIM_1, output_dim=256*N_CHANNELS, filter_size=1, he_init=False, inputs=output) + + return tf.transpose( + tf.reshape(output, [-1, 256, N_CHANNELS, HEIGHT, WIDTH]), + [0,2,3,4,1] + ) + + def Enc2(h1): + output = h1 + + if LATENTS1_WIDTH == 16: + output = ResidualBlock('Enc2.Res0', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Enc2.Res1Pre', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Enc2.Res1Pre2', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Enc2.Res1', input_dim=DIM_2, output_dim=DIM_3, filter_size=3, resample='down', he_init=True, inputs=output) + + output = ResidualBlock('Enc2.Res2Pre', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Enc2.Res2Pre2', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Enc2.Res2Pre3', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Enc2.Res1A', input_dim=DIM_3, output_dim=DIM_4, filter_size=3, resample='down', he_init=True, inputs=output) + output = ResidualBlock('Enc2.Res2PreA', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Enc2.Res2', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Enc2.Res2Post', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, resample=None, he_init=True, inputs=output) + + output = tf.reshape(output, [-1, 4*4*DIM_4]) + output = lib.ops.linear.Linear('Enc2.Output', input_dim=4*4*DIM_4, output_dim=2*LATENT_DIM_2, inputs=output) + + return output + + def Dec2(latents, targets): + output = tf.clip_by_value(latents, -50., 50.) + output = lib.ops.linear.Linear('Dec2.Input', input_dim=LATENT_DIM_2, output_dim=4*4*DIM_4, inputs=output) + + output = tf.reshape(output, [-1, DIM_4, 4, 4]) + + output = ResidualBlock('Dec2.Res1Pre', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Dec2.Res1', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Dec2.Res1Post', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Dec2.Res3', input_dim=DIM_4, output_dim=DIM_3, filter_size=3, resample='up', he_init=True, inputs=output) + output = ResidualBlock('Dec2.Res3Post', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Dec2.Res3Post2', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Dec2.Res3Post3', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, he_init=True, inputs=output) + + if LATENTS1_WIDTH == 16: + output = ResidualBlock('Dec2.Res3Post5', input_dim=DIM_3, output_dim=DIM_2, filter_size=3, resample='up', he_init=True, inputs=output) + output = ResidualBlock('Dec2.Res3Post6', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Dec2.Res3Post7', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('Dec2.Res3Post8', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, he_init=True, inputs=output) + + if HIGHER_LEVEL_PIXCNN: + + if LATENTS1_WIDTH == 16: + masked_targets = lib.ops.conv2d.Conv2D('Dec2.Pix1', input_dim=LATENT_DIM_1, output_dim=DIM_2, filter_size=5, mask_type=('a', PIX_2_N_BLOCKS), he_init=False, inputs=targets) + else: + masked_targets = lib.ops.conv2d.Conv2D('Dec2.Pix1', input_dim=LATENT_DIM_1, output_dim=DIM_3, filter_size=5, mask_type=('a', PIX_2_N_BLOCKS), he_init=False, inputs=targets) + + # Make the variance of output and masked_targets roughly match + output /= 2 + + output = tf.concat(1, [masked_targets, output]) + + if LATENTS1_WIDTH == 16: + output = ResidualBlock('Dec2.Pix2Res', input_dim=2*DIM_2, output_dim=DIM_PIX_2, filter_size=3, mask_type=('b', PIX_2_N_BLOCKS), he_init=True, inputs=output) + else: + output = ResidualBlock('Dec2.Pix2Res', input_dim=2*DIM_3, output_dim=DIM_PIX_2, filter_size=3, mask_type=('b', PIX_2_N_BLOCKS), he_init=True, inputs=output) + output = ResidualBlock('Dec2.Pix3Res', input_dim=DIM_PIX_2, output_dim=DIM_PIX_2, filter_size=3, mask_type=('b', PIX_2_N_BLOCKS), he_init=True, inputs=output) + output = ResidualBlock('Dec2.Pix4Res', input_dim=DIM_PIX_2, output_dim=DIM_PIX_2, filter_size=1, mask_type=('b', PIX_2_N_BLOCKS), he_init=True, inputs=output) + + output = lib.ops.conv2d.Conv2D('Dec2.Out', input_dim=DIM_PIX_2, output_dim=2*LATENT_DIM_1, filter_size=1, mask_type=('b', PIX_2_N_BLOCKS), he_init=False, inputs=output) + + else: + + if LATENTS1_WIDTH == 16: + output = lib.ops.conv2d.Conv2D('Dec2.Out', input_dim=DIM_2, output_dim=2*LATENT_DIM_1, filter_size=1, mask_type=('b', PIX_2_N_BLOCKS), he_init=False, inputs=output) + else: + output = lib.ops.conv2d.Conv2D('Dec2.Out', input_dim=DIM_3, output_dim=2*LATENT_DIM_1, filter_size=1, mask_type=('b', PIX_2_N_BLOCKS), he_init=False, inputs=output) + + return output + + # Really only for MNIST. Will require modification for other datasets. + def EncFull(images): + output = images + + if EMBED_INPUTS: + output = lib.ops.conv2d.Conv2D('EncFull.Input', input_dim=N_CHANNELS*DIM_EMBED, output_dim=DIM_1, filter_size=1, inputs=output, he_init=False) + else: + output = lib.ops.conv2d.Conv2D('EncFull.Input', input_dim=N_CHANNELS, output_dim=DIM_1, filter_size=1, inputs=output, he_init=False) + + output = ResidualBlock('EncFull.Res1', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('EncFull.Res2', input_dim=DIM_1, output_dim=DIM_2, filter_size=3, resample='down', inputs=output) + output = ResidualBlock('EncFull.Res3', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('EncFull.Res4', input_dim=DIM_2, output_dim=DIM_3, filter_size=3, resample='down', inputs=output) + output = ResidualBlock('EncFull.Res5', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('EncFull.Res6', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, inputs=output) + + output = tf.reduce_mean(output, reduction_indices=[2,3]) + output = lib.ops.linear.Linear('EncFull.Output', input_dim=DIM_3, output_dim=2*LATENT_DIM_2, initialization='glorot', inputs=output) + + return output + + # Really only for MNIST. Will require modification for other datasets. + def DecFull(latents, images): + output = tf.clip_by_value(latents, -50., 50.) + + output = lib.ops.linear.Linear('DecFull.Input', input_dim=LATENT_DIM_2, output_dim=DIM_3, initialization='glorot', inputs=output) + output = tf.reshape(tf.tile(tf.reshape(output, [-1, DIM_3, 1]), [1, 1, 49]), [-1, DIM_3, 7, 7]) + + output = ResidualBlock('DecFull.Res2', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res3', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res4', input_dim=DIM_3, output_dim=DIM_2, filter_size=3, resample='up', he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res5', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res6', input_dim=DIM_2, output_dim=DIM_1, filter_size=3, resample='up', he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res7', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample=None, he_init=True, inputs=output) + + if PIXEL_LEVEL_PIXCNN: + + if EMBED_INPUTS: + masked_images = lib.ops.conv2d.Conv2D('DecFull.Pix1', input_dim=N_CHANNELS*DIM_EMBED, output_dim=DIM_1, filter_size=5, inputs=images, mask_type=('a', N_CHANNELS), he_init=False) + else: + masked_images = lib.ops.conv2d.Conv2D('DecFull.Pix1', input_dim=N_CHANNELS, output_dim=DIM_1, filter_size=5, inputs=images, mask_type=('a', N_CHANNELS), he_init=False) + + # Warning! Because of the masked convolutions it's very important that masked_images comes first in this concat + + output = tf.concat(1, [masked_images, output]) + + # output = ResidualBlock('DecFull.Pix2Res', input_dim=2*DIM_1, output_dim=DIM_PIX_1, filter_size=1, mask_type=('b', N_CHANNELS), inputs=output) + + output = ResidualBlock('DecFull.Pix2Res', input_dim=2*DIM_1, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) + output = ResidualBlock('DecFull.Pix3Res', input_dim=DIM_PIX_1, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) + output = ResidualBlock('DecFull.Pix4Res', input_dim=DIM_PIX_1, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) + output = ResidualBlock('DecFull.Pix5Res', input_dim=DIM_PIX_1, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) + + output = lib.ops.conv2d.Conv2D('Dec1.Out', input_dim=DIM_PIX_1, output_dim=256*N_CHANNELS, filter_size=1, mask_type=('b', N_CHANNELS), he_init=False, inputs=output) + + else: + + output = lib.ops.conv2d.Conv2D('Dec1.Out', input_dim=DIM_1, output_dim=256*N_CHANNELS, filter_size=1, he_init=False, inputs=output) + + return tf.transpose( + tf.reshape(output, [-1, 256, N_CHANNELS, HEIGHT, WIDTH]), + [0,2,3,4,1] + ) + + def split(mu_and_logsig): + mu, logsig = tf.split(1, 2, mu_and_logsig) + sig = 0.5 * (tf.nn.softsign(logsig)+1) + logsig = tf.log(sig) + return mu, logsig, sig + + def clamp_logsig_and_sig(logsig, sig): + # Early during training (see BETA_ITERS), stop sigma from going too low + floor = 1. - tf.minimum(1., tf.cast(total_iters, 'float32') / BETA_ITERS) + log_floor = tf.log(floor) + return tf.maximum(logsig, log_floor), tf.maximum(sig, floor) + + + scaled_images = (tf.cast(images, 'float32') - 128.) / 64. + if EMBED_INPUTS: + embedded_images = lib.ops.embedding.Embedding('Embedding', 256, DIM_EMBED, images) + embedded_images = tf.transpose(embedded_images, [0,4,1,2,3]) + embedded_images = tf.reshape(embedded_images, [-1, DIM_EMBED*N_CHANNELS, HEIGHT, WIDTH]) + + if MODE == 'one_level': + + # Layer 1 + + if EMBED_INPUTS: + mu_and_logsig1 = EncFull(embedded_images) + else: + mu_and_logsig1 = EncFull(scaled_images) + mu1, logsig1, sig1 = split(mu_and_logsig1) + + eps = tf.random_normal(tf.shape(mu1)) + latents1 = mu1 + (eps * sig1) + + if EMBED_INPUTS: + outputs1 = DecFull(latents1, embedded_images) + else: + outputs1 = DecFull(latents1, scaled_images) + + reconst_cost = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + tf.reshape(outputs1, [-1, 256]), + tf.reshape(images, [-1]) + ) + ) + + # Assembly + + # An alpha of exactly 0 can sometimes cause inf/nan values, so we're + # careful to avoid it. + alpha = tf.minimum(1., tf.cast(total_iters+1, 'float32') / ALPHA1_ITERS) * KL_PENALTY + + kl_cost_1 = tf.reduce_mean( + lib.ops.kl_unit_gaussian.kl_unit_gaussian( + mu1, + logsig1, + sig1 + ) + ) + + kl_cost_1 *= float(LATENT_DIM_2) / (N_CHANNELS * WIDTH * HEIGHT) + + cost = reconst_cost + (alpha * kl_cost_1) + + elif MODE == 'two_level': + # Layer 1 + + if EMBED_INPUTS: + mu_and_logsig1, h1 = Enc1(embedded_images) + else: + mu_and_logsig1, h1 = Enc1(scaled_images) + mu1, logsig1, sig1 = split(mu_and_logsig1) + + if mu1.get_shape().as_list()[2] != LATENTS1_HEIGHT: + raise Exception("LATENTS1_HEIGHT doesn't match mu1 shape!") + if mu1.get_shape().as_list()[3] != LATENTS1_WIDTH: + raise Exception("LATENTS1_WIDTH doesn't match mu1 shape!") + + eps = tf.random_normal(tf.shape(mu1)) + latents1 = mu1 + (eps * sig1) + + if EMBED_INPUTS: + outputs1 = Dec1(latents1, embedded_images) + outputs1_sample = Dec1(latents1_sample, embedded_images) + else: + outputs1 = Dec1(latents1, scaled_images) + outputs1_sample = Dec1(latents1_sample, scaled_images) + + reconst_cost = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + tf.reshape(outputs1, [-1, 256]), + tf.reshape(images, [-1]) + ) + ) + + # Layer 2 + + mu_and_logsig2 = Enc2(h1) + mu2, logsig2, sig2 = split(mu_and_logsig2) + + eps = tf.random_normal(tf.shape(mu2)) + latents2 = mu2 + (eps * sig2) + + outputs2 = Dec2(latents2, latents1) + + mu1_prior, logsig1_prior, sig1_prior = split(outputs2) + logsig1_prior, sig1_prior = clamp_logsig_and_sig(logsig1_prior, sig1_prior) + mu1_prior = 2. * tf.nn.softsign(mu1_prior / 2.) + + # Assembly + + # An alpha of exactly 0 can sometimes cause inf/nan values, so we're + # careful to avoid it. + alpha1 = tf.minimum(1., tf.cast(total_iters+1, 'float32') / ALPHA1_ITERS) * KL_PENALTY + alpha2 = tf.minimum(1., tf.cast(total_iters+1, 'float32') / ALPHA2_ITERS) * alpha1# * KL_PENALTY + + kl_cost_1 = tf.reduce_mean( + lib.ops.kl_gaussian_gaussian.kl_gaussian_gaussian( + mu1, + logsig1, + sig1, + mu1_prior, + logsig1_prior, + sig1_prior + ) + ) + + kl_cost_2 = tf.reduce_mean( + lib.ops.kl_unit_gaussian.kl_unit_gaussian( + mu2, + logsig2, + sig2 + ) + ) + + kl_cost_1 *= float(LATENT_DIM_1 * LATENTS1_WIDTH * LATENTS1_HEIGHT) / (N_CHANNELS * WIDTH * HEIGHT) + kl_cost_2 *= float(LATENT_DIM_2) / (N_CHANNELS * WIDTH * HEIGHT) + + cost = reconst_cost + (alpha1 * kl_cost_1) + (alpha2 * kl_cost_2) + + tower_cost.append(cost) + if MODE == 'two_level': + tower_outputs1_sample.append(outputs1_sample) + + full_cost = tf.reduce_mean( + tf.concat(0, [tf.expand_dims(x, 0) for x in tower_cost]), 0 + ) + + if MODE == 'two_level': + full_outputs1_sample = tf.concat(0, tower_outputs1_sample) + + # Sampling + + if MODE == 'one_level': + + ch_sym = tf.placeholder(tf.int32, shape=None) + y_sym = tf.placeholder(tf.int32, shape=None) + x_sym = tf.placeholder(tf.int32, shape=None) + logits = tf.reshape(tf.slice(outputs1, tf.pack([0, ch_sym, y_sym, x_sym, 0]), tf.pack([-1, 1, 1, 1, -1])), [-1, 256]) + dec1_fn_out = tf.multinomial(logits, 1)[:, 0] + def dec1_fn(_latents, _targets, _ch, _y, _x): + return session.run(dec1_fn_out, feed_dict={latents1: _latents, images: _targets, ch_sym: _ch, y_sym: _y, x_sym: _x, total_iters: 99999, bn_is_training: False, bn_stats_iter:0}) + + def enc_fn(_images): + return session.run(latents1, feed_dict={images: _images, total_iters: 99999, bn_is_training: False, bn_stats_iter:0}) + + sample_fn_latents1 = np.random.normal(size=(8, LATENT_DIM_2)).astype('float32') + + def generate_and_save_samples(tag): + def color_grid_vis(X, nh, nw, save_path): + # from github.com/Newmu + X = X.transpose(0,2,3,1) + h, w = X[0].shape[:2] + img = np.zeros((h*nh, w*nw, 3)) + for n, x in enumerate(X): + j = n/nw + i = n%nw + img[j*h:j*h+h, i*w:i*w+w, :] = x + imsave(save_path, img) + + latents1_copied = np.zeros((64, LATENT_DIM_2), dtype='float32') + for i in xrange(8): + latents1_copied[i::8] = sample_fn_latents1 + + samples = np.zeros( + (64, N_CHANNELS, HEIGHT, WIDTH), + dtype='int32' + ) + + print "Generating samples" + for y in xrange(HEIGHT): + for x in xrange(WIDTH): + for ch in xrange(N_CHANNELS): + next_sample = dec1_fn(latents1_copied, samples, ch, y, x) + samples[:,ch,y,x] = next_sample + + print "Saving samples" + color_grid_vis( + samples, + 8, + 8, + 'samples_{}.png'.format(tag) + ) + + + elif MODE == 'two_level': + + def dec2_fn(_latents, _targets): + return session.run([mu1_prior, logsig1_prior], feed_dict={latents2: _latents, latents1: _targets, total_iters: 99999, bn_is_training: False, bn_stats_iter: 0}) + + ch_sym = tf.placeholder(tf.int32, shape=None) + y_sym = tf.placeholder(tf.int32, shape=None) + x_sym = tf.placeholder(tf.int32, shape=None) + logits_sym = tf.reshape(tf.slice(full_outputs1_sample, tf.pack([0, ch_sym, y_sym, x_sym, 0]), tf.pack([-1, 1, 1, 1, -1])), [-1, 256]) + + def dec1_logits_fn(_latents, _targets, _ch, _y, _x): + return session.run(logits_sym, + feed_dict={all_latents1: _latents, + all_images: _targets, + ch_sym: _ch, + y_sym: _y, + x_sym: _x, + total_iters: 99999, + bn_is_training: False, + bn_stats_iter: 0}) + + N_SAMPLES = BATCH_SIZE + if N_SAMPLES % N_GPUS != 0: + raise Exception("N_SAMPLES must be divisible by N_GPUS") + HOLD_Z2_CONSTANT = False + HOLD_EPSILON_1_CONSTANT = False + HOLD_EPSILON_PIXELS_CONSTANT = False + + # Draw z2 from N(0,I) + z2 = np.random.normal(size=(N_SAMPLES, LATENT_DIM_2)).astype('float32') + if HOLD_Z2_CONSTANT: + z2[:] = z2[0][None] + + # Draw epsilon_1 from N(0,I) + epsilon_1 = np.random.normal(size=(N_SAMPLES, LATENT_DIM_1, LATENTS1_HEIGHT, LATENTS1_WIDTH)).astype('float32') + if HOLD_EPSILON_1_CONSTANT: + epsilon_1[:] = epsilon_1[0][None] + + # Draw epsilon_pixels from U[0,1] + epsilon_pixels = np.random.uniform(size=(N_SAMPLES, N_CHANNELS, HEIGHT, WIDTH)) + if HOLD_EPSILON_PIXELS_CONSTANT: + epsilon_pixels[:] = epsilon_pixels[0][None] + + + def generate_and_save_samples(tag): + # Draw z1 autoregressively using z2 and epsilon1 + print "Generating z1" + z1 = np.zeros((N_SAMPLES, LATENT_DIM_1, LATENTS1_HEIGHT, LATENTS1_WIDTH), dtype='float32') + for y in xrange(LATENTS1_HEIGHT): + for x in xrange(LATENTS1_WIDTH): + z1_prior_mu, z1_prior_logsig = dec2_fn(z2, z1) + z1[:,:,y,x] = z1_prior_mu[:,:,y,x] + np.exp(z1_prior_logsig[:,:,y,x]) * epsilon_1[:,:,y,x] + + # Draw pixels (the images) autoregressively using z1 and epsilon_x + print "Generating pixels" + pixels = np.zeros((N_SAMPLES, N_CHANNELS, HEIGHT, WIDTH)).astype('int32') + for y in xrange(HEIGHT): + for x in xrange(WIDTH): + for ch in xrange(N_CHANNELS): + # start_time = time.time() + logits = dec1_logits_fn(z1, pixels, ch, y, x) + probs = np.exp(logits - np.max(logits, axis=-1, keepdims=True)) + probs = probs / np.sum(probs, axis=-1, keepdims=True) + cdf = np.cumsum(probs, axis=-1) + pixels[:,ch,y,x] = np.argmax(cdf >= epsilon_pixels[:,ch,y,x,None], axis=-1) + # print time.time() - start_time + + # Save them + def color_grid_vis(X, nh, nw, save_path): + # from github.com/Newmu + X = X.transpose(0,2,3,1) + h, w = X[0].shape[:2] + img = np.zeros((h*nh, w*nw, 3)) + for n, x in enumerate(X): + j = n/nw + i = n%nw + img[j*h:j*h+h, i*w:i*w+w, :] = x + imsave(save_path, img) + + print "Saving" + rows = int(np.sqrt(N_SAMPLES)) + while N_SAMPLES % rows != 0: + rows -= 1 + color_grid_vis( + pixels, rows, N_SAMPLES/rows, + 'samples_{}.png'.format(tag) + ) + + # Train! + + if MODE == 'one_level': + prints=[ + ('alpha', alpha), + ('reconst', reconst_cost), + ('kl1', kl_cost_1) + ] + elif MODE == 'two_level': + prints=[ + ('alpha1', alpha1), + ('alpha2', alpha2), + ('reconst', reconst_cost), + ('kl1', kl_cost_1), + ('kl2', kl_cost_2), + ] + + decayed_lr = tf.train.exponential_decay( + LR, + total_iters, + LR_DECAY_AFTER, + LR_DECAY_FACTOR, + staircase=True + ) + + lib.train_loop_2.train_loop( + session=session, + inputs=[total_iters, all_images], + inject_iteration=True, + bn_vars=(bn_is_training, bn_stats_iter), + cost=full_cost, + stop_after=TIMES['stop_after'], + prints=prints, + optimizer=tf.train.AdamOptimizer(decayed_lr), + train_data=train_data, + test_data=dev_data, + callback=generate_and_save_samples, + callback_every=TIMES['callback_every'], + test_every=TIMES['test_every'], + save_checkpoints=True + ) \ No newline at end of file diff --git a/read_output.py b/read_output.py new file mode 100644 index 0000000..ee2ce7b --- /dev/null +++ b/read_output.py @@ -0,0 +1,20 @@ +import json +import numpy as np + +INTERVAL = 1000 +LABELS = ['train cost', 'train kl1', 'train kl2'] + +for label in LABELS: + print "==============================" + print label + + vals = [] + + with open('train_output.ndjson') as f: + for line in f: + line = json.loads(line[:-1]) + if label in line: + vals.append(line[label]) + + for i in xrange(0, len(vals), INTERVAL): + print "{}-{}\t{}".format(i, min(len(vals), i+INTERVAL), np.mean(vals[i:i+INTERVAL])) \ No newline at end of file diff --git a/tflib/__init__.py b/tflib/__init__.py new file mode 100644 index 0000000..7e810c9 --- /dev/null +++ b/tflib/__init__.py @@ -0,0 +1,91 @@ +import numpy as np +import tensorflow as tf + +import locale + +locale.setlocale(locale.LC_ALL, '') + +_params = {} +def param(name, *args, **kwargs): + """ + A wrapper for `tf.Variable` which enables parameter sharing in models. + + Creates and returns theano shared variables similarly to `tf.Variable`, + except if you try to create a param with the same name as a + previously-created one, `param(...)` will just return the old one instead of + making a new one. + + This constructor also adds a `param` attribute to the shared variables it + creates, so that you can easily search a graph for all params. + """ + + if name not in _params: + kwargs['name'] = name + param = tf.Variable(*args, **kwargs) + param.param = True + _params[name] = param + return _params[name] + +def params_with_name(name): + return [p for n,p in _params.items() if name in n] + +def delete_all_params(): + _params.clear() + +# def search(node, critereon): +# """ +# Traverse the Theano graph starting at `node` and return a list of all nodes +# which match the `critereon` function. When optimizing a cost function, you +# can use this to get a list of all of the trainable params in the graph, like +# so: + +# `lib.search(cost, lambda x: hasattr(x, "param"))` +# """ + +# def _search(node, critereon, visited): +# if node in visited: +# return [] +# visited.add(node) + +# results = [] +# if isinstance(node, T.Apply): +# for inp in node.inputs: +# results += _search(inp, critereon, visited) +# else: # Variable node +# if critereon(node): +# results.append(node) +# if node.owner is not None: +# results += _search(node.owner, critereon, visited) +# return results + +# return _search(node, critereon, set()) + +# def print_params_info(params): +# """Print information about the parameters in the given param set.""" + +# params = sorted(params, key=lambda p: p.name) +# values = [p.get_value(borrow=True) for p in params] +# shapes = [p.shape for p in values] +# print "Params for cost:" +# for param, value, shape in zip(params, values, shapes): +# print "\t{0} ({1})".format( +# param.name, +# ",".join([str(x) for x in shape]) +# ) + +# total_param_count = 0 +# for shape in shapes: +# param_count = 1 +# for dim in shape: +# param_count *= dim +# total_param_count += param_count +# print "Total parameter count: {0}".format( +# locale.format("%d", total_param_count, grouping=True) +# ) + +def print_model_settings(locals_): + print "Model settings:" + all_vars = [(k,v) for (k,v) in locals_.items() if (k.isupper() and k!='T')] + all_vars = sorted(all_vars, key=lambda x: x[0]) + for var_name, var_value in all_vars: + print "\t{}: {}".format(var_name, var_value) \ No newline at end of file diff --git a/tflib/lsun_bedrooms.py b/tflib/lsun_bedrooms.py new file mode 100644 index 0000000..1f7db9d --- /dev/null +++ b/tflib/lsun_bedrooms.py @@ -0,0 +1,80 @@ +import sys +sys.modules['theano'] = None + +import numpy as np + +from fuel.datasets.hdf5 import H5PYDataset +from fuel.schemes import ShuffledScheme, SequentialScheme +from fuel.streams import DataStream +# from fuel.transformers.image import RandomFixedSizeCrop + +PATH = '/home/ishaan/data/lsun_bedrooms_2727000_64px.hdf5' + +from scipy.misc import imsave +def color_grid_vis(X, nh, nw, save_path): + # from github.com/Newmu + X = X.transpose(0,2,3,1) + h, w = X[0].shape[:2] + img = np.zeros((h*nh, w*nw, 3)) + for n, x in enumerate(X): + j = n/nw + i = n%nw + img[j*h:j*h+h, i*w:i*w+w, :] = x + imsave(save_path, img) + + +def _make_stream(stream, bs, downsample): + def new_stream(): + if downsample: + result = np.empty((bs, 32, 32, 3), dtype='int32') + else: + result = np.empty((bs, 64, 64, 3), dtype='int32') + for (imb,) in stream.get_epoch_iterator(): + for i, img in enumerate(imb): + if downsample: + a = img[:64:2, :64:2, :] + b = img[:64:2, 1:64:2, :] + c = img[1:64:2, :64:2, :] + d = img[1:64:2, 1:64:2, :] + result[i] = a + result[i] += b + result[i] += c + result[i] += d + result[i] /= 4 + # print (a+b+c+d).dtype + # raise Exception() + # result[i] = (a+b+c+d)/4 + else: + result[i] = img[:64, :64, :] + # print "warning overfit mode" + # color_grid_vis(result.transpose(0,3,1,2)[:,:3,:,:], 2, 2, 'reals.png') + # while True: + yield (result.transpose(0,3,1,2),) + # yield (result.transpose(0,3,1,2)[:,:3,:,:],) + return new_stream + +def load(batch_size=128, downsample=True): + tr_data = H5PYDataset(PATH, which_sets=('train',)) + te_data = H5PYDataset(PATH, which_sets=('valid',)) + + ntrain = tr_data.num_examples + # ntest = te_data.num_examples + nval = te_data.num_examples + + # print "ntrain {}, nval {}".format(ntrain, nval) + + tr_scheme = ShuffledScheme(examples=ntrain, batch_size=batch_size) + tr_stream = DataStream(tr_data, iteration_scheme=tr_scheme) + + # te_scheme = SequentialScheme(examples=ntest, batch_size=batch_size) + # te_stream = DataStream(te_data, iteration_scheme=te_scheme) + + val_scheme = SequentialScheme(examples=nval, batch_size=batch_size) + val_stream = DataStream(tr_data, iteration_scheme=val_scheme) + + return _make_stream(tr_stream, batch_size, downsample), _make_stream(val_stream, batch_size, downsample) + # return ( + # (lambda: tr_stream.get_epoch_iterator()), + # (lambda: val_stream.get_epoch_iterator()), + # # (lambda: te_stream.get_epoch_iterator()) + # ) diff --git a/tflib/mnist.py b/tflib/mnist.py new file mode 100644 index 0000000..78c75fd --- /dev/null +++ b/tflib/mnist.py @@ -0,0 +1,58 @@ +import numpy + +import os +import urllib +import gzip +import cPickle as pickle + +def mnist_generator(data, batch_size, n_labelled): + images, targets = data + + images = images.astype('float32') + targets = targets.astype('int32') + if n_labelled is not None: + labelled = numpy.zeros(len(images), dtype='int32') + labelled[:n_labelled] = 1 + + def get_epoch(): + rng_state = numpy.random.get_state() + numpy.random.shuffle(images) + numpy.random.set_state(rng_state) + numpy.random.shuffle(targets) + + if n_labelled is not None: + numpy.random.set_state(rng_state) + numpy.random.shuffle(labelled) + + image_batches = images.reshape(-1, batch_size, 784) + target_batches = targets.reshape(-1, batch_size) + + if n_labelled is not None: + labelled_batches = labelled.reshape(-1, batch_size) + + for i in xrange(len(image_batches)): + yield (numpy.copy(image_batches[i]), numpy.copy(target_batches[i]), numpy.copy(labelled)) + + else: + + for i in xrange(len(image_batches)): + yield (numpy.copy(image_batches[i]), numpy.copy(target_batches[i])) + + return get_epoch + +def load(batch_size, test_batch_size, n_labelled=None): + filepath = '/tmp/mnist.pkl.gz' + url = 'http://www.iro.umontreal.ca/~lisa/deep/data/mnist/mnist.pkl.gz' + + if not os.path.isfile(filepath): + print "Couldn't find MNIST dataset in /tmp, downloading..." + urllib.urlretrieve(url, filepath) + + with gzip.open('/tmp/mnist.pkl.gz', 'rb') as f: + train_data, dev_data, test_data = pickle.load(f) + + return ( + mnist_generator(train_data, batch_size, n_labelled), + mnist_generator(dev_data, test_batch_size, n_labelled), + mnist_generator(test_data, test_batch_size, n_labelled) + ) \ No newline at end of file diff --git a/tflib/mnist_256.py b/tflib/mnist_256.py new file mode 100644 index 0000000..5e5146e --- /dev/null +++ b/tflib/mnist_256.py @@ -0,0 +1,32 @@ +import tflib.mnist + +import numpy as np + +def discretize(x): + return (x*(256-1e-8)).astype('int32') + +def binarized_generator(generator, include_targets=False, n_labelled=None): + def get_epoch(): + for data in generator(): + if n_labelled is not None: + images, targets, labelled = data + else: + images, targets = data + images = images.reshape((-1, 1, 28, 28)) + images = discretize(images) + if include_targets: + if n_labelled is not None: + yield (images, targets, labelled) + else: + yield (images, targets) + else: + yield (images,) + return get_epoch + +def load(batch_size, test_batch_size, include_targets=False, n_labelled=None): + train_gen, dev_gen, test_gen = tflib.mnist.load(batch_size, test_batch_size, n_labelled) + return ( + binarized_generator(train_gen, include_targets=include_targets, n_labelled=n_labelled), + binarized_generator(dev_gen, include_targets=include_targets, n_labelled=n_labelled), + binarized_generator(test_gen, include_targets=include_targets, n_labelled=n_labelled) + ) \ No newline at end of file diff --git a/tflib/ops/__init__.py b/tflib/ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tflib/ops/batchnorm.py b/tflib/ops/batchnorm.py new file mode 100644 index 0000000..c90d76d --- /dev/null +++ b/tflib/ops/batchnorm.py @@ -0,0 +1,78 @@ +import tflib as lib + +import numpy as np +import tensorflow as tf + +def Batchnorm(name, axes, inputs, is_training=None, stats_iter=None, update_moving_stats=True): + if axes == [0,2,3]: + # Old (working but pretty slow) implementation: + ########## + + # inputs = tf.transpose(inputs, [0,2,3,1]) + + # mean, var = tf.nn.moments(inputs, [0,1,2], keep_dims=False) + # offset = lib.param(name+'.offset', np.zeros(mean.get_shape()[-1], dtype='float32')) + # scale = lib.param(name+'.scale', np.ones(var.get_shape()[-1], dtype='float32')) + # result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-4) + + # return tf.transpose(result, [0,3,1,2]) + + # New (super fast but untested) implementation: + offset = lib.param(name+'.offset', np.zeros(inputs.get_shape()[1], dtype='float32')) + scale = lib.param(name+'.scale', np.ones(inputs.get_shape()[1], dtype='float32')) + + moving_mean = lib.param(name+'.moving_mean', np.zeros(inputs.get_shape()[1], dtype='float32'), trainable=False) + moving_variance = lib.param(name+'.moving_variance', np.ones(inputs.get_shape()[1], dtype='float32'), trainable=False) + + def _fused_batch_norm_training(): + return tf.nn.fused_batch_norm(inputs, scale, offset, epsilon=1e-2, data_format='NCHW') + def _fused_batch_norm_inference(): + # Version which blends in the current item's statistics + batch_size = tf.cast(tf.shape(inputs)[0], 'float32') + mean, var = tf.nn.moments(inputs, [2,3], keep_dims=True) + mean = ((1./batch_size)*mean) + (((batch_size-1.)/batch_size)*moving_mean)[None,:,None,None] + var = ((1./batch_size)*var) + (((batch_size-1.)/batch_size)*moving_variance)[None,:,None,None] + return tf.nn.batch_normalization(inputs, mean, var, offset[None,:,None,None], scale[None,:,None,None], 1e-2), mean, var + + # Standard version + # return tf.nn.fused_batch_norm( + # inputs, + # scale, + # offset, + # epsilon=1e-2, + # mean=moving_mean, + # variance=moving_variance, + # is_training=False, + # data_format='NCHW' + # ) + + if is_training is None: + raise Exception('no is_training') + outputs, batch_mean, batch_var = _fused_batch_norm_training() + else: + outputs, batch_mean, batch_var = tf.cond(is_training, + _fused_batch_norm_training, + _fused_batch_norm_inference) + if update_moving_stats: + no_updates = lambda: outputs + def _force_updates(): + """Internal function forces updates moving_vars if is_training.""" + float_stats_iter = tf.cast(stats_iter, tf.float32) + + update_moving_mean = tf.assign(moving_mean, ((float_stats_iter/(float_stats_iter+1))*moving_mean) + ((1/(float_stats_iter+1))*batch_mean)) + update_moving_variance = tf.assign(moving_variance, ((float_stats_iter/(float_stats_iter+1))*moving_variance) + ((1/(float_stats_iter+1))*batch_var)) + + with tf.control_dependencies([update_moving_mean, update_moving_variance]): + return tf.identity(outputs) + outputs = tf.cond(is_training, _force_updates, no_updates) + + return outputs + else: + # raise Exception('old BN') + # TODO we can probably use nn.fused_batch_norm here too for speedup + mean, var = tf.nn.moments(inputs, axes, keep_dims=True) + offset = lib.param(name+'.offset', np.zeros(mean.get_shape(), dtype='float32')) + scale = lib.param(name+'.scale', np.ones(var.get_shape(), dtype='float32')) + result = tf.nn.batch_normalization(inputs, mean, var, offset, scale, 1e-4) + # lib.debug.print_stats(name, result) + return result \ No newline at end of file diff --git a/tflib/ops/conv2d.py b/tflib/ops/conv2d.py new file mode 100644 index 0000000..43f8f47 --- /dev/null +++ b/tflib/ops/conv2d.py @@ -0,0 +1,108 @@ +import tflib as lib + +import numpy as np +import tensorflow as tf + +_default_weightnorm = False +def enable_default_weightnorm(): + global _default_weightnorm + _default_weightnorm = True + +def Conv2D(name, input_dim, output_dim, filter_size, inputs, he_init=True, mask_type=None, stride=1, weightnorm=None, biases=True, gain=1.): + """ + inputs: tensor of shape (batch size, num channels, height, width) + mask_type: one of None, 'a', 'b' + + returns: tensor of shape (batch size, num channels, height, width) + """ + with tf.name_scope(name) as scope: + + if mask_type is not None: + mask_type, mask_n_channels = mask_type + + mask = np.ones( + (filter_size, filter_size, input_dim, output_dim), + dtype='float32' + ) + center = filter_size // 2 + + # Mask out future locations + # filter shape is (height, width, input channels, output channels) + mask[center+1:, :, :, :] = 0. + mask[center, center+1:, :, :] = 0. + + # Mask out future channels + for i in xrange(mask_n_channels): + for j in xrange(mask_n_channels): + if (mask_type=='a' and i >= j) or (mask_type=='b' and i > j): + mask[ + center, + center, + i::mask_n_channels, + j::mask_n_channels + ] = 0. + + + def uniform(stdev, size): + return np.random.uniform( + low=-stdev * np.sqrt(3), + high=stdev * np.sqrt(3), + size=size + ).astype('float32') + + fan_in = input_dim * filter_size**2 + fan_out = output_dim * filter_size**2 / (stride**2) + + if mask_type is not None: # only approximately correct + fan_in /= 2. + fan_out /= 2. + + if he_init: + filters_stdev = np.sqrt(4./(fan_in+fan_out)) + else: # Normalized init (Glorot & Bengio) + filters_stdev = np.sqrt(2./(fan_in+fan_out)) + + filter_values = uniform( + filters_stdev, + (filter_size, filter_size, input_dim, output_dim) + ) + # print "WARNING IGNORING GAIN" + filter_values *= gain + + filters = lib.param(name+'.Filters', filter_values) + + if weightnorm==None: + weightnorm = _default_weightnorm + if weightnorm: + norm_values = np.sqrt(np.sum(np.square(filter_values), axis=(0,1,2))) + target_norms = lib.param( + name + '.g', + norm_values + ) + with tf.name_scope('weightnorm') as scope: + norms = tf.sqrt(tf.reduce_sum(tf.square(filters), reduction_indices=[0,1,2])) + filters = filters * (target_norms / norms) + + if mask_type is not None: + with tf.name_scope('filter_mask'): + filters = filters * mask + + result = tf.nn.conv2d( + input=inputs, + filter=filters, + strides=[1, 1, stride, stride], + padding='SAME', + data_format='NCHW' + ) + + if biases: + _biases = lib.param( + name+'.Biases', + np.zeros(output_dim, dtype='float32') + ) + + result = tf.nn.bias_add(result, _biases, data_format='NCHW') + + # lib.debug.print_stats(name, result) + + return result \ No newline at end of file diff --git a/tflib/ops/embedding.py b/tflib/ops/embedding.py new file mode 100644 index 0000000..50eafd5 --- /dev/null +++ b/tflib/ops/embedding.py @@ -0,0 +1,11 @@ +import tflib as lib + +import numpy as np +import tensorflow as tf + +def Embedding(name, vocab_size, dim, indices): + embeddings = lib.param( + name+'.EmbeddingMatrix', + np.random.normal(size=(vocab_size, dim)).astype('float32') + ) + return tf.gather(embeddings, indices) \ No newline at end of file diff --git a/tflib/ops/kl_gaussian_gaussian.py b/tflib/ops/kl_gaussian_gaussian.py new file mode 100644 index 0000000..1154ae5 --- /dev/null +++ b/tflib/ops/kl_gaussian_gaussian.py @@ -0,0 +1,18 @@ +import tensorflow as tf + +# def kl_gaussian_gaussian(mu1, sig1, mu2, sig2): +# """ +# (adapted from https://github.com/jych/cle) +# mu1, sig1 = posterior mu and *log* sigma +# mu2, sig2 = prior mu and *log* sigma +# """ +# return 0.5 * (2*sig2 - 2*sig1 + (tf.exp(2*sig1) + (mu1 - mu2)**2) / tf.exp(2*sig2) - 1) + +def kl_gaussian_gaussian(mu1, logsig1, sig1, mu2, logsig2, sig2): + """ + (adapted from https://github.com/jych/cle) + mu1, logsig1, sig2 = posterior mu and *log* sigma + mu2, logsig2, sig2 = prior mu and *log* sigma + """ + with tf.name_scope('kl_gaussian_gaussian') as scope: + return 0.5 * (2*logsig2 - 2*logsig1 + (sig1**2 + (mu1 - mu2)**2) / sig2**2 - 1) \ No newline at end of file diff --git a/tflib/ops/kl_unit_gaussian.py b/tflib/ops/kl_unit_gaussian.py new file mode 100644 index 0000000..4177797 --- /dev/null +++ b/tflib/ops/kl_unit_gaussian.py @@ -0,0 +1,9 @@ +import tensorflow as tf + +def kl_unit_gaussian(mu, log_sigma, sigma): + """ + KL divergence from a unit Gaussian prior + based on yaost, via Alec + """ + with tf.name_scope('kl_unit_gaussian') as scope: + return -0.5 * (1 + 2 * log_sigma - mu**2 - sigma**2) \ No newline at end of file diff --git a/tflib/ops/linear.py b/tflib/ops/linear.py new file mode 100644 index 0000000..4e9ace8 --- /dev/null +++ b/tflib/ops/linear.py @@ -0,0 +1,134 @@ +import tflib as lib + +import numpy as np +import tensorflow as tf + +_default_weightnorm = False +def enable_default_weightnorm(): + global _default_weightnorm + _default_weightnorm = True + +def disable_default_weightnorm(): + global _default_weightnorm + _default_weightnorm = False + +def Linear( + name, + input_dim, + output_dim, + inputs, + biases=True, + initialization=None, + weightnorm=None, + gain=1. + ): + """ + initialization: None, `lecun`, `he`, `orthogonal`, `("uniform", range)` + """ + with tf.name_scope(name) as scope: + + def uniform(stdev, size): + return np.random.uniform( + low=-stdev * np.sqrt(3), + high=stdev * np.sqrt(3), + size=size + ).astype('float32') + + if initialization == 'lecun' or \ + (initialization == None):# and input_dim != output_dim): + # disabling orth. init for now because it's too slow + weight_values = uniform( + np.sqrt(1./input_dim), + (input_dim, output_dim) + ) + + elif initialization == 'glorot': + + weight_values = uniform( + np.sqrt(2./(input_dim+output_dim)), + (input_dim, output_dim) + ) + + elif initialization == 'he': + + weight_values = uniform( + np.sqrt(2./input_dim), + (input_dim, output_dim) + ) + + elif initialization == 'glorot_he': + + weight_values = uniform( + np.sqrt(4./(input_dim+output_dim)), + (input_dim, output_dim) + ) + + elif initialization == 'orthogonal' or \ + (initialization == None and input_dim == output_dim): + + # From lasagne + def sample(shape): + if len(shape) < 2: + raise RuntimeError("Only shapes of length 2 or more are " + "supported.") + flat_shape = (shape[0], np.prod(shape[1:])) + # TODO: why normal and not uniform? + a = np.random.normal(0.0, 1.0, flat_shape) + u, _, v = np.linalg.svd(a, full_matrices=False) + # pick the one with the correct shape + q = u if u.shape == flat_shape else v + q = q.reshape(shape) + return q.astype('float32') + weight_values = sample((input_dim, output_dim)) + + elif initialization[0] == 'uniform': + + weight_values = np.random.uniform( + low=-initialization[1], + high=initialization[1], + size=(input_dim, output_dim) + ).astype('float32') + + else: + + raise Exception('Invalid initialization!') + + weight_values *= gain + + weight = lib.param( + name + '.W', + weight_values + ) + + if weightnorm==None: + weightnorm = _default_weightnorm + if weightnorm: + norm_values = np.sqrt(np.sum(np.square(weight_values), axis=0)) + # norm_values = np.linalg.norm(weight_values, axis=0) + + target_norms = lib.param( + name + '.g', + norm_values + ) + + with tf.name_scope('weightnorm') as scope: + norms = tf.sqrt(tf.reduce_sum(tf.square(weight), reduction_indices=[0])) + weight = weight * (target_norms / norms) + + if inputs.get_shape().ndims == 2: + result = tf.matmul(inputs, weight) + else: + reshaped_inputs = tf.reshape(inputs, [-1, input_dim]) + result = tf.matmul(reshaped_inputs, weight) + result = tf.reshape(result, tf.pack(tf.unpack(tf.shape(inputs))[:-1] + [output_dim])) + + if biases: + result = tf.nn.bias_add( + result, + lib.param( + name + '.b', + np.zeros((output_dim,), dtype='float32') + ) + ) + + return result \ No newline at end of file diff --git a/tflib/small_imagenet.py b/tflib/small_imagenet.py new file mode 100644 index 0000000..ac48c63 --- /dev/null +++ b/tflib/small_imagenet.py @@ -0,0 +1,34 @@ +import numpy as np +import scipy.misc +import time + +def make_generator(path, n_files, batch_size): + epoch_count = [1] + def get_epoch(): + images = np.zeros((batch_size, 3, 64, 64), dtype='int32') + files = range(n_files) + random_state = np.random.RandomState(epoch_count[0]) + random_state.shuffle(files) + epoch_count[0] += 1 + for n, i in enumerate(files): + image = scipy.misc.imread("{}/{}.png".format(path, str(i+1).zfill(len(str(n_files))))) + images[n % batch_size] = image.transpose(2,0,1) + if n > 0 and n % batch_size == 0: + yield (images,) + return get_epoch + +def load(batch_size): + return ( + make_generator('/home/ishaan/data/imagenet64/train_64x64', 1281149, batch_size), + # make_generator('/home/ishaan/data/imagenet64/valid_64x64', 10000, batch_size)# shorter validation set for debugging + make_generator('/home/ishaan/data/imagenet64/valid_64x64', 49999, batch_size) + ) + +if __name__ == '__main__': + train_gen, valid_gen = load(64) + t0 = time.time() + for i, batch in enumerate(train_gen(), start=1): + print "{}\t{}".format(str(time.time() - t0), batch[0][0,0,0,0]) + if i == 1000: + break + t0 = time.time() \ No newline at end of file diff --git a/tflib/train_loop_2.py b/tflib/train_loop_2.py new file mode 100644 index 0000000..5ea1371 --- /dev/null +++ b/tflib/train_loop_2.py @@ -0,0 +1,280 @@ +import tflib as lib + +import numpy as np +import tensorflow as tf + +import collections +import cPickle as pickle +import json +import locale +import os +import time +import shutil + +locale.setlocale(locale.LC_ALL, '') + +PARAMS_FILE = 'params.ckpt' +TRAIN_LOOP_FILE = 'train_loop.pkl' +TRAIN_OUTPUT_FILE = 'train_output.ndjson' + +def train_loop( + session, + inputs, + cost, + train_data, + stop_after, + prints=[], + test_data=None, + test_every=None, + callback=None, + callback_every=None, + inject_iteration=False, + bn_vars=None, + bn_stats_iters=1000, + before_test=None, + optimizer=tf.train.AdamOptimizer(), + save_every=1000, + save_checkpoints=False + ): + + prints = [('cost', cost)] + prints + + grads_and_vars = optimizer.compute_gradients( + cost, + colocate_gradients_with_ops=True + ) + + print "Params:" + total_param_count = 0 + for g, v in grads_and_vars: + shape = v.get_shape() + shape_str = ",".join([str(x) for x in v.get_shape()]) + + param_count = 1 + for dim in shape: + param_count *= int(dim) + total_param_count += param_count + + if g == None: + print "\t{} ({}) [no grad!]".format(v.name, shape_str) + else: + print "\t{} ({})".format(v.name, shape_str) + print "Total param count: {}".format( + locale.format("%d", total_param_count, grouping=True) + ) + + # for i in xrange(len(grads_and_vars)): + # g, v = grads_and_vars[i] + # if g == None: + # grads_and_vars[i] = (tf.zeros_like(v), v) + # else: + # grads_and_vars[i] = (tf.clip_by_value(g, -5., 5.), v) + + grads = [g for g,v in grads_and_vars] + _vars = [v for g,v in grads_and_vars] + + global_norm = tf.global_norm(grads) + prints = prints + [('gradnorm', global_norm)] + + grads, global_norm = tf.clip_by_global_norm(grads, 5.0, use_norm=global_norm) + grads_and_vars = zip(grads, _vars) + + train_op = optimizer.apply_gradients(grads_and_vars) + + def train_fn(input_vals): + feed_dict = {sym:real for sym, real in zip(inputs, input_vals)} + if bn_vars is not None: + feed_dict[bn_vars[0]] = True + feed_dict[bn_vars[1]] = 0 + return session.run( + [p[1] for p in prints] + [train_op], + feed_dict=feed_dict + )[:-1] + + def bn_stats_fn(input_vals, iter_): + feed_dict = {sym:real for sym, real in zip(inputs, input_vals)} + feed_dict[bn_vars[0]] = True + feed_dict[bn_vars[1]] = iter_ + return session.run( + [p[1] for p in prints], + feed_dict=feed_dict + ) + + def eval_fn(input_vals): + feed_dict = {sym:real for sym, real in zip(inputs, input_vals)} + if bn_vars is not None: + feed_dict[bn_vars[0]] = False + feed_dict[bn_vars[1]] = 0 + return session.run( + [p[1] for p in prints], + feed_dict=feed_dict + ) + + _vars = { + 'epoch': 0, + 'iteration': 0, + 'seconds': 0., + 'last_callback': 0, + 'last_test': 0 + } + + train_generator = train_data() + + saver = tf.train.Saver(write_version=tf.train.SaverDef.V2) + + if os.path.isfile(TRAIN_LOOP_FILE): + print "Resuming interrupted train loop session" + with open(TRAIN_LOOP_FILE, 'r') as f: + _vars = pickle.load(f) + saver.restore(session, os.getcwd()+"/"+PARAMS_FILE) + + print "Fast-fowarding dataset generator" + dataset_iters = 0 + while dataset_iters < _vars['iteration']: + try: + train_generator.next() + except StopIteration: + train_generator = train_data() + train_generator.next() + dataset_iters += 1 + else: + print "Initializing variables..." + session.run(tf.initialize_all_variables()) + print "done!" + + train_output_entries = [[]] + + def log(outputs, test, _vars, extra_things_to_print): + entry = collections.OrderedDict() + for key in ['epoch', 'iteration', 'seconds']: + entry[key] = _vars[key] + for i,p in enumerate(prints): + if test: + entry['test '+p[0]] = outputs[i] + else: + entry['train '+p[0]] = outputs[i] + + train_output_entries[0].append(entry) + + to_print = entry.items() + to_print.extend(extra_things_to_print) + print_str = "" + for k,v in to_print: + if isinstance(v, int): + print_str += "{}:{}\t".format(k,v) + else: + print_str += "{}:{:.4f}\t".format(k,v) + print print_str[:-1] # omit the last \t + + def save_train_output_and_params(iteration): + print "Saving things..." + + if save_checkpoints: + # Saving weights takes a while. There's a risk of interruption during + # this time, leaving the weights file corrupt. Oh well. + + start_time = time.time() + saver.save(session, PARAMS_FILE) + print "saver.save time: {}".format(time.time() - start_time) + + start_time = time.time() + with open(TRAIN_LOOP_FILE, 'w') as f: + pickle.dump(_vars, f) + print "_vars pickle dump time: {}".format(time.time() - start_time) + + start_time = time.time() + with open(TRAIN_OUTPUT_FILE, 'a') as f: + for entry in train_output_entries[0]: + for k,v in entry.items(): + if isinstance(v, np.generic): + entry[k] = np.asscalar(v) + f.write(json.dumps(entry) + "\n") + print "ndjson write time: {}".format(time.time() - start_time) + + train_output_entries[0] = [] + + while True: + + if _vars['iteration'] == stop_after: + save_train_output_and_params(_vars['iteration']) + + print "Done!" + + try: # This only matters on Ishaan's computer + import experiment_tools + experiment_tools.send_sms("done!") + except ImportError: + pass + + break + + data_load_start_time = time.time() + try: + input_vals = train_generator.next() + except StopIteration: + train_generator = train_data() + input_vals = train_generator.next() + train_generator.next() + _vars['epoch'] += 1 + data_load_time = time.time() - data_load_start_time + + if inject_iteration: + input_vals = [np.int32(_vars['iteration'])] + list(input_vals) + + start_time = time.time() + outputs = train_fn(input_vals) + run_time = time.time() - start_time + + _vars['seconds'] += run_time + _vars['iteration'] += 1 + + log(outputs, False, _vars, [('iter time', run_time), ('data time', data_load_time)]) + + if ((test_data is not None) and _vars['iteration'] % test_every == (test_every-1)) or ((callback is not None) and _vars['iteration'] % callback_every == (callback_every-1)): + if inject_iteration: + + if bn_vars is not None: # If using batchnorm, run over a bunch of training data first to make the running-average stats good. + _train_gen = train_data() + for i in xrange(bn_stats_iters): + try: + bn_stats_fn([np.int32(_vars['iteration'])] + list(_train_gen.next()), i) + except StopIteration: + _train_gen = train_data() + bn_stats_fn([np.int32(_vars['iteration'])] + list(_train_gen.next()), i) + + else: + + if bn_vars is not None: # If using batchnorm, run over a bunch of training data first to make the running-average stats good. + _train_gen = train_data() + for i in xrange(bn_stats_iters): + try: + bn_stats_fn(list(_train_gen.next()), i) + except StopIteration: + _train_gen = train_data() + bn_stats_fn(list(_train_gen.next()), i) + + + if (test_data is not None) and _vars['iteration'] % test_every == (test_every-1): + if inject_iteration: + + test_outputs = [ + eval_fn([np.int32(_vars['iteration'])] + list(input_vals)) + for input_vals in test_data() + ] + + else: + + test_outputs = [ + eval_fn(list(input_vals)) + for input_vals in test_data() + ] + mean_test_outputs = np.array(test_outputs).mean(axis=0) + + log(mean_test_outputs, True, _vars, []) + + if (callback is not None) and _vars['iteration'] % callback_every == (callback_every-1): + tag = "iter{}".format(_vars['iteration']) + callback(tag) + + if _vars['iteration'] % save_every == (save_every-1): + save_train_output_and_params(_vars['iteration']) \ No newline at end of file