Code Monkey home page Code Monkey logo

tensorflow2-generative-models's Introduction

Binder

Generative models in Tensorflow 2

Tim Sainburg (PhD Candidate, UCSD, Gentner Laboratory)

This is a small project to implement a number of generative models in Tensorflow 2. Layers and optimizers use Keras. The models are implemented for two datasets: fashion MNIST, and NSYNTH. Networks were written with the goal of being as simple and consistent as possible while still being readable. Because each network is self contained within the notebook, they should be easily run in a colab session.

Included models:

Autoencoder (AE) Open In Colab

A simple autoencoder network.

an autoencoder

Variational Autoencoder (VAE) (article) Open In Colab

The original variational autoencoder network, using tensorflow_probability

variational autoencoder

Generative Adversarial Network (GAN) (article) Open In Colab

GANs are a form of neural network in which two sub-networks (the encoder and decoder) are trained on opposing loss functions: an encoder that is trained to produce data which is indiscernable from the true data, and a decoder that is trained to discriminate between the data and generated data.

gan

Wasserstein GAN with Gradient Penalty (WGAN-GP) (article) Open In Colab

WGAN-GP is a GAN that improves over the original loss function to improve training stability.

wgan gp

VAE-GAN (article) Open In Colab

VAE-GAN combines the VAE and GAN to autoencode over a latent representation of data in the generator to improve over the pixelwise error function used in autoencoders.

vae gan

Generative adversarial interpolative autoencoder (GAIA) (article) Open In Colab

GAIA is an autoencoder trained to learn convex latent representations by adversarially training on interpolations in latent space projections of real data. This is an experimental modification of the original algorithm. For the original algorithm, see here: https://github.com/timsainb/gaia

generative adversarial interpolative autoencoding network

Other Notebooks:

Seq2Seq Autoencoder (without attention) (Fasion MNIST: Open In Colab | NSYNTH: Open In Colab)

Seq2Seq models use recurrent neural network cells (like LSTMs) to better capture sequential organization in data. This implementation uses Convolutional Layers as input to the LSTM cells, and a single Bidirectional LSTM layer.

a seq2seq bidirectional lstm in tensorflow 2.0

Spectrogramming, Mel Scaling, MFCCs, and Inversion in Tensorflow Open In Colab

Tensorflow has a signal processing package that allows us to generate spectrograms from waveforms as part of our dataset iterator, rather than pregenerating a second spectrogram dataset. This notebook can serve as a reference for how this is done. Spectrogram inversion is done using the Griffin-Lim algorithm.

spectrogram inversion in tensorflow 2.0

Iterator for NSynth Open In Colab

The NSYNTH dataset is a set of thousands of musical notes saved as waveforms. To input these into a Seq2Seq model as spectrograms, I wrote a small dataset class that converts to spectrogram in tensorflow (using the code from the spectrogramming notebook).

a dataset iterator for tensorflow 2.0

tensorflow2-generative-models's People

Contributors

timsainb avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tensorflow2-generative-models's Issues

Activation function for discriminator of WGAN-GP

In the network architecture of discriminator:
tf.keras.layers.Dense(units=1, activation="sigmoid"),
You don't need any activation function for discriminator, since you are using Wasserstein Loss. Using sigmoid here would greatly limit your convergeance speed.

WGAN-GP Loss

Hi,
In Wasserstein GAN with Gradient Penalty (WGAN-GP) code, I think there is a mistake in disc_loss formulation.
disc_loss = (
tf.reduce_mean(logits_x)
- tf.reduce_mean(logits_x_gen)
+ d_regularizer * self.gradient_penalty_weight
)
According to the paper disc_loss has to be D(x_tilda) - D(x) + ...
So the sign of the first two terms should be vice versa.
I am wondering if it makes any difference or not! Or just having opposite signs is the key?

GPU not being used

*This is not an issue

Hello Tim, first of all I would like to thank you for your work.

