Skip to content
forked from ericjang/draw

TensorFlow Implementation of "DRAW: A Recurrent Neural Network For Image Generation"

License

Notifications You must be signed in to change notification settings

sebandraos/draw

This branch is 4 commits behind ericjang/draw:master.

Folders and files

NameName
Last commit message
Last commit date

Latest commit

00c55a4 · Aug 27, 2016

History

14 Commits
Feb 22, 2016
Feb 22, 2016
Feb 22, 2016
Jul 26, 2016
Aug 27, 2016
Feb 22, 2016

Repository files navigation

draw

TensorFlow implementation of DRAW: A Recurrent Neural Network For Image Generation on the MNIST generation task.

With Attention Without Attention

Although open-source implementations of this paper already exist (see links below), this implementation focuses on simplicity and ease of understanding. I tried to make the code resemble the raw equations as closely as posible.

For a gentle walkthrough through the paper and implementation, see the writeup here: http://blog.evjang.com/2016/06/understanding-and-implementing.html.

Usage

python draw.py --data_dir=/tmp/draw downloads the binarized MNIST dataset to /tmp/draw/mnist and trains the DRAW model with attention enabled for both reading and writing. After training, output data is written to /tmp/draw/draw_data.npy

You can visualize the results by running the script python plot_data.py <prefix> <output_data>

For example,

python myattn /tmp/draw/draw_data.npy

To run training without attention, do:

python draw.py --working_dir=/tmp/draw --read_attn=False --write_attn=False

Restoring from Pre-trained Model

Instead of training from scratch, you can load pre-trained weights by uncommenting the following line in draw.py and editing the path to your checkpoint file as needed. Save electricity!

saver.restore(sess, "/tmp/draw/drawmodel.ckpt")

This git repository contains the following pre-trained in the data/ folder:

Filename Description
draw_data_attn.npy Training outputs for DRAW with attention
drawmodel_attn.ckpt Saved weights for DRAW with attention
draw_data_noattn.npy Training outputs for DRAW without attention
drawmodel_noattn.ckpt Saved weights for DRAW without attention

These were trained for 10000 iterations with minibatch size=100 on a GTX 970 GPU.

Useful Resources

About

TensorFlow Implementation of "DRAW: A Recurrent Neural Network For Image Generation"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%