Code Monkey home page Code Monkey logo

augmentation-enhanced-soft-intro-vae's Introduction

augmentation-enhanced-Soft-Intro-VAE

In this project we researched the effect of regularizations and augmentations on the Soft-Introspective VAE presented in "Soft-IntroVAE: Analyzing and Improving the Introspective Variational Autoencoder." by Daniel, Tal, and Aviv Tamar, in order to improve its generation ability. Using differentiable augmentations on both real and fake data which also act as a regularization, we managed to improve the FID score reported in the paper of 4.30 on the CIFAR-10 dataset, to a score of 2.96.

This repository contains the implementation of the different methods we used integrated with the code from the official repository.

Project created by: Yuval Baruch & Dan Haramati

General Description: The recently introduced Soft Introspective Variational Auto-Encoder (Soft-IntroVAE) is an explicit deep generative model that enjoys the good traits of variational auto-encoders (VAEs) and generative adversarial networks (GANs) by proposing a variational-based approach to adversarial training, and it exhibits outstanding performance in various tasks such as density estimation, image generation and more. However, in adversarial training, it is quite common that the discriminatory module (discriminator in GANs, encoder in Soft-IntroVAE) may overfit at some point to the training data, preventing further improvements for the generative module (generator in GANs, decoder in Soft-IntroVAE). In this project, we researched and implemented different data augmentations and regularization methods with the aim of improving the Soft-Intro VAE model in terms of generation quality on the CIFAR-10 dataset. We achieved our best results using differentiable augmentations on both real and fake data which also act as a regularization and managed to improve the FID score from 4.30 to 2.96 on the CIFAR-10 dataset, further narrowing the gap between VAE based models and GANs. We include our method’s theoretical and empirical justification as well as an analysis of the exact source of our significant improvement. In addition, we present results on downstream tasks our model can be useful for as a consequence of being trained with augmented data.

augmentation-enhanced-Soft-Intro-VAE

  1. Datasets
  2. Training
  3. Method Overview
  4. Results
  5. Files and Directories in the Repository
  6. Credits

Datasets

We used the CIFAR-10 dataset to optimize and evaluate our model on.
For the evaluation and analysis we also used MNIST which is more simple and structured than CIAFR-10.

Training

main.py --help

You should use the main.py file with the following arguments:

Argument Description Legal Values
-h, --help shows arguments description
-d, --dataset dataset to train on str: 'cifar10', 'mnist', 'fmnist', 'svhn', 'monsters128', 'celeb128', 'celeb256', 'celeb1024'
-n, --num_epochs total number of epochs to run int: default=250
-z, --z_dim latent dimensions int: default=128
-s, --seed random state to use. for random: -1 int: -1 , 0, 1, 2 ,....
-v, --num_vae number of iterations for vanilla vae training int: default=0
-l, --lr learning rate float: defalut=2e-4
-r, --beta_rec beta coefficient for the reconstruction loss float: default=1.0
-k, --beta_kl beta coefficient for the kl divergence float: default=1.0
-e, --beta_neg beta coefficient for the kl divergence in the expELBO function float: default=256.0
-g, --gamma_r coefficient for the reconstruction loss for fake data in the decoder float: default=1e-8
-b, --batch_size batch size int: default=32
-p, --pretrained path to pretrained model, to continue training str: default="None"
-c, --device device: -1 for cpu, 0 and up for specific cuda device int: default=-1
-f, --fid if specified, FID wil be calculated during training bool: default=False
--gn_real_sigma standart deviation for gaussian noise added to real data float: defalut=0.0
--nn_sigma_enc size of relative standard deviation of in-layer gaussian noise in the encoder float: defalut=0.0
--nn_sigma_dec size of relative standard deviation of in-layer gaussian noise in the decoder float: defalut=0.0
--nn_gn_rel if specified, inlayer guassian noise std will be a learned parameter bool: default=True
--p_enc_s dropout probability after the first layer in the encoder float: defalut=0.0
--p_enc_e dropout probability after the last layer in the encoder float: defalut=0.0
--p_dec_s dropout probability after the first layer in the decoder float: defalut=0.0
--p_dec_e dropout probability after the last layer in the decoder float: defalut=0.0
--drop_dis if specified, Dropout will be disable in the other module during training bool: default=False
--p_augment data augmentation probability in each augmentation layer float: defalut=0.0
--p_ls_enc label switch percentage per batch in encoder training float: defalut=0.0
--p_ls_dec label switch percentage per batch in decoder training float: defalut=0.0

Command with recommended hyperparameters:

