forked from ericjang/draw
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Eric Jang
committed
Feb 22, 2016
1 parent
a043942
commit 7b3ec04
Showing
7 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# draw | ||
|
||
TensorFlow implementation of [DRAW: A Recurrent Neural Network For Image Generation](http://arxiv.org/pdf/1502.04623.pdf) on the MNIST generation task. | ||
|
||
For a gentle walkthrough through the paper and implementation, see the writeup here: [https://evjang/articles/draw](http://evjang/articles/draw). | ||
|
||
| With Attention | Without Attention | | ||
| ------------- | ------------- | | ||
| ![AttnGIF](img/mnist_attn.gif) | ![NoAttnGIF](img/mnist_noattn.gif) | | ||
|
||
Although open-source implementations of this paper already exist (see links below), this implementation focuses on simplicity and ease of understanding. I tried to make the code resemble the raw equations as closely as posible. | ||
|
||
## Usage | ||
|
||
`python draw.py --data_dir=/tmp/draw` downloads the binarized MNIST dataset to /tmp/draw/mnist and trains the DRAW model with attention enabled for both reading and writing. After training, output data is written to `/tmp/draw/draw_data.npy` | ||
|
||
You can visualize the results by running the script `python plot_data.py <prefix> <output_data>` | ||
|
||
For example, | ||
|
||
`python fubar /tmp/draw/draw_data.npy` | ||
|
||
To run training without attention, do: | ||
|
||
`python draw.py --working_dir=/tmp/draw --read_attn=False --write_attn=False` | ||
|
||
## Restoring from Pre-trained Model | ||
|
||
Instead of training from scratch, you can load pre-trained weights by uncommenting the following line in `draw.py` and editing the path to your checkpoint file as needed. Save electricity! | ||
|
||
```python | ||
saver.restore(sess, "/tmp/draw/drawmodel.ckpt") | ||
``` | ||
|
||
This git repository contains the following pre-trained in the `data/` folder: | ||
|
||
| Filename | Description | | ||
| ------------- | ------------- | | ||
| draw_data_attn.npy | Training outputs for DRAW with attention | | ||
| drawmodel_attn.ckpt | Saved weights for DRAW with attention | | ||
| draw_data_noattn.npy | Training outputs for DRAW without attention | | ||
| drawmodel_noattn.ckpt | Saved weights for DRAW without attention | | ||
|
||
These were trained for 10000 iterations with minibatch size=100 on a GTX 970 GPU. | ||
|
||
## Useful Resources | ||
|
||
- https://github.com/vivanov879/draw | ||
- https://github.com/jbornschein/draw | ||
- https://github.com/ikostrikov/TensorFlow-VAE-GAN-DRAW (wish I had found this earlier) | ||
- [Video Lecture on Variational Autoencoders and Image Generation]( https://www.youtube.com/watch?v=P78QYjWh5sM&list=PLE6Wd9FR--EfW8dtjAuPoTuPcqmOV53Fu&index=3) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
#!/usr/bin/env python | ||
|
||
"""" | ||
Simple implementation of http://arxiv.org/pdf/1502.04623v2.pdf in TensorFlow | ||
Example Usage: | ||
python draw.py --data_dir=/tmp/draw --read_attn=True --write_attn=True | ||
Author: Eric Jang | ||
""" | ||
|
||
import tensorflow as tf | ||
from tensorflow.models.rnn.rnn_cell import LSTMCell | ||
from tensorflow.examples.tutorials import mnist | ||
import numpy as np | ||
import os | ||
|
||
tf.flags.DEFINE_string("data_dir", "", "") | ||
tf.flags.DEFINE_boolean("read_attn", True, "enable attention for reader") | ||
tf.flags.DEFINE_boolean("write_attn",True, "enable attention for writer") | ||
FLAGS = tf.flags.FLAGS | ||
|
||
## MODEL PARAMETERS ## | ||
|
||
A,B = 28,28 # image width,height | ||
img_size = B*A # the canvas size | ||
enc_size = 256 # number of hidden units / output size in LSTM | ||
dec_size = 256 | ||
read_n = 5 # read glimpse grid width/height | ||
write_n = 5 # write glimpse grid width/height | ||
read_size = 2*read_n*read_n if FLAGS.read_attn else 2*img_size | ||
write_size = write_n*write_n if FLAGS.write_attn else img_size | ||
z_size=10 # QSampler output size | ||
T=10 # MNIST generation sequence length | ||
batch_size=100 # training minibatch size | ||
train_iters=10000 | ||
learning_rate=1e-3 # learning rate for optimizer | ||
eps=1e-8 # epsilon for numerical stability | ||
|
||
## BUILD MODEL ## | ||
|
||
DO_SHARE=None # workaround for variable_scope(reuse=True) | ||
|
||
x = tf.placeholder(tf.float32,shape=(batch_size,img_size)) # input (batch_size * img_size) | ||
e=tf.random_normal((batch_size,z_size), mean=0, stddev=1) # Qsampler noise | ||
lstm_enc = LSTMCell(enc_size, read_size+dec_size) # encoder Op | ||
lstm_dec = LSTMCell(dec_size, z_size) # decoder Op | ||
|
||
def linear(x,output_dim): | ||
""" | ||
affine transformation Wx+b | ||
assumes x.shape = (batch_size, num_features) | ||
""" | ||
w=tf.get_variable("w", [x.get_shape()[1], output_dim]) | ||
b=tf.get_variable("b", [output_dim], initializer=tf.constant_initializer(0.0)) | ||
return tf.matmul(x,w)+b | ||
|
||
def filterbank(gx, gy, sigma2,delta, N): | ||
grid_i = tf.reshape(tf.cast(tf.range(N), tf.float32), [1, -1]) | ||
mu_x = gx + (grid_i - N / 2 - 0.5) * delta # eq 19 | ||
mu_y = gy + (grid_i - N / 2 - 0.5) * delta # eq 20 | ||
a = tf.reshape(tf.cast(tf.range(A), tf.float32), [1, 1, -1]) | ||
b = tf.reshape(tf.cast(tf.range(B), tf.float32), [1, 1, -1]) | ||
mu_x = tf.reshape(mu_x, [-1, N, 1]) | ||
mu_y = tf.reshape(mu_y, [-1, N, 1]) | ||
sigma2 = tf.reshape(sigma2, [-1, 1, 1]) | ||
Fx = tf.exp(-tf.square((a - mu_x) / (2*sigma2))) # 2*sigma2? | ||
Fy = tf.exp(-tf.square((b - mu_y) / (2*sigma2))) # batch x N x B | ||
# normalize, sum over A and B dims | ||
Fx=Fx/tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),eps) | ||
Fy=Fy/tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),eps) | ||
return Fx,Fy | ||
|
||
def attn_window(scope,h_dec,N): | ||
with tf.variable_scope(scope,reuse=DO_SHARE): | ||
params=linear(h_dec,5) | ||
gx_,gy_,log_sigma2,log_delta,log_gamma=tf.split(1,5,params) | ||
gx=(A+1)/2*(gx_+1) | ||
gy=(B+1)/2*(gy_+1) | ||
sigma2=tf.exp(log_sigma2) | ||
delta=(max(A,B)-1)/(N-1)*tf.exp(log_delta) # batch x N | ||
return filterbank(gx,gy,sigma2,delta,N)+(tf.exp(log_gamma),) | ||
|
||
## READ ## | ||
def read_no_attn(x,x_hat,h_dec_prev): | ||
return tf.concat(1,[x,x_hat]) | ||
|
||
def read_attn(x,x_hat,h_dec_prev): | ||
Fx,Fy,gamma=attn_window("read",h_dec_prev,read_n) | ||
def filter_img(img,Fx,Fy,gamma,N): | ||
Fxt=tf.transpose(Fx,perm=[0,2,1]) | ||
img=tf.reshape(img,[-1,B,A]) | ||
glimpse=tf.batch_matmul(Fy,tf.batch_matmul(img,Fxt)) | ||
glimpse=tf.reshape(glimpse,[-1,N*N]) | ||
return glimpse*tf.reshape(gamma,[-1,1]) | ||
x=filter_img(x,Fx,Fy,gamma,read_n) # batch x (read_n*read_n) | ||
x_hat=filter_img(x_hat,Fx,Fy,gamma,read_n) | ||
return tf.concat(1,[x,x_hat]) # concat along feature axis | ||
|
||
read = read_attn if FLAGS.read_attn else read_no_attn | ||
|
||
## ENCODE ## | ||
def encode(state,input): | ||
""" | ||
run LSTM | ||
state = previous encoder state | ||
input = cat(read,h_dec_prev) | ||
returns: (output, new_state) | ||
""" | ||
with tf.variable_scope("encoder",reuse=DO_SHARE): | ||
return lstm_enc(input,state) | ||
|
||
## Q-SAMPLER (VARIATIONAL AUTOENCODER) ## | ||
|
||
def sampleQ(h_enc): | ||
""" | ||
Samples Zt ~ normrnd(mu,sigma) via reparameterization trick for normal dist | ||
mu is (batch,z_size) | ||
""" | ||
with tf.variable_scope("mu",reuse=DO_SHARE): | ||
mu=linear(h_enc,z_size) | ||
with tf.variable_scope("sigma",reuse=DO_SHARE): | ||
logsigma=linear(h_enc,z_size) | ||
sigma=tf.exp(logsigma) | ||
return (mu + sigma*e, mu, logsigma, sigma) | ||
|
||
## DECODER ## | ||
def decode(state,input): | ||
with tf.variable_scope("decoder",reuse=DO_SHARE): | ||
return lstm_dec(input, state) | ||
|
||
## WRITER ## | ||
def write_no_attn(h_dec): | ||
with tf.variable_scope("write",reuse=DO_SHARE): | ||
return linear(h_dec,img_size) | ||
|
||
def write_attn(h_dec): | ||
with tf.variable_scope("writeW",reuse=DO_SHARE): | ||
w=linear(h_dec,write_size) # batch x (write_n*write_n) | ||
N=write_n | ||
w=tf.reshape(w,[batch_size,N,N]) | ||
Fx,Fy,gamma=attn_window("write",h_dec,write_n) | ||
Fyt=tf.transpose(Fy,perm=[0,2,1]) | ||
wr=tf.batch_matmul(Fyt,tf.batch_matmul(w,Fx)) | ||
wr=tf.reshape(wr,[batch_size,B*A]) | ||
#gamma=tf.tile(gamma,[1,B*A]) | ||
return wr*tf.reshape(1.0/gamma,[-1,1]) | ||
|
||
write=write_attn if FLAGS.write_attn else write_no_attn | ||
|
||
## STATE VARIABLES ## | ||
|
||
cs=[0]*T # sequence of canvases | ||
mus,logsigmas,sigmas=[0]*T,[0]*T,[0]*T # gaussian params generated by SampleQ. We will need these for computing loss. | ||
# initial states | ||
h_dec_prev=tf.zeros((batch_size,dec_size)) | ||
enc_state=lstm_enc.zero_state(batch_size, tf.float32) | ||
dec_state=lstm_dec.zero_state(batch_size, tf.float32) | ||
|
||
## DRAW MODEL ## | ||
|
||
# construct the unrolled computational graph | ||
for t in range(T): | ||
c_prev = tf.zeros((batch_size,img_size)) if t==0 else cs[t-1] | ||
x_hat=x-tf.sigmoid(c_prev) # error image | ||
r=read(x,x_hat,h_dec_prev) | ||
h_enc,enc_state=encode(enc_state,tf.concat(1,[r,h_dec_prev])) | ||
z,mus[t],logsigmas[t],sigmas[t]=sampleQ(h_enc) | ||
h_dec,dec_state=decode(dec_state,z) | ||
cs[t]=c_prev+write(h_dec) # store results | ||
h_dec_prev=h_dec | ||
DO_SHARE=True # from now on, share variables | ||
|
||
## LOSS FUNCTION ## | ||
|
||
def binary_crossentropy(t,o): | ||
return -(t*tf.log(o+eps) + (1.0-t)*tf.log(1.0-o+eps)) | ||
|
||
# reconstruction term appears to have been collapsed down to a single scalar value (rather than one per item in minibatch) | ||
x_recons=tf.nn.sigmoid(cs[-1]) | ||
|
||
# after computing binary cross entropy, sum across features then take the mean of those sums across minibatches | ||
Lx=tf.reduce_sum(binary_crossentropy(x,x_recons),1) # reconstruction term | ||
Lx=tf.reduce_mean(Lx) | ||
|
||
kl_terms=[0]*T | ||
for t in range(T): | ||
mu2=tf.square(mus[t]) | ||
sigma2=tf.square(sigmas[t]) | ||
logsigma=logsigmas[t] | ||
kl_terms[t]=0.5*tf.reduce_sum(mu2+sigma2-2*logsigma,1)-T*.5 # each kl term is (1xminibatch) | ||
KL=tf.add_n(kl_terms) # this is 1xminibatch, corresponding to summing kl_terms from 1:T | ||
Lz=tf.reduce_mean(KL) # average over minibatches | ||
|
||
cost=Lx+Lz | ||
|
||
## OPTIMIZER ## | ||
|
||
optimizer=tf.train.AdamOptimizer(learning_rate, beta1=0.5) | ||
grads=optimizer.compute_gradients(cost) | ||
for i,(g,v) in enumerate(grads): | ||
if g is not None: | ||
grads[i]=(tf.clip_by_norm(g,5),v) # clip gradients | ||
train_op=optimizer.apply_gradients(grads) | ||
|
||
## RUN TRAINING ## | ||
|
||
data_directory = os.path.join(FLAGS.data_dir, "mnist") | ||
if not os.path.exists(data_directory): | ||
os.makedirs(data_directory) | ||
train_data = mnist.input_data.read_data_sets(data_directory, one_hot=True).train # binarized (0-1) mnist data | ||
|
||
fetches=[] | ||
fetches.extend([Lx,Lz,train_op]) | ||
Lxs=[0]*train_iters | ||
Lzs=[0]*train_iters | ||
|
||
sess=tf.InteractiveSession() | ||
|
||
saver = tf.train.Saver() # saves variables learned during training | ||
tf.initialize_all_variables().run() | ||
#saver.restore(sess, "/tmp/draw/drawmodel.ckpt") # to restore from model, uncomment this line | ||
|
||
for i in range(train_iters): | ||
xtrain,_=train_data.next_batch(batch_size) # xtrain is (batch_size x img_size) | ||
feed_dict={x:xtrain} | ||
results=sess.run(fetches,feed_dict) | ||
Lxs[i],Lzs[i],_=results | ||
if i%100==0: | ||
print("iter=%d : Lx: %f Lz: %f" % (i,Lxs[i],Lzs[i])) | ||
|
||
## TRAINING FINISHED ## | ||
|
||
canvases=sess.run(cs,feed_dict) # generate some examples | ||
canvases=np.array(canvases) # T x batch x img_size | ||
|
||
out_file=os.path.join(FLAGS.data_dir,"draw_data.npy") | ||
np.save(out_file,[canvases,Lxs,Lzs]) | ||
print("Outputs saved in file: %s" % out_file) | ||
|
||
ckpt_file=os.path.join(FLAGS.data_dir,"drawmodel.ckpt") | ||
print("Model saved in file: %s" % saver.save(sess,ckpt_file)) | ||
|
||
sess.close() | ||
|
||
print('Done drawing! Have a nice day! :)') |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# takes data saved by DRAW model and generates animations | ||
# example usage: python plot_data.py noattn /tmp/draw/draw_data.npy | ||
|
||
import matplotlib | ||
import sys | ||
import numpy as np | ||
|
||
interactive=False # set to False if you want to write images to file | ||
|
||
if not interactive: | ||
matplotlib.use('Agg') # Force matplotlib to not use any Xwindows backend. | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
def xrecons_grid(X,B,A): | ||
""" | ||
plots canvas for single time step | ||
X is x_recons, (batch_size x img_size) | ||
assumes features = BxA images | ||
batch is assumed to be a square number | ||
""" | ||
padsize=1 | ||
padval=.5 | ||
ph=B+2*padsize | ||
pw=A+2*padsize | ||
batch_size=X.shape[0] | ||
N=int(np.sqrt(batch_size)) | ||
X=X.reshape((N,N,B,A)) | ||
img=np.ones((N*ph,N*pw))*padval | ||
for i in range(N): | ||
for j in range(N): | ||
startr=i*ph+padsize | ||
endr=startr+B | ||
startc=j*pw+padsize | ||
endc=startc+A | ||
img[startr:endr,startc:endc]=X[i,j,:,:] | ||
return img | ||
|
||
if __name__ == '__main__': | ||
prefix=sys.argv[1] | ||
out_file=sys.argv[2] | ||
[C,Lxs,Lzs]=np.load(out_file) | ||
T,batch_size,img_size=C.shape | ||
X=1.0/(1.0+np.exp(-C)) # x_recons=sigmoid(canvas) | ||
B=A=int(np.sqrt(img_size)) | ||
if interactive: | ||
f,arr=plt.subplots(1,T) | ||
for t in range(T): | ||
img=xrecons_grid(X[t,:,:],B,A) | ||
if interactive: | ||
arr[t].matshow(img,cmap=plt.cm.gray) | ||
arr[t].set_xticks([]) | ||
arr[t].set_yticks([]) | ||
else: | ||
plt.matshow(img,cmap=plt.cm.gray) | ||
imgname='%s_%d.png' % (prefix,t) # you can merge using imagemagick, i.e. convert -delay 10 -loop 0 *.png mnist.gif | ||
plt.savefig(imgname) | ||
print(imgname) | ||
f=plt.figure() | ||
plt.plot(Lxs,label='Reconstruction Loss Lx') | ||
plt.plot(Lzs,label='Latent Loss Lz') | ||
plt.xlabel('iterations') | ||
plt.legend() | ||
if interactive: | ||
plt.show() | ||
else: | ||
plt.savefig('%s_loss.png' % (prefix)) |