I've notice that the 3.0 Fashion GAN is not running on my GPU (1080 TI).
Does it runs on yours?

Nice Regards.

InfoGAN in Tf2.0

This isn’t an issue but I was wondering if you could implement the InfoGAN model in tensorflow 2.0 as well.

AttributeError: 'VAEGAN' object has no attribute 'D_prop'

Hi there,

I have tried running this code and I cannot get past the create model step, I've pasted the error below. Please let me know if you need more information.

Thanks!

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[32], line 6
      3 disc_optimizer = tf.keras.optimizers.RMSprop(1e-3)
      5 # model
----> 6 model = VAEGAN(
      7     enc = encoder,
      8     dec = decoder,
      9     vae_disc_function = vaegan_discrim,
     10     lr_base_gen = 1e-3, # 
     11     lr_base_disc = 1e-4, # the discriminator's job is easier than the generators so make the learning rate lower
     12     latent_loss_div=1, # this variable will depend on your dataset - choose a number that will bring your latent loss to ~1-10
     13     sig_mult = 10, # how binary the discriminator's learning rate is shifted (we squash it with a sigmoid)
     14     recon_loss_div = .001, # this variable will depend on your dataset - choose a number that will bring your latent loss to ~1-10
     15 )

Cell In[29], line 19, in VAEGAN.__init__(self, **kwargs)
     17 self.enc_optimizer = tf.keras.optimizers.Adam(self.lr_base_gen, beta_1=0.5)
     18 self.dec_optimizer = tf.keras.optimizers.Adam(self.lr_base_gen, beta_1=0.5)
---> 19 self.disc_optimizer = tf.keras.optimizers.Adam(self.get_lr_d, beta_1=0.5)

File ~/miniconda3/envs/spatial_3_10/lib/python3.10/site-packages/keras/optimizers/optimizer_experimental/adam.py:116, in Adam.__init__(self, learning_rate, beta_1, beta_2, epsilon, amsgrad, weight_decay, clipnorm, clipvalue, global_clipnorm, use_ema, ema_momentum, ema_overwrite_frequency, jit_compile, name, **kwargs)
     86 def __init__(
     87     self,
     88     learning_rate=0.001,
   (...)
    102     **kwargs
    103 ):
    104     super().__init__(
    105         name=name,
    106         weight_decay=weight_decay,
   (...)
    114         **kwargs
    115     )
--> 116     self._learning_rate = self._build_learning_rate(learning_rate)
    117     self.beta_1 = beta_1
    118     self.beta_2 = beta_2

File ~/miniconda3/envs/spatial_3_10/lib/python3.10/site-packages/keras/optimizers/optimizer_experimental/optimizer.py:378, in _BaseOptimizer._build_learning_rate(self, learning_rate)
    370     self._current_learning_rate = tf.Variable(
    371         current_learning_rate,
    372         name="current_learning_rate",
    373         dtype=current_learning_rate.dtype,
    374         trainable=False,
    375     )
    376     return learning_rate
--> 378 return tf.Variable(
    379     learning_rate,
    380     name="learning_rate",
    381     dtype=backend.floatx(),
    382     trainable=False,
    383 )

File ~/miniconda3/envs/spatial_3_10/lib/python3.10/site-packages/tensorflow/python/util/traceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
    151 except Exception as e:
    152   filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153   raise e.with_traceback(filtered_tb) from None
    154 finally:
    155   del filtered_tb

Cell In[29], line 30, in VAEGAN.get_lr_d(self)
     29 def get_lr_d(self):
---> 30     return self.lr_base_disc * self.D_prop

AttributeError: 'VAEGAN' object has no attribute 'D_prop'

Incorrect loss for AutoEncoder

In computing the testing loss for the autoencoder, you reuse the training set

# test on holdout

loss = []
    for batch, test_x in tqdm(
        zip(range(N_TRAIN_BATCHES), train_dataset), total=N_TRAIN_BATCHES # Why is this using training again?
    ):

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.