Code Monkey home page Code Monkey logo

srvae's Introduction

VAE and Super-Resolution VAE in PyTorch

Python 3.6 PyTorch 1.3 MIT

Code release for Super-Resolution Variational Auto-Encoders

Abstract

The framework of Variational Auto-Encoders (VAEs) provides a principled manner of reasoning in latent-variable models using variational inference. However, the main drawback of this approach is blurriness of generated images. Some studies link this effect to the objective function, namely, the (negative) log-likelihood function. Here, we propose to enhance VAEs by adding a random variable that is a downscaled version of the original image and still use the log-likelihood function as the learning objective. Further, we provide the downscaled image as an input to the decoder and use it in a manner similar to the super-resolution. We present empirically that the proposed approach performs comparably to VAEs in terms of the negative log-likelihood function, but it obtains a better FID score.

Features

  • Models

    • VAE
    • Super-resolution VAE (srVAE)
  • Priors

    • Standard (unimodal) Gaussian
    • Mixture of Gaussians
    • RealNVP
  • Reconstruction Loss

    • Discretized Mixture of Logistics Loss
  • Neural Networks

    • DenseNet
  • Datasets

    • CIFAR-10

Quantitative results

Model nll
VAE 3.51
srVAE 3.65

Results on CIFAR-10. The log-likelihood value nll was estimated using 500 weighted samples on the test set (10k images).

Qualitative results

VAE

Results from VAE with RealNVP Prior trained on CIFAR10.

Interpolations

Reconstructions.

Unconditional generations.

Super-Resolution VAE

Results from Super-Resolution VAE trained on CIFAR10.

Interpolations

Super-Resolution results of the srVAE on CIFAR-10

Unconditional generations. Left: The generations of the first step, the compressed representations that capture the _global_ structure. Right: The final result after enhasing the images with local content.

Requirements

The code is compatible with:

  • python 3.6
  • pytorch 1.3

Usage

  • To run VAE with RealNVP prior on CIFAR-10, please execude:
python main.py --model VAE --network densenet32 --prior RealNVP
  • Otherwise, to run srVAE:
python main.py --model srVAE --network densenet16x32 --prior RealNVP

Cite

Please cite our paper if you use this code in your own work:

@misc{gatopoulos2020superresolution,
    title={Super-resolution Variational Auto-Encoders},
    author={Ioannis Gatopoulos and Maarten Stol and Jakub M. Tomczak},
    year={2020},
    eprint={2006.05218},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}

Acknowledgements

This work was supported and funded from the University of Amsterdam, and BrainCreators B.V..

Repo Author

Ioannis Gatopoulos, 2020

srvae's People

Contributors

ioangatop avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

srvae's Issues

Ground truth in forward pass?

Hello, first of all thank you for the code!
I have a question about the forward pass in the srVAE model:

def forward(self, x, **kwargs):
""" Forward pass through the inference and the generative model. """
# y ~ f(x) (determinist)
y = self.compressed_transoformation(x)
# u ~ q(u| y)
u_q_mean, u_q_logvar = self.q_u(y)
u_q = self.reparameterize(u_q_mean, u_q_logvar)
# z ~ q(z| x, y)
z_q_mean, z_q_logvar = self.q_z(x)
z_q = self.reparameterize(z_q_mean, z_q_logvar)
# x ~ p(x| y, z)
x_logits = self.p_x((y, z_q))
# y ~ p(y| u)
y_logits = self.p_y(u_q)

It looks like p_x gets the ground truth y as input. Should this not be the y that is predicted by p_y instead?

Best regards!

Scale factor has NaN entries when training my own dataset

I defined my own datasets class in src/data/datasets.py and import this class in dataloader, but I met the following problem when run with
python3 main.py --model VAE --network densenet32 --prior RealNVP --dataset mydataset --img_resize 128

Thanks for help!

Traceback (most recent call last):
  File "main.py", line 119, in <module>
    main()
  File "main.py", line 106, in main
    train_model(args.dataset, args.model, writer)
  File "main.py", line 21, in train_model
    model.module.initialize(train_loader)  
  File "/data3/daniel/srVAE/src/models/vae/vae.py", line 53, in initialize
    self.calculate_elbo(x, output)
  File "/data3/daniel/srVAE/src/models/vae/vae.py", line 89, in calculate_elbo
    log_p_z = self.p_z.log_p(z_q)
  File "/data3/daniel/srVAE/src/modules/priors/realnvp/model/real_nvp.py", line 56, in log_p
    z, sldj = self.forward(x, reverse=False)
  File "/data3/daniel/srVAE/src/modules/priors/realnvp/model/real_nvp.py", line 75, in forward
    x, sldj = self.flows(x, sldj, reverse)
  File "/home/daniel/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/data3/daniel/srVAE/src/modules/priors/realnvp/model/real_nvp.py", line 137, in forward
    x, sldj = coupling(x, sldj, reverse)
  File "/home/daniel/.local/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
    result = self.forward(*input, **kwargs)
  File "/data3/daniel/srVAE/src/modules/priors/realnvp/model/coupling_layer.py", line 69, in forward
    raise RuntimeError('Scale factor has NaN entries')

FID score of VAE with RealNVP prior

Hi,
thanks for your great work!

May I ask for the FID score of the model trained with python main.py --model VAE --network densenet32 --prior RealNVP, i.e. the VAE model with RealNVP?
It's expected to be 41 according to the paper but I get 65 instead.

Thanks!

train with multiple gpus

Hi! When I run the code with 2 or 3 gpus, the program will turn into Sl+ process. Do you know the reason?

There are some discrepancies between your paper and code.

``class DenseNetBlock(nn.Module):
def init(self, inplanes, growth_rate, drop_prob=0.0):
super().init()
self.dense_block = ### ### nn.Sequential(
Conv2d(inplanes, 4 * growth_rate,
kernel_size=1, stride=1, padding=0, drop_prob=drop_prob),
Conv2d(4 * growth_rate, growth_rate,
kernel_size=3, stride=1, padding=1, drop_prob=drop_prob, act=None)
)

def forward(self, input):
    y = self.dense_block(input)
    y = torch.cat([input, y], dim=1)
    return y

in the part , I think it may miss a ELU function between two COnv2d models.Am I right?

Training larger images

Thank you for writing such an excellent and convenient code!But when i trying to training larger images,there are errors about NAN in 'CALayer' modules.I don't understand why the conv2d outputs the NAN values?Looking forward to your answer~

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.