Code Monkey home page Code Monkey logo

dgflow's Introduction

Discriminator Gradient flow (DGflow)

preprint License: MIT Venue:ICLR 2021


Fig 1. An illustration of refinement using DGflow, with the gradient flow in the 2-Wasserstein space (top) and the corresponding discretized SDE in the latent space (bottom).


This repository contains code for reproducing the experiments presented in the ICLR 2021 paper Refining Deep Generative Models via Discriminator Gradient Flow. In this paper, we propose DGflow, a technique to improve samples from deep generative models using the gradient flow of entropy-regularized f-divergences between the generated and real data distributions. The gradient flow has an equivalent Stochastic Differential Equation (SDE) which can be simulated using the Euler-Maruyama method. We simulate the SDE in the latent space of the generative model as follows:

The density ratio term in the above equation is estimated using a pretrained discriminator.

We further propose a technique to extend DGflow to deep generative models lacking a corresponding discriminator such as VAEs, Normalizing Flows, and GANs with vector-valued critics. For such generative models, we use a (pretrained) discriminator trained on the same dataset as the generative model combined with a density ratio corrector which corrects the density ratio estimate. Please check out the paper for more details and empirical results.


Fig 2. Improvement in the quality of samples generated from the base model (leftmost columns) over the steps of DGflow for the CIFAR10 (left) and STL10 (right) datasets.


Environment Setup

  • Install Libraries:
pip install -r requirements.txt
  • Run the following commands.
$ python download_dataset.py --root torchdata/ --data cifar10  # Download CIFAR10
$ python download_dataset.py --root torchdata/ --data stl10  # Download STL10
$ python scale_stl10.py --gpu 0  # Scale STL10 to 48 x 48 and 32 x 32
$ mkdir metric
$ python download_inception.py --outfile metric/inception_score.model    # Download inception model
$ python extract_inception_feats.py --data CIFAR --gpu 0  # Extract inception features for CIFAR10
$ python extract_inception_feats.py --data STL48 --gpu 0  # Extract inception features for STL10 48 x 48
$ python extract_inception_feats.py --data STL32 --gpu 0  # Extract inception features for STL10 32 x 32
  • Download pretrained SNGAN models from [Tanaka, 2019]'s official repository.
  • Download pretrained MMDGAN, OCFGAN, and VAE generators and density ratio correctors from Releases.
  • Move the pretrained models to ./trained_models/.

Usage

To refine samples using DGflow run the following:

python dgflow.py --config <path to config file>

Example for SN-DCGAN (hinge):

python dgflow.py --config configs/sngan-hi.yml

The results will be saved in ./exps/.

Example of a Config File

dataset: cifar10
image_size: 32
eval_file_prefix: CIFAR
gen_type: t_cnn # Type of generator
gen_path: trained_models/OCFGAN_CIFAR.pth # Path to generator checkpoint
disc_type: c_sndcgan # Type of discriminator
disc_path: trained_models/DCGAN_D_CIFAR_SNGAN_NonSaturating_150001.npz # Path to discriminator checkpoint
corr_type: t_sndcgan # Type of density ratio corrector
corr_path: trained_models/OCFGAN_CIFAR_DRC.pth # Path to density ratio corrector checkpoint
f_div: KL # One of 'KL', 'logD', or 'JS'
eta: 0.1 # Step-Size
gamma: 0.01 # Noise regularizer
save_interval: 5 # Save samples every save_interval steps
keep_samples: true # Keep samples on disk after execution
steps: 25 # Number of update steps
bottom_width: 4 
num_imgs: 50000 # Number of images to generate
batch_size: 500
exp_root: ./exps/
z_dim: 32 # Dimension of the prior distribution

Questions

For any questions regarding the code or the paper, please email Abdul Fatir.

BibTeX

If you find this repository or the ideas presented in our paper useful for your research, please consider citing our paper.

@inproceedings{
ansari2021refining,
title={Refining Deep Generative Models via Discriminator Gradient Flow},
author={Abdul Fatir Ansari and Ming Liang Ang and Harold Soh},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=Zbc-ue9p_rE}
}

Acknowledgement

This repo contains code that's based on the following repos:

Repo Copyright (c) License
AkinoriTanaka-phys/DOT 2019 AkinoriTanaka-phys MIT License
pfnet-research/sngan_projection 2018 Preferred Networks, Inc. MIT License
pfnet-research/chainer-gan-lib 2017 pfnet-research MIT License
xjtuygao/VGrow 2019 xjtuygao NA

References

[Miyato et al., 2018] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, and Yuichi Yoshida. Spectral normalization for generative adversarial networks. In ICLR, 2018.
[Tanaka, 2019] Akinori Tanaka. Discriminator optimal transport. In NeurIPS, 2019.
[Ansari et al., 2020] Abdul Fatir Ansari, Jonathan Scarlett, and Harold Soh. A characteristic function approach to deep implicit generative modeling. In CVPR, 2020.

dgflow's People

Contributors

abdulfatir avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.