CIFAR-10 - python main.py --dataset cifar10 --device 0 --lr 2e-4 --num_epochs 1000 --beta_kl 1.0 --beta_rec 1.0 --beta_neg 256 --z_dim 128 --batch_size 32 --p_augment 0.5

MNIST - python main.py --dataset mnist --device 0 --lr 2e-4 --num_epochs 500 --beta_kl 1.0 --beta_rec 1.0 --beta_neg 256 --z_dim 32 --batch_size 128 --p_augment 0.5

Notes:

  • Our best experiment was with --p_augment 0.5, but you can experiment with the other methods as well.
  • The optimizer uses a step scheduler for the learning rate, the milestone is set to epoch 350.
  • During training:
    • figures of samples and reconstructions are saved locally.
    • statistics are printed (reconstruction error, KLD, expELBO).
    • at the end of each epoch, a summary of statistics will be printed.
    • FID can be calculated only on 3 channels images (rgb)
  • Tips:
    • KL of fake/rec samples should be >= KL of real data.
    • It is usually better to choose beta_kl >= beta_rec.
    • FID calculation is not so fast, so turn it off if you don't care about it.
  • The baseline model code supports additional datasets that we did not optimize our method for.

Method Overview

The code in this repository includes implementations of several adjustable regularization methods:

  • Dropout in different layers of the encoder and decoder
  • Gaussian noise added to the original data or to intermediate layers of the encoder and decoder
  • Label switching (Real <-> Fake)
  • Differentiable data augmentations

The method that has achieved the most significant improvement in the FID score was data augmentation.

Differentiable Data Augmentation
We used kornia for the implementation of differentiable augmentations to the data by incorporating them as initial layers in the encoder module. We applied them on both real and fake data in all training stages, as proposed in "Differentiable augmentation for data-efficient gan training.", by Zhao, Shengyu, et al. and portrayed in the following figure (applied in (i), (ii) and (iii)):

The augmentations that worked best for us and were applied with a probability of 0.5 were:

  • random cutout - block of up to half the size of the image
  • random translation - of up to 1/8 of the width/height of the image

We based our choice on our own trials and on "Training generative adversarial networks with limited data.", by Karras, Tero, et al.. An important difference in our setting is that the model has access to the original images through the reconstruction loss calculation, making it robust to augmentations leaking to the generated data.

We believe there were two main sources for this method's success in this setting:

  • The augmentations enrich our dataset (CIFAR-10 is a low data benchmark based on "Training generative adversarial networks with limited data.").
  • As both augmentations we used effectively delete blocks of pixels from the image, they also act as a regularization, creating a better balance between the two competing modules.

Results

All of the results can be conveniently produced using the provided S_IntroVAE_Analysis.ipynb notebook by loading trained models from checkpoints and running the relevant cells.

Open In Colab

Baseline Ours SOTA (June 2021)
FID score on CIFAR-10 4.30 2.96 2.10

Generation

We improved the Fréchet inception distance (FID) score on CIFAR-10 from 4.30 to 2.96.

These are a few samples generated by our model:

Reconstruction

We were able to maintain good reconstruction although generation significantly improved.

In-Painting
A downstream task our model performs well due to how it was trained is reconstructing images with missing blocks by encoding and then decoding them.

Our model is much more robust to this kind of noise.

Interpolation in the Latent Space

Interpolation between two images in the latent space by decoding for several values of .

Files and Directories in the Repository

Folder File name Purpose
main environment.yml environment file
training main.py main application for training augmentation enhanced Soft-Intro VAE
dataset.py holds functions that manage the datasets
train_soft_intro_vae.py builds and trains the model
metrics folder which contains 2 measurs for generation quality: fid_score.py, inception.py
analysis S_IntroVAE_Analysis.ipynb interactive notebook for loading model from checkpoint and producing results
figures generation_examples.jpg images generated by our model
image_inpainting_CIFAR10.jpg performance example in inpainting task on CIFAR-10
image_inpainting_MNIST.jpg performance example in inpainting task on MNIST
image_interpolation_CIFAR10.jpg interpolation between 2 images from CIFAR-10 performed in latent space
image_interpolation_MNIST.jpg interpolation between 2 images from MNIST performed in latent space
image_reconstruction_CIFAR10.jpg examples for the reconstruction quality of our model on CIFAR-10
image_reconstruction_MNIST.jpg examples for the reconstruction quality of our model on MNIST
model_structure.jpg structure of the model
where_to_augment.jpg shows different augmentation approches in an adversarial setting

Credits

augmentation-enhanced-soft-intro-vae's People

Contributors

baruch1192 avatar danhrmti avatar taldatech avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.