Code Monkey home page Code Monkey logo

prediction_gan's Introduction

Prediction Optimizer (to stabilize GAN training)

Introduction

This is a PyTorch implementation of 'prediction method' introduced in the following paper ...

  • Abhay Yadav et al., Stabilizing Adversarial Nets with Prediction Methods, ICLR 2018, Link
  • (Just for clarification, I'm not an author of the paper.)

The authors proposed a simple (but effective) method to stabilize GAN trainings. With this Prediction Optimizer, you can easily apply the method to your existing GAN codes. This impl. is compatible with most of PyTorch optimizers and network structures. (Please let me know if you have any issues using this)

How-to-use

Instructions

  • Import prediction.py
    • from prediction import PredOpt
  • Initialize just like an optimizer
    • pred = PredOpt(net.parameters())
  • Run the model in a 'with' block to get results from a model with predicted params.
    • With 'step' argument, you can control lookahead step size (1.0 by default)
    • with pred.lookahead(step=1.0):
          output = net(input)
  • Call step() after an update of the network parameters
    • optim_net.step()
      pred.step()

Samples

  • You can find a sample code in this repository (example_gan.py)
  • A sample snippet
  • import torch.optim as optim
    from prediction import PredOpt
    
    
    # ...
    
    optim_G = optim.Adam(netG.parameters(), lr=0.01)
    optim_D = optim.Adam(netD.parameters(), lr=0.01)
    
    pred_G = PredOpt(netG.parameters())             # Create an prediction optimizer with target parameters
    pred_D = PredOpt(netD.parameters())
    
    
    for i, data in enumerate(dataloader, 0):
        # (1) Training D with samples from predicted generator
        with pred_G.lookahead(step=1.0):            # in the 'with' block, the model works as a 'predicted' model
            fake_predicted = netG(Z)                           
        
            # Compute gradients and loss 
        
            optim_D.step()
            pred_D.step()
        
        # (2) Training G
        with pred_D.lookahead(step=1.0:)            # 'Predicted D'
            fake = netG(Z)                          # Draw samples from the real model. (not predicted one)
            D_outs = netD(fake)
    
            # Compute gradients and loss
    
            optim_G.step()
            pred_G.step()                           # You should call PredOpt.step() after each update

Output samples

You can find more images at the following issues.

Training w/ large learning rate (0.01)

Vanilla DCGAN DCGAN w/ prediction (step=1.0)
ep25_cifar_base_lr 0 01 ep25_cifar_pred_lr 0 01
ep25_celeba_base_lr 0 01 ep25_celeba_pred_lr 0 01

Training w/ medium learning rate (1e-4)

Vanilla DCGAN DCGAN w/ prediction (step=1.0)
ep25_cifar_base_lr 0 0001 ep25_cifar_pred_lr 0 0001
ep25_celeba_base_lr 0 0001 ep25_celeba_pred_lr 0 0001

Training w/ small learning rate (1e-5)

Vanilla DCGAN DCGAN w/ prediction (step=1.0)
ep25_cifar_base_lr 0 00001 ep25_cifar_pred_lr 0 00001
ep25_celeba_base_lr 0 00001 ep25_celeba_pred_lr 0 00001

External links

TODOs

  • : Impl. as an optimizer
  • : Support pip install
  • : Add some experimental results

prediction_gan's People

Contributors

sanghoon avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

prediction_gan's Issues

CelebA experiments

Notes

  • This work was done only to show some sample outputs. Different random seeds can lead to totally different outcomes. Therefore, we need to investigate outputs from repeated trials to correctly compare two GAN methods.
  • For faster training, I used only 50k images from CelebA (resized to be 64x64)

Large learning rate (0.01)

Vanilla DCGAN

After 2 epochs After 10 epochs After 25 epochs
ep02_celeba_base_lr 0 01 ep10_celeba_base_lr 0 01 ep25_celeba_base_lr 0 01

DCGAN w/ prediction

After 2 epochs After 10 epochs After 25 epochs
ep02_celeba_pred_lr 0 01 ep10_celeba_pred_lr 0 01 ep25_celeba_pred_lr 0 01

Medium learning rate (0.0001)

Vanilla DCGAN

After 2 epochs After 10 epochs After 25 epochs
ep02_celeba_base_lr 0 0001 ep10_celeba_base_lr 0 0001 ep25_celeba_base_lr 0 0001

DCGAN w/ prediction

After 2 epochs After 10 epochs After 25 epochs
ep02_celeba_pred_lr 0 0001 ep10_celeba_pred_lr 0 0001 ep25_celeba_pred_lr 0 0001

Small learning rate (1e-5)

Vanilla DCGAN

After 2 epochs After 10 epochs After 25 epochs
ep02_celeba_base_lr 0 00001 ep10_celeba_base_lr 0 00001 ep25_celeba_base_lr 0 00001

DCGAN w/ prediction

After 2 epochs After 10 epochs After 25 epochs
ep02_celeba_pred_lr 0 00001 ep10_celeba_pred_lr 0 00001 ep25_celeba_pred_lr 0 00001

CIFAR-10 experiments with different LRs (PredictionOpt only on the generator)

Notes

  • This work was done only to show some sample outputs. Different random seeds can lead to totally different outcomes. Therefore, we need to investigate outputs from repeated trials to correctly compare two GAN methods.
  • For these results, the prediction methods has been applied only for G.

Large learning rate (0.01)

Vanilla DCGAN

After 2 epochs
ep02_cifar_base_lr 0 01
After 10 epochs
ep10_cifar_base_lr 0 01
After 25 epochs
ep25_cifar_base_lr 0 01

DCGAN w/ prediction

After 2 epochs
ep02_cifar_pred_lr 0 01
After 10 epochs
ep10_cifar_pred_lr 0 01
After 25 epochs
ep25_cifar_pred_lr 0 01

Medium learning rate (0.0001)

Vanilla DCGAN

After 2 epochs
ep02_cifar_base_lr 0 0001
After 10 epochs
ep10_cifar_base_lr 0 0001
After 25 epochs
ep25_cifar_base_lr 0 0001

DCGAN w/ prediction

After 2 epochs
ep02_cifar_pred_lr 0 0001
After 10 epochs
ep10_cifar_pred_lr 0 0001
Adter 25 epochs
ep25_cifar_pred_lr 0 0001

Small learning rate (1e-5)

Vanilla DCGAN

After 2 epochs
ep02_cifar_base_lr 0 00001
After 10 epochs
ep10_cifar_base_lr 0 00001
After 25 epochs
ep25_cifar_base_lr 0 00001

DCGAN w/ prediction

After 2 epochs
ep02_cifar_pred_lr 0 00001
Adter 10 epochs
ep10_cifar_pred_lr 0 00001
After 25 epochs
ep25_cifar_pred_lr 0 00001

CIFAR-10 experiments with different LRs

Notes

  • This work was done only to show some sample outputs. Different random seeds can lead to totally different outcomes. Therefore, we need to investigate outputs from repeated trials to correctly compare two GAN methods.

Large learning rate (0.01)

Vanilla DCGAN

After 2 epochs After 10 epochs After 25 epochs
ep02_cifar_base_lr 0 01 ep10_cifar_base_lr 0 01 ep25_cifar_base_lr 0 01

DCGAN w/ prediction

After 2 epochs After 10 epochs After 25 epochs
ep02_cifar_pred_lr 0 01 ep10_cifar_pred_lr 0 01 ep25_cifar_pred_lr 0 01

Medium learning rate (0.0001)

Vanilla DCGAN

After 2 epochs After 10 epochs After 25 epochs
ep02_cifar_base_lr 0 0001 ep10_cifar_base_lr 0 0001 ep25_cifar_base_lr 0 0001

DCGAN w/ prediction

After 2 epochs After 10 epochs After 25 epochs
ep02_cifar_pred_lr 0 0001 ep10_cifar_pred_lr 0 0001 ep25_cifar_pred_lr 0 0001

Small learning rate (1e-5)

Vanilla DCGAN

After 2 epochs After 10 epochs After 25 epochs
ep02_cifar_base_lr 0 00001 ep10_cifar_base_lr 0 00001 ep25_cifar_base_lr 0 00001

DCGAN w/ prediction

After 2 epochs After 10 epochs After 25 epochs
ep02_cifar_pred_lr 0 00001 ep10_cifar_pred_lr 0 00001 ep25_cifar_pred_lr 0 00001

CelebA experiments (PredictionOpt only on the generator)

Notes

  • This work was done only to show some sample outputs. Different random seeds can lead to totally different outcomes. Therefore, we need to investigate outputs from repeated trials to correctly compare two GAN methods.
  • For faster training, I used only 50k images from CelebA (resized to be 64x64)
  • For these results, prediction methods have been applied only for G.

Large learning rate (0.01)

Vanilla DCGAN

After 2 epochs
ep02_celeba_base_lr 0 01
After 10 epochs
ep10_celeba_base_lr 0 01
After 25 epochs
ep25_celeba_base_lr 0 01

DCGAN w/ prediction

After 2 epochs
ep02_celeba_pred_lr 0 01
After 10 epochs
ep10_celeba_pred_lr 0 01
After 25 epochs
ep25_celeba_pred_lr 0 01

Medium learning rate (0.0001)

Vanilla DCGAN

After 2 epochs
ep02_celeba_base_lr 0 0001
After 10 epochs
ep10_celeba_base_lr 0 0001
After 25 epochs
ep25_celeba_base_lr 0 0001

DCGAN w/ prediction

After 2 epochs
ep02_celeba_pred_lr 0 0001
After 10 epochs
ep10_celeba_pred_lr 0 0001
After 25 epochs
ep25_celeba_pred_lr 0 0001

Small learning rate (1e-5)

Vanilla DCGAN

After 2 epochs
ep02_celeba_base_lr 0 00001
After 10 epochs
ep10_celeba_base_lr 0 00001
After 25 epochs
ep25_celeba_base_lr 0 00001

DCGAN w/ prediction

After 2 epochs
ep02_celeba_pred_lr 0 00001
After 10 epochs
ep10_celeba_pred_lr 0 00001
After 25 epochs
ep25_celeba_pred_lr 0 00001

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.