diff --git a/pixelvae.py b/pixelvae.py index cc5de1e..0511b27 100644 --- a/pixelvae.py +++ b/pixelvae.py @@ -6,7 +6,7 @@ import os, sys sys.path.append(os.getcwd()) -N_GPUS = 1 +N_GPUS = 4 try: # This only matters on Ishaan's computer import experiment_tools @@ -86,7 +86,7 @@ HEIGHT = 28 WIDTH = 28 - # These aren't actually (typically) used for one-level models but some parts + # These aren't actually used for one-level models but some parts # of the code still depend on them being defined. LATENT_DIM_1 = 64 LATENTS1_HEIGHT = 7 @@ -271,6 +271,63 @@ LATENTS1_WIDTH = 16 LATENTS1_HEIGHT = 16 +elif SETTINGS=='64px_big_onelevel': + + # two_level uses Enc1/Dec1 for the bottom level, Enc2/Dec2 for the top level + # one_level uses EncFull/DecFull for the bottom (and only) level + MODE = 'one_level' + + # Whether to treat pixel inputs to the model as real-valued (as in the + # original PixelCNN) or discrete (gets better likelihoods). + EMBED_INPUTS = True + + # Turn on/off the bottom-level PixelCNN in Dec1/DecFull + PIXEL_LEVEL_PIXCNN = True + HIGHER_LEVEL_PIXCNN = True + + DIM_EMBED = 16 + DIM_PIX_1 = 384 + DIM_0 = 192 + DIM_1 = 256 + DIM_2 = 512 + DIM_3 = 512 + DIM_4 = 512 + LATENT_DIM_2 = 512 + + ALPHA1_ITERS = 50000 + ALPHA2_ITERS = 50000 + KL_PENALTY = 1.0 + BETA_ITERS = 1000 + + # In Dec2, we break each spatial location into N blocks (analogous to channels + # in the original PixelCNN) and model each spatial location autoregressively + # as P(x)=P(x0)*P(x1|x0)*P(x2|x0,x1)... In my experiments values of N > 1 + # actually hurt performance. Unsure why; might be a bug. + PIX_2_N_BLOCKS = 1 + + TIMES = { + 'test_every': 10000, + 'stop_after': 400000, + 'callback_every': 50000 + } + LR = 1e-3 + + LR_DECAY_AFTER = 180000 + LR_DECAY_FACTOR = 0.5 + + BATCH_SIZE = 48 + N_CHANNELS = 3 + HEIGHT = 64 + WIDTH = 64 + + # These aren't actually used for one-level models but some parts + # of the code still depend on them being defined. + LATENT_DIM_1 = 64 + LATENTS1_HEIGHT = 7 + LATENTS1_WIDTH = 7 + + + if DATASET == 'mnist_256': train_data, dev_data, test_data = lib.mnist_256.load(BATCH_SIZE, BATCH_SIZE) elif DATASET == 'lsun_32': @@ -294,8 +351,8 @@ all_images = tf.placeholder(tf.int32, shape=[None, N_CHANNELS, HEIGHT, WIDTH], name='all_images') all_latents1 = tf.placeholder(tf.float32, shape=[None, LATENT_DIM_1, LATENTS1_HEIGHT, LATENTS1_WIDTH], name='all_latents1') - split_images = tf.split(0, len(DEVICES), all_images) - split_latents1 = tf.split(0, len(DEVICES), all_latents1) + split_images = tf.split(all_images, len(DEVICES), axis=0) + split_latents1 = tf.split(all_images, len(DEVICES), axis=0) tower_cost = [] tower_outputs1_sample = [] @@ -442,7 +499,7 @@ def Dec1(latents, images): output /= 2 # Warning! Because of the masked convolutions it's very important that masked_images comes first in this concat - output = tf.concat(1, [masked_images, output]) + output = tf.concat([masked_images, output], axis=1) if WIDTH == 64: output = ResidualBlock('Dec1.Pix2Res', input_dim=2*DIM_0, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) @@ -518,7 +575,7 @@ def Dec2(latents, targets): # Make the variance of output and masked_targets roughly match output /= 2 - output = tf.concat(1, [masked_targets, output]) + output = tf.concat([masked_targets, output], axis=1) if LATENTS1_WIDTH == 16: output = ResidualBlock('Dec2.Pix2Res', input_dim=2*DIM_2, output_dim=DIM_PIX_2, filter_size=3, mask_type=('b', PIX_2_N_BLOCKS), he_init=True, inputs=output) @@ -538,64 +595,104 @@ def Dec2(latents, targets): return output - # Really only for MNIST. Will require modification for other datasets. + # Only for 64px_big_onelevel and MNIST. Needs modification for others. def EncFull(images): output = images - if EMBED_INPUTS: - output = lib.ops.conv2d.Conv2D('EncFull.Input', input_dim=N_CHANNELS*DIM_EMBED, output_dim=DIM_1, filter_size=1, inputs=output, he_init=False) + if WIDTH == 64: + if EMBED_INPUTS: + output = lib.ops.conv2d.Conv2D('EncFull.Input', input_dim=N_CHANNELS*DIM_EMBED, output_dim=DIM_0, filter_size=1, inputs=output, he_init=False) + else: + output = lib.ops.conv2d.Conv2D('EncFull.Input', input_dim=N_CHANNELS, output_dim=DIM_0, filter_size=1, inputs=output, he_init=False) + + output = ResidualBlock('EncFull.Res1', input_dim=DIM_0, output_dim=DIM_0, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('EncFull.Res2', input_dim=DIM_0, output_dim=DIM_1, filter_size=3, resample='down', inputs=output) + output = ResidualBlock('EncFull.Res3', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('EncFull.Res4', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('EncFull.Res5', input_dim=DIM_1, output_dim=DIM_2, filter_size=3, resample='down', inputs=output) + output = ResidualBlock('EncFull.Res6', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('EncFull.Res7', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('EncFull.Res8', input_dim=DIM_2, output_dim=DIM_3, filter_size=3, resample='down', inputs=output) + output = ResidualBlock('EncFull.Res9', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('EncFull.Res10', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('EncFull.Res11', input_dim=DIM_3, output_dim=DIM_4, filter_size=3, resample='down', inputs=output) + output = ResidualBlock('EncFull.Res12', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('EncFull.Res13', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, resample=None, inputs=output) + output = tf.reshape(output, [-1, 4*4*DIM_4]) + output = lib.ops.linear.Linear('EncFull.Output', input_dim=4*4*DIM_4, output_dim=2*LATENT_DIM_2, initialization='glorot', inputs=output) else: - output = lib.ops.conv2d.Conv2D('EncFull.Input', input_dim=N_CHANNELS, output_dim=DIM_1, filter_size=1, inputs=output, he_init=False) - - output = ResidualBlock('EncFull.Res1', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample=None, inputs=output) - output = ResidualBlock('EncFull.Res2', input_dim=DIM_1, output_dim=DIM_2, filter_size=3, resample='down', inputs=output) - output = ResidualBlock('EncFull.Res3', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) - output = ResidualBlock('EncFull.Res4', input_dim=DIM_2, output_dim=DIM_3, filter_size=3, resample='down', inputs=output) - output = ResidualBlock('EncFull.Res5', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, inputs=output) - output = ResidualBlock('EncFull.Res6', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, inputs=output) + if EMBED_INPUTS: + output = lib.ops.conv2d.Conv2D('EncFull.Input', input_dim=N_CHANNELS*DIM_EMBED, output_dim=DIM_1, filter_size=1, inputs=output, he_init=False) + else: + output = lib.ops.conv2d.Conv2D('EncFull.Input', input_dim=N_CHANNELS, output_dim=DIM_1, filter_size=1, inputs=output, he_init=False) - output = tf.reduce_mean(output, reduction_indices=[2,3]) - output = lib.ops.linear.Linear('EncFull.Output', input_dim=DIM_3, output_dim=2*LATENT_DIM_2, initialization='glorot', inputs=output) + output = ResidualBlock('EncFull.Res1', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('EncFull.Res2', input_dim=DIM_1, output_dim=DIM_2, filter_size=3, resample='down', inputs=output) + output = ResidualBlock('EncFull.Res3', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('EncFull.Res4', input_dim=DIM_2, output_dim=DIM_3, filter_size=3, resample='down', inputs=output) + output = ResidualBlock('EncFull.Res5', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, inputs=output) + output = ResidualBlock('EncFull.Res6', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, inputs=output) + output = tf.reduce_mean(output, reduction_indices=[2,3]) + output = lib.ops.linear.Linear('EncFull.Output', input_dim=DIM_3, output_dim=2*LATENT_DIM_2, initialization='glorot', inputs=output) return output - # Really only for MNIST. Will require modification for other datasets. + # Only for 64px_big_onelevel and MNIST. Needs modification for others. def DecFull(latents, images): output = tf.clip_by_value(latents, -50., 50.) - output = lib.ops.linear.Linear('DecFull.Input', input_dim=LATENT_DIM_2, output_dim=DIM_3, initialization='glorot', inputs=output) - output = tf.reshape(tf.tile(tf.reshape(output, [-1, DIM_3, 1]), [1, 1, 49]), [-1, DIM_3, 7, 7]) + if WIDTH == 64: + output = lib.ops.linear.Linear('DecFull.Input', input_dim=LATENT_DIM_2, output_dim=4*4*DIM_4, initialization='glorot', inputs=output) + output = tf.reshape(output, [-1, DIM_4, 4, 4]) + output = ResidualBlock('DecFull.Res2', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res3', input_dim=DIM_4, output_dim=DIM_4, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res4', input_dim=DIM_4, output_dim=DIM_3, filter_size=3, resample='up', he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res5', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res6', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res7', input_dim=DIM_3, output_dim=DIM_2, filter_size=3, resample='up', he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res8', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res9', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res10', input_dim=DIM_2, output_dim=DIM_1, filter_size=3, resample='up', he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res11', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res12', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res13', input_dim=DIM_1, output_dim=DIM_0, filter_size=3, resample='up', he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res14', input_dim=DIM_0, output_dim=DIM_0, filter_size=3, resample=None, he_init=True, inputs=output) + else: + output = lib.ops.linear.Linear('DecFull.Input', input_dim=LATENT_DIM_2, output_dim=DIM_3, initialization='glorot', inputs=output) + output = tf.reshape(tf.tile(tf.reshape(output, [-1, DIM_3, 1]), [1, 1, 49]), [-1, DIM_3, 7, 7]) + output = ResidualBlock('DecFull.Res2', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res3', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res4', input_dim=DIM_3, output_dim=DIM_2, filter_size=3, resample='up', he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res5', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res6', input_dim=DIM_2, output_dim=DIM_1, filter_size=3, resample='up', he_init=True, inputs=output) + output = ResidualBlock('DecFull.Res7', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample=None, he_init=True, inputs=output) - output = ResidualBlock('DecFull.Res2', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, he_init=True, inputs=output) - output = ResidualBlock('DecFull.Res3', input_dim=DIM_3, output_dim=DIM_3, filter_size=3, resample=None, he_init=True, inputs=output) - output = ResidualBlock('DecFull.Res4', input_dim=DIM_3, output_dim=DIM_2, filter_size=3, resample='up', he_init=True, inputs=output) - output = ResidualBlock('DecFull.Res5', input_dim=DIM_2, output_dim=DIM_2, filter_size=3, resample=None, he_init=True, inputs=output) - output = ResidualBlock('DecFull.Res6', input_dim=DIM_2, output_dim=DIM_1, filter_size=3, resample='up', he_init=True, inputs=output) - output = ResidualBlock('DecFull.Res7', input_dim=DIM_1, output_dim=DIM_1, filter_size=3, resample=None, he_init=True, inputs=output) + if WIDTH == 64: + dim = DIM_0 + else: + dim = DIM_1 if PIXEL_LEVEL_PIXCNN: if EMBED_INPUTS: - masked_images = lib.ops.conv2d.Conv2D('DecFull.Pix1', input_dim=N_CHANNELS*DIM_EMBED, output_dim=DIM_1, filter_size=5, inputs=images, mask_type=('a', N_CHANNELS), he_init=False) + masked_images = lib.ops.conv2d.Conv2D('DecFull.Pix1', input_dim=N_CHANNELS*DIM_EMBED, output_dim=dim, filter_size=5, inputs=images, mask_type=('a', N_CHANNELS), he_init=False) else: - masked_images = lib.ops.conv2d.Conv2D('DecFull.Pix1', input_dim=N_CHANNELS, output_dim=DIM_1, filter_size=5, inputs=images, mask_type=('a', N_CHANNELS), he_init=False) + masked_images = lib.ops.conv2d.Conv2D('DecFull.Pix1', input_dim=N_CHANNELS, output_dim=dim, filter_size=5, inputs=images, mask_type=('a', N_CHANNELS), he_init=False) # Warning! Because of the masked convolutions it's very important that masked_images comes first in this concat + output = tf.concat([masked_images, output], axis=1) - output = tf.concat(1, [masked_images, output]) - - # output = ResidualBlock('DecFull.Pix2Res', input_dim=2*DIM_1, output_dim=DIM_PIX_1, filter_size=1, mask_type=('b', N_CHANNELS), inputs=output) - - output = ResidualBlock('DecFull.Pix2Res', input_dim=2*DIM_1, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) + output = ResidualBlock('DecFull.Pix2Res', input_dim=2*dim, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) output = ResidualBlock('DecFull.Pix3Res', input_dim=DIM_PIX_1, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) output = ResidualBlock('DecFull.Pix4Res', input_dim=DIM_PIX_1, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) - output = ResidualBlock('DecFull.Pix5Res', input_dim=DIM_PIX_1, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) + if WIDTH != 64: + output = ResidualBlock('DecFull.Pix5Res', input_dim=DIM_PIX_1, output_dim=DIM_PIX_1, filter_size=3, mask_type=('b', N_CHANNELS), inputs=output) output = lib.ops.conv2d.Conv2D('Dec1.Out', input_dim=DIM_PIX_1, output_dim=256*N_CHANNELS, filter_size=1, mask_type=('b', N_CHANNELS), he_init=False, inputs=output) else: - output = lib.ops.conv2d.Conv2D('Dec1.Out', input_dim=DIM_1, output_dim=256*N_CHANNELS, filter_size=1, he_init=False, inputs=output) + output = lib.ops.conv2d.Conv2D('Dec1.Out', input_dim=dim, output_dim=256*N_CHANNELS, filter_size=1, he_init=False, inputs=output) return tf.transpose( tf.reshape(output, [-1, 256, N_CHANNELS, HEIGHT, WIDTH]), @@ -603,7 +700,7 @@ def DecFull(latents, images): ) def split(mu_and_logsig): - mu, logsig = tf.split(1, 2, mu_and_logsig) + mu, logsig = tf.split(mu_and_logsig, 2, axis=1) sig = 0.5 * (tf.nn.softsign(logsig)+1) logsig = tf.log(sig) return mu, logsig, sig @@ -641,8 +738,8 @@ def clamp_logsig_and_sig(logsig, sig): reconst_cost = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( - tf.reshape(outputs1, [-1, 256]), - tf.reshape(images, [-1]) + logits=tf.reshape(outputs1, [-1, 256]), + labels=tf.reshape(images, [-1]) ) ) @@ -690,8 +787,8 @@ def clamp_logsig_and_sig(logsig, sig): reconst_cost = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( - tf.reshape(outputs1, [-1, 256]), - tf.reshape(images, [-1]) + logits=tf.reshape(outputs1, [-1, 256]), + labels=tf.reshape(images, [-1]) ) ) @@ -745,11 +842,11 @@ def clamp_logsig_and_sig(logsig, sig): tower_outputs1_sample.append(outputs1_sample) full_cost = tf.reduce_mean( - tf.concat(0, [tf.expand_dims(x, 0) for x in tower_cost]), 0 + tf.concat([tf.expand_dims(x, 0) for x in tower_cost], axis=0), 0 ) if MODE == 'two_level': - full_outputs1_sample = tf.concat(0, tower_outputs1_sample) + full_outputs1_sample = tf.concat(tower_outputs1_sample, axis=0) # Sampling @@ -758,7 +855,7 @@ def clamp_logsig_and_sig(logsig, sig): ch_sym = tf.placeholder(tf.int32, shape=None) y_sym = tf.placeholder(tf.int32, shape=None) x_sym = tf.placeholder(tf.int32, shape=None) - logits = tf.reshape(tf.slice(outputs1, tf.pack([0, ch_sym, y_sym, x_sym, 0]), tf.pack([-1, 1, 1, 1, -1])), [-1, 256]) + logits = tf.reshape(tf.slice(outputs1, tf.stack([0, ch_sym, y_sym, x_sym, 0]), tf.stack([-1, 1, 1, 1, -1])), [-1, 256]) dec1_fn_out = tf.multinomial(logits, 1)[:, 0] def dec1_fn(_latents, _targets, _ch, _y, _x): return session.run(dec1_fn_out, feed_dict={latents1: _latents, images: _targets, ch_sym: _ch, y_sym: _y, x_sym: _x, total_iters: 99999, bn_is_training: False, bn_stats_iter:0}) @@ -813,7 +910,7 @@ def dec2_fn(_latents, _targets): ch_sym = tf.placeholder(tf.int32, shape=None) y_sym = tf.placeholder(tf.int32, shape=None) x_sym = tf.placeholder(tf.int32, shape=None) - logits_sym = tf.reshape(tf.slice(full_outputs1_sample, tf.pack([0, ch_sym, y_sym, x_sym, 0]), tf.pack([-1, 1, 1, 1, -1])), [-1, 256]) + logits_sym = tf.reshape(tf.slice(full_outputs1_sample, tf.stack([0, ch_sym, y_sym, x_sym, 0]), tf.stack([-1, 1, 1, 1, -1])), [-1, 256]) def dec1_logits_fn(_latents, _targets, _ch, _y, _x): return session.run(logits_sym,