Code Monkey home page Code Monkey logo

vae's Introduction

VAE

Overview of different types of autoencoders.

In this repository, we shall create a number of different types ofautoencoders. Some of the autoencoders are written in TensorFlow, and some in pyTorch. Care has been taken to make sure that the modelsare easy to understand rather than whether they are efficient or accurate. Also most of thiscode does not follow good software engineering practices whatsoever. This repo is not intended tobe production quality code. This is expected to be experimental software, that can form the basis for rapid prototyping and experimentation.

Of special note is the fact that none of this code uses any form of regularization, batch normalization, and the like. Neither does this code contain any information for saving models, creating checkpoints, loading from checkpoints, etc. If you wish to use any of these features, you will need to add them yourself.

Examples

The following examples are present

command model backend data comments
python3 testTFvae.py TF/VAE.py TensorFlow MNIST Both the encoder and the decoder are Dense layers. Reconstruction is simply based upon a sigmoid_cross_entropy_with_logits. MNIST digits are unraveled into a 784 dimensional vector.
python3 testTFcvae.py TF/CVAE.py TensorFlow MNIST Conditional variational autoencoder. Both the encoder and the decoder are Dense layers. Reconstruction is simply based upon a sigmoid_cross_entropy_with_logits. MNIST digits are unraveled into a 784 dimensional vector.
python3 testTFcoercevae.py TF/coerceVAE.py TensorFlow MNIST coerced variational autoencoder. Both the encoder and the decoder are Dense layers. Reconstruction is simply based upon a sigmoid_cross_entropy_with_logits. MNIST digits are unraveled into a 784 dimensional vector. In this variation, there is some coersion while creating the latent space so that there is greater separation between members of the group that are known to be in different groups.
python3 testTFConvVAE.py TF/ConvVAE.py TensorFlow MNIST Convolutional variaitonal autoencoder. Instead of assuming that the image is based upon a flattened representation, this method simply uses a set of convolution layers as part of the encoder and the decoder.
python3 testTemporalVAE.py TF/TemporalVAE.py TensorFlow Generated Data An Autoencoder that looks like a Hidden Markov Model (HMM). If the number of states are very high, this might be a good method of handling the matter image. Note that its best not to use this as a VAE but as an ordinary AE
python3 testTFsigmaVae.py TF/sigmaVAE.py TensorFlow MNIST The simple VAE example updated so that the loss function resembles that of the σ-VAE [1]

Example Results

VAE latent reconstruction with tanh activation:

image

VAE latent reconstruction with relu activation:

image

ConvVAE latent reconstruction with tanh activation:

image

Coerce VAE

In this case, we ant to coerce the latent space such that it is easier to discriminate between the different labels when the label data is available.

latent space reconstruction
image image

C-VAE

1 2 3 5
image img img img

σ-VAE

The σ-VAE is a variant on the β-VAE in that the parameter β is no longer a parameter that needs to be tuned by hand, but can be learned end-to-end. This follows from the work of Rybkin et al. [1], and is supposed to yield much better reconstructions in comparison to β-VAEs. Note that all the VAE's that are shown above have some form of manual β-tuning that has been performed at run-time. Compare reconstruction results form the β-VAE and the σ-VAE below:

VAE type reconstruction
β-VAE image
σ-VAE image

Requirements

The current version is written with the following configuration:

  • CudaToolkit 11.0
  • cuDNN 8.
  • TensorFlow 2.4.1
  • torch 1.8.0+cu11

The code has been tested on a GPU with the following configuration:

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.119.03   Driver Version: 450.119.03   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  GeForce RTX 2070    Off  | 00000000:01:00.0  On |                  N/A |
|  0%   47C    P8    21W / 175W |   1456MiB /  7979MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

For some reason, the current version of tensorflow overflows in memory usage and errors out for RTX 2070 seres. For that reason, you will need to add the following lines to your TensorFlow code to prevent that from happening.

import tensorflow as tf

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)

Authors

Sankha S. Mukherjee - Initial work (2021)

License

This project is licensed under the MIT License - see the LICENSE.txt file for details

References

  1. Simple and Effective VAE Training with Calibrated Decoders

vae's People

Contributors

sankhamukherjee avatar

Watchers

 avatar  avatar

Forkers

jacklgs

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.