TensorFlow implementation of DRAW: A Recurrent Neural Network For Image Generation on the MNIST generation task.
With Attention | Without Attention |
---|---|
Although open-source implementations of this paper already exist (see links below), this implementation focuses on simplicity and ease of understanding. I tried to make the code resemble the raw equations as closely as posible.
For a gentle walkthrough through the paper and implementation, see the writeup here: http://blog.evjang.com/2016/06/understanding-and-implementing.html.
python draw.py --data_dir=/tmp/draw --log_dir=/tmp/draw/logs
uses images provided in data_dir and trains the DRAW model with attention enabled for both reading and writing. After training, output data is written to log_dir
Tensorboard summaries can be monitored using tensorboard --logdir=<log-dir>
You can visualize the results by running the script python plot_data.py <prefix> <output_data>
For example,
python plot_data.py myattn /tmp/draw/draw_data.npy
Parameters can be modified in config.py
Instead of training from scratch, you can load pre-trained weights by uncommenting the following line in draw.py
and editing the path to your checkpoint file as needed. Save electricity!
saver.restore(sess, "/tmp/draw/drawmodel.ckpt")
This git repository contains the following pre-trained in the data/
folder:
Filename | Description |
---|---|
draw_data_attn.npy | Training outputs for DRAW with attention |
drawmodel_attn.ckpt | Saved weights for DRAW with attention |
draw_data_noattn.npy | Training outputs for DRAW without attention |
drawmodel_noattn.ckpt | Saved weights for DRAW without attention |
These were trained for 10000 iterations with minibatch size=100 on a GTX 970 GPU.