Code Monkey home page Code Monkey logo

tf-wgan's Introduction

Wasserstein GANs in Tensorflow

This is my implementation of the Wasserstein GAN algorithm (see the paper) in Tensorflow/Keras, as well as my takeaways from playing with it.

Training a model

To train a model, use the run.py script:

python run.py train \
  --dataset <dataset> \
  --model <epochs> \
  --n-batch <batch_size> \
  --lr <learning_rate> \
  --c <c_parameter> \
  --n-critic <critic_steps> \
  --log-dir <log_dir> \
  --epochs <training_epochs>

The model will report training/validation losses in the logfile.

Configuration

The repository contains code for a standard DC-GAN, trained using the usual GAN loss, as well as a Wasserstein GAN that uses a similar architecture.

The model parameters are configured via command-line options:

  • --model is either dcgan or wdcgan (standard or Wasserstein)
  • --dataset is only mnist for now
  • --c and --n-critic are the WGAN hyperparameters. The default values are the ones proposed in the paper.

Other flags are pretty self-explanatory.

Observations

These are the results you get from a fully-trained DCGAN (left) or a WGAN (right):

DCGAN Wasserstein DCGAN

The results are quite similar. However, I found the training process to be quite different.

Training speed

The most obvious difference is that the GAN starts outputting reasonable-looking images after less than 5 epochs (left), while the DCGAN takes nearly 100 epochs to get to the same level, and produces noise at the same stage of the training process (right):

DCGAN Wasserstein DCGAN

Obviously, we only train the generator on one batch out of six by default (the other five batches are used to train the discriminator). However, the difference in training speed is much more than five-fold. The learning rate is also smaller for WGANs, but the same learning rate in DCGANs still gives much faster convergence. It seems like there is a deeper difference between the gradients that the two methods provide.

Several others have also reported that WGANs were relatively slow to train. Soumith Chintala suggested on Twitter to increase the learning rate of the generator; however this will probably help only in some cases.

Stability

One of the central claims of the WGAN paper is that WGANs are much less sensitive to hyperparameter choices and to the model architecture. I am not yet convinced that it is completely true.

For example, if you add batch normalization in the first layer of the discriminator, the WGAN starts giving really bad results (even after hundreds of epochs; it's commented out in the code). However, having batch norm in the first layer of the standard DCGAN works perfectly fine for me. I'm a bit disappointed to see this happening on a dataset as simple as MNIST.

Convergence metric

The other claim made in the paper is that the Earth Mover's distance (which is approximated by the discriminator/critic loss times minus one) provides a useful metric of convergence. Lower losses should correspond to higher quality images.

That, I have found to be accurate. At first, the discriminator loss goes down, as the critic learns to tell apart the real examples from the fake ones. Then, the generator starts training and forces the discriminator's loss to steadily go up until convergence.

In the above case (when I add a batch norm and the model doesn't train well), the Earth Mover's distance indeed gets stuck at a high plateau.

Feedback

Send feedback to Volodymyr Kuleshov.

tf-wgan's People

Contributors

kuleshov avatar black-puppydog avatar

Watchers

Shubham Pachori avatar paper2code - bot avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.