Code Monkey home page Code Monkey logo

variational-inequality-gan's Introduction

This is the code associated with the paper A Variational Inequality Perspective for Generative Adversarial Networks. If you find this code useful please cite as:

Gauthier Gidel, Hugo Berard, Gaetan Vignoud, Pascal Vincent, Simon Lacoste-Julien, A Variational Inequality Perspective on Generative Adversarial Networks, International Conference on Learning Representations (ICLR 2019), 2019.

Code mostly written by Hugo Berard while at Facebook AI Research, for any questions about the code please contact ([email protected])

Requirements

The code is in pytorch and was tested for:

  • pytorch=0.4.0

(Optional) The inception score is computed using the original implementation from OpenAI , and thus requires tensorflow <= 1.5.0 to be installed.

A conda environement is also provided (requires CUDA 9): conda env create -f environment.yml

class Extragradient

The extragradient method is packaged as a torch.optim.Optimizer with an additional method extrapolation(). Two variants are available ExtraSGD and ExtraAdam.

Example of how to run Extragradient:

for i, input, target in enumerate(dataset):
    Extragradient.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    if i%2:
      Extragradient.extrapolation()
    else:
      Extragradient.step()

Example of how to run Extragradient from the past:

for i, input, target in enumerate(dataset):
    PastExtragradient.extrapolation()
    PastExtragradient.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    PastExtragradient.step()

Note that for Extragradient from the past the extrapolation is done before computing the forward and the backward pass.

Experiments

To run the WGAN-GP experiment with ExtraAdam and the ResNet architecture on CIFAR10 with the parameters from the paper: python train_extraadam.py results\ --default --model resnet --cuda

The --default option loads the hyperparameters used in the paper for each experiments, they are available as JSON files in the config folder.

The weights for our WGAN-GP and ResNet model trained with ExtraAdam is available in the results folder.

For evaluation: python eval_inception_score.py results/ExtraAdam/best_model.state

A ipython notebook is also available for the bilinear example here.

Results

with Averaging:

AvgExtraAdam samples on CIFAR10 for ResNet WGAN-GP

without Averaging:

ExtraAdam samples on CIFAR10 for ResNet WGAN-GP

variational-inequality-gan's People

Contributors

caogang avatar hugobb avatar robotcator avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.