Code Monkey home page Code Monkey logo

Comments (14)

87hbteo avatar 87hbteo commented on May 27, 2024 13


from pytorch-vae.

MrChenFeng avatar MrChenFeng commented on May 27, 2024 12




Hello, can you give more explainations about the torch.sum(latentloss,1)

同样遇到这个问题,读一下原论文你会发现整体损失里面的recon损失相对于KLdivergence而言就是考虑整个batch,而在整合loss的时候,criterion = nn.MSELoss()这里默认的是做batch average,改成criterion = nn.MSELoss(size_average=False)也能work。

from pytorch-vae.

MPolaris avatar MPolaris commented on May 27, 2024 1



from pytorch-vae.

87hbteo avatar 87hbteo commented on May 27, 2024




from pytorch-vae.

shukoushuu avatar shukoushuu commented on May 27, 2024




Hello, can you give more explainations about the torch.sum(latentloss,1)

from pytorch-vae.

joray86 avatar joray86 commented on May 27, 2024




Hello, can you give more explainations about the torch.sum(latentloss,1)

同样遇到这个问题,读一下原论文你会发现整体损失里面的recon损失相对于KLdivergence而言就是考虑整个batch,而在整合loss的时候,criterion = nn.MSELoss()这里默认的是做batch average,改成criterion = nn.MSELoss(size_average=False)也能work。


from pytorch-vae.

lantudou avatar lantudou commented on May 27, 2024

This is not a very good demo for VAE. The relu non-linear function for decoding output will limit the range of the pixels in your reconstruction image. If your input image has been normalized using the transform in pytorch, it may cause your loss cannot be decreased.

For the loss funtion, mse is ok for reconstruction loss but you should take care of the dimension problem. Firstly, you should finish the loss caculation for every samples in the batch. and finally get the mean value of the total batch loss.

Last but not least, using resize for tensor will cause the extra memory copy, try to use view() to instead it.

Here is my simple correction version, when I am debugging the VAE model, I prefer to printing the KL loss, which is a good representation to indicate whether your model is work. If this value increases with training, it proves that your model has learned the characteristics of input and is committed to output more diversified images.

import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt

class Encoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Encoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.tanh(self.linear2(x))

class Decoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.tanh(self.linear2(x))

class VAE(torch.nn.Module):

    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.latent_dim = 32
        self._enc_mu = torch.nn.Linear(100, self.latent_dim)
        self._enc_log_sigma = torch.nn.Linear(100, self.latent_dim)
        self.mu_bn = torch.nn.BatchNorm1d(self.latent_dim)
        self.mu_bn.weight.requires_grad = False
        nn.init.constant_(self.mu_bn.bias, 0.0)
    def _sample_latent(self, h_enc):
        Return the latent normal sample z ~ N(mu, sigma^2)
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
        self.z_mean = self.mu_bn(mu)
        self.z_sigma = sigma

        self.z = self.z_mean + self.z_sigma * Variable(std_z, requires_grad=False)
        return self.z  # Reparameterization trick

    def forward(self, state):

        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)

class VAE2(torch.nn.Module):

    def __init__(self, enc_out_dim=512, latent_dim=256, input_height=32):
        super(VAE2, self).__init__()
        self.encoder = resnet18_encoder(False, False)
        self.decoder = resnet18_decoder(

        self.latent_dim = latent_dim
        self._enc_mu = torch.nn.Linear(100, self.latent_dim)
        self._enc_log_sigma = torch.nn.Linear(100, self.latent_dim)
        self.mu_bn = torch.nn.BatchNorm1d(self.latent_dim)
        self.mu_bn.weight.requires_grad = False
        nn.init.constant_(self.mu_bn.bias, 0.0)
    def _sample_latent(self, h_enc):
        Return the latent normal sample z ~ N(mu, sigma^2)
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
        self.z_mean = self.mu_bn(mu)
        self.z_sigma = sigma

        self.z = self.z_mean + self.z_sigma * Variable(std_z, requires_grad=False)
        return self.z  # Reparameterization trick

    def forward(self, state):

        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)

def gaussian_likelihood(x_hat,x):

    log_scale = nn.Parameter(torch.Tensor([0.0]))
    scale = torch.exp(log_scale)
    mean = x_hat
    dist = torch.distributions.Normal(mean, scale)

    # measure prob of seeing image under p(x|z)
    log_pxz = dist.log_prob(x)
    return log_pxz.sum(dim=(1))

def kl_divergence(z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
    p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
    q = torch.distributions.Normal(mu, std)

            # 2. get the probabilities from the equation
    log_qzx = q.log_prob(z)
    log_pz = p.log_prob(z)

            # kl
    kl = (log_qzx - log_pz)
    kl = kl.sum(-1)
    return kl

def show_image_grid(images, batch_size=8):
    fig = plt.figure(figsize=(8, batch_size/10))
    #fig.suptitle("Pass {}".format(pass_id))
    gs = plt.GridSpec(int(batch_size/10)+1, 10)
    gs.update(wspace=0.05, hspace=0.05)

    for i, image in enumerate(images):
        ax = plt.subplot(gs[i])
       # print(image)
        plt.imshow(image.reshape(28, 28)*0.5+0.5,cmap='gray')

if __name__ == '__main__':

    input_dim = 28 * 28
    batch_size = 128

    transform = transforms.Compose([transforms.Resize([28, 28]),
        # transforms.Normalize(0.5,0.5),

    mnist = torchvision.datasets.MNIST('./', download=True, transform=transform)

    dataloader =, batch_size=batch_size, shuffle=True, num_workers=2)

    print('Number of samples: ', len(mnist))

    encoder = Encoder(input_dim, 100, 100)
    decoder = Decoder(32, 100, input_dim)
    vae = VAE(encoder, decoder)

    #criterion = nn.MSELoss(size_average=False)

    optimizer = optim.Adam(vae.parameters(), lr=0.0001)
    l = None
    for epoch in range(100):
        for i, data in enumerate(dataloader, 0):
            inputs, classes = data
            #inputs, classes = Variable(inputs.resize_(batch_size, input_dim)), Variable(classes)
            dec = vae(inputs.view(-1,784))
            #ll = latent_loss(vae.z_mean, vae.z_sigma)
            ll = kl_divergence(vae.z, vae.z_mean, vae.z_sigma)
            #dec_ll = criterion(dec, inputs)
            dec_ll =gaussian_likelihood(dec,inputs.view(-1,784))

            elbo = (ll - dec_ll)

            loss = elbo.mean()

        dec = vae(inputs.view(-1,784))
        out = dec.detach().numpy()
        show_image_grid(inputs, batch_size)
        print(epoch, dec_ll.mean().item(),ll.mean().item(),loss.item())


    plt.imshow(vae(inputs).item().numpy().reshape(28, 28), cmap='gray')

The epoch 100 result:
Orignal images
Reconstruction images

from pytorch-vae.

zl457 avatar zl457 commented on May 27, 2024

This is not a very good demo for VAE. The relu non-linear function for decoding output will limit the range of the pixels in your reconstruction image. If your input image has been normalized using the transform in pytorch, it may cause your loss cannot be decreased.

For the loss funtion, mse is ok for reconstruction loss but you should take care of the dimension problem. Firstly, you should finish the loss caculation for every samples in the batch. and finally get the mean value of the total batch loss.

Last but not least, using resize for tensor will cause the extra memory copy, try to use view() to instead it.

Here is my simple correction version, when I am debugging the VAE model, I prefer to printing the KL loss, which is a good representation to indicate whether your model is work. If this value increases with training, it proves that your model has learned the characteristics of input and is committed to output more diversified images.

import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt

class Encoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Encoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.tanh(self.linear2(x))

class Decoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.tanh(self.linear2(x))

class VAE(torch.nn.Module):

    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.latent_dim = 32
        self._enc_mu = torch.nn.Linear(100, self.latent_dim)
        self._enc_log_sigma = torch.nn.Linear(100, self.latent_dim)
        self.mu_bn = torch.nn.BatchNorm1d(self.latent_dim)
        self.mu_bn.weight.requires_grad = False
        nn.init.constant_(self.mu_bn.bias, 0.0)
    def _sample_latent(self, h_enc):
        Return the latent normal sample z ~ N(mu, sigma^2)
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
        self.z_mean = self.mu_bn(mu)
        self.z_sigma = sigma

        self.z = self.z_mean + self.z_sigma * Variable(std_z, requires_grad=False)
        return self.z  # Reparameterization trick

    def forward(self, state):

        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)

class VAE2(torch.nn.Module):

    def __init__(self, enc_out_dim=512, latent_dim=256, input_height=32):
        super(VAE2, self).__init__()
        self.encoder = resnet18_encoder(False, False)
        self.decoder = resnet18_decoder(

        self.latent_dim = latent_dim
        self._enc_mu = torch.nn.Linear(100, self.latent_dim)
        self._enc_log_sigma = torch.nn.Linear(100, self.latent_dim)
        self.mu_bn = torch.nn.BatchNorm1d(self.latent_dim)
        self.mu_bn.weight.requires_grad = False
        nn.init.constant_(self.mu_bn.bias, 0.0)
    def _sample_latent(self, h_enc):
        Return the latent normal sample z ~ N(mu, sigma^2)
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
        self.z_mean = self.mu_bn(mu)
        self.z_sigma = sigma

        self.z = self.z_mean + self.z_sigma * Variable(std_z, requires_grad=False)
        return self.z  # Reparameterization trick

    def forward(self, state):

        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)

def gaussian_likelihood(x_hat,x):

    log_scale = nn.Parameter(torch.Tensor([0.0]))
    scale = torch.exp(log_scale)
    mean = x_hat
    dist = torch.distributions.Normal(mean, scale)

    # measure prob of seeing image under p(x|z)
    log_pxz = dist.log_prob(x)
    return log_pxz.sum(dim=(1))

def kl_divergence(z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
    p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
    q = torch.distributions.Normal(mu, std)

            # 2. get the probabilities from the equation
    log_qzx = q.log_prob(z)
    log_pz = p.log_prob(z)

            # kl
    kl = (log_qzx - log_pz)
    kl = kl.sum(-1)
    return kl

def show_image_grid(images, batch_size=8):
    fig = plt.figure(figsize=(8, batch_size/10))
    #fig.suptitle("Pass {}".format(pass_id))
    gs = plt.GridSpec(int(batch_size/10)+1, 10)
    gs.update(wspace=0.05, hspace=0.05)

    for i, image in enumerate(images):
        ax = plt.subplot(gs[i])
       # print(image)
        plt.imshow(image.reshape(28, 28)*0.5+0.5,cmap='gray')

if __name__ == '__main__':

    input_dim = 28 * 28
    batch_size = 128

    transform = transforms.Compose([transforms.Resize([28, 28]),
        # transforms.Normalize(0.5,0.5),

    mnist = torchvision.datasets.MNIST('./', download=True, transform=transform)

    dataloader =, batch_size=batch_size, shuffle=True, num_workers=2)

    print('Number of samples: ', len(mnist))

    encoder = Encoder(input_dim, 100, 100)
    decoder = Decoder(32, 100, input_dim)
    vae = VAE(encoder, decoder)

    #criterion = nn.MSELoss(size_average=False)

    optimizer = optim.Adam(vae.parameters(), lr=0.0001)
    l = None
    for epoch in range(100):
        for i, data in enumerate(dataloader, 0):
            inputs, classes = data
            #inputs, classes = Variable(inputs.resize_(batch_size, input_dim)), Variable(classes)
            dec = vae(inputs.view(-1,784))
            #ll = latent_loss(vae.z_mean, vae.z_sigma)
            ll = kl_divergence(vae.z, vae.z_mean, vae.z_sigma)
            #dec_ll = criterion(dec, inputs)
            dec_ll =gaussian_likelihood(dec,inputs.view(-1,784))

            elbo = (ll - dec_ll)

            loss = elbo.mean()

        dec = vae(inputs.view(-1,784))
        out = dec.detach().numpy()
        show_image_grid(inputs, batch_size)
        print(epoch, dec_ll.mean().item(),ll.mean().item(),loss.item())


    plt.imshow(vae(inputs).item().numpy().reshape(28, 28), cmap='gray')

The epoch 100 result: Orignal images image Reconstruction images image

hi, what's the resnet_encoder?

from pytorch-vae.

lantudou avatar lantudou commented on May 27, 2024

This is for my another test. Do not care about the VAE2 class and just delete it.

from pytorch-vae.

ethanluoyc avatar ethanluoyc commented on May 27, 2024

Thanks to everyone who's commenting here. I have not updated this repo for a while and also stopped using PyTorch (and switched to JAX).

This code was written when I first learned about VAE, so it's probably different from what I would write if I were to do it now. Just so you know, PyTorch now has a torch.distributions package which would allow you to compute the log-likelihood and KL divergences easily instead of hard-coding it as I did, so you should definitely use that. A very good example of VAE in JAX (but should read very well for anyone who is familiar with NumPy is

My PyTorch is a bit rusty now but just to answer a few general questions raised here, in the hope it clarifies things

Using MSE is OK since if you assume a Gaussian likelihood with fixed stddev, then the log-likelihood takes the form of square error. They will be different from the proper Gaussian likelihood by a constant so using MSE corresponds to beta-VAE with beta != 1, but you are right that you should sum across the dimension of the image but take the mean across the batch. As a rule of thumb, for any VAE-based models, to implement what's in the original VAE paper, you should sum across the dimensions of the inputs and latents and average across the batch. Doing it differently can still work in practice and you can refer to the beta-VAE paper for more details.

In addition, the above makes the assumption of fixed stddev, in practice, you can also learn the stddev albeit being quite difficult in general.

A Gaussian decoder is not the only choice. If you work with binary MNIST then using the Bernoulli distribution may be a more proper choice, and that corresponds to the cross-entropy loss. If you work with a Gaussian decoder, then indeed problematic to use another activation function since the parameters of the Gaussian distribution may be negative. It's ok to add ReLU if you work with a Bernoulli decoder and assume that the output of the decoder is used to parameterize the probs (instead of logits), but I am not sure how good it would perform in practice. It also seems in practice that using the Bernoulli decoder works better for MNIST, but theoretically, both a Gaussian likelihood and Bernoulli are fine, it is only a difference in the modeling choice. There's nothing wrong in theory, but which one works best depends on the practical application.

Normalizing the image to be [-0.5, 0.5] may be useful if you work with Gaussian likelihoods as typically the output of the NN would be initialized to be centered around zero. This may have some positive impact on the optimization.

from pytorch-vae.

ethanluoyc avatar ethanluoyc commented on May 27, 2024

Updated code which uses new PyTorch as well as some comments.

import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt
from torch import distributions

class Encoder(torch.nn.Module):
    def __init__(self, D_in, H, latent_size):
        super(Encoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, H)
        self.enc_mu = torch.nn.Linear(H, latent_size)
        self.enc_log_sigma = torch.nn.Linear(H, latent_size)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        mu = self.enc_mu(x)
        log_sigma = self.enc_log_sigma(x)
        sigma = torch.exp(log_sigma)
        return torch.distributions.Normal(loc=mu, scale=sigma)

class Decoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        mu = torch.tanh(self.linear2(x))
        return torch.distributions.Normal(mu, torch.ones_like(mu))

class VAE(torch.nn.Module):
    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, state):
        q_z = self.encoder(state)
        z = q_z.rsample()
        return self.decoder(z), q_z

transform = transforms.Compose(
     # Normalize the images to be -0.5, 0.5
     transforms.Normalize(0.5, 1)]
mnist = torchvision.datasets.MNIST('./', download=True, transform=transform)

input_dim = 28 * 28
batch_size = 128
num_epochs = 100
learning_rate = 0.001
hidden_size = 512
latent_size = 8

if torch.cuda.is_available():
    device = torch.device('cuda')
    device = torch.device('cpu')

dataloader =
    mnist, batch_size=batch_size,

print('Number of samples: ', len(mnist))

encoder = Encoder(input_dim, hidden_size, latent_size)
decoder = Decoder(latent_size, hidden_size, input_dim)

vae = VAE(encoder, decoder).to(device)

optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
    for data in dataloader:
        inputs, _ = data
        inputs = inputs.view(-1, input_dim).to(device)
        p_x, q_z = vae(inputs)
        log_likelihood = p_x.log_prob(inputs).sum(-1).mean()
        kl = torch.distributions.kl_divergence(
            torch.distributions.Normal(0, 1.)
        loss = -(log_likelihood - kl)
        l = loss.item()
    print(epoch, l, log_likelihood.item(), kl.item())

from pytorch-vae.

coderaBruce avatar coderaBruce commented on May 27, 2024

I believe I did run the updated version @ethanluoyc, it seems still doesn't work well whether the task is generating from latent space or trying to recover the original images. Plus loss does not drop at all either.
Wander if anyone had a successful attempt? @87hbteo @MrChenFeng @joray86 @MrChenFeng ?

I just doubt if it is proper to calculate log-likelihood as p_x.log_prob(inputs).sum(-1).mean()? Since this seems to be calculating the log-likelihood of p(x|z) while we should target the log-likelihood of p(x)?

from pytorch-vae.

ethanluoyc avatar ethanluoyc commented on May 27, 2024

I just tested and indeed the new version I posted does not work well. Looking at it though, I am not sure what I did differently from @zl457.

@coderaBruce I believe the calculation of the expected log-likelihood is correct. If you look at @zl457's answer the new version looks the same.

I don't have time to look at this now. My hypothesis right now is that using a Gaussian likelihood in conjunction with sigmoid/tanh causes some vanishing gradient problem that otherwise would not occur if you use a BCE likelihood. Referring to @zl457 's answer, it looks like there is an additional BatchNorm layer that was missing from my updated version. I can try to take a look again later this week to see what caused the bad result.

from pytorch-vae.

handesome avatar handesome commented on May 27, 2024

This is not a very good demo for VAE. The relu non-linear function for decoding output will limit the range of the pixels in your reconstruction image. If your input image has been normalized using the transform in pytorch, it may cause your loss cannot be decreased.

For the loss funtion, mse is ok for reconstruction loss but you should take care of the dimension problem. Firstly, you should finish the loss caculation for every samples in the batch. and finally get the mean value of the total batch loss.

Last but not least, using resize for tensor will cause the extra memory copy, try to use view() to instead it.

Here is my simple correction version, when I am debugging the VAE model, I prefer to printing the KL loss, which is a good representation to indicate whether your model is work. If this value increases with training, it proves that your model has learned the characteristics of input and is committed to output more diversified images.

import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt

class Encoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Encoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.tanh(self.linear2(x))

class Decoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.tanh(self.linear2(x))

class VAE(torch.nn.Module):

    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.latent_dim = 32
        self._enc_mu = torch.nn.Linear(100, self.latent_dim)
        self._enc_log_sigma = torch.nn.Linear(100, self.latent_dim)
        self.mu_bn = torch.nn.BatchNorm1d(self.latent_dim)
        self.mu_bn.weight.requires_grad = False
        nn.init.constant_(self.mu_bn.bias, 0.0)
    def _sample_latent(self, h_enc):
        Return the latent normal sample z ~ N(mu, sigma^2)
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
        self.z_mean = self.mu_bn(mu)
        self.z_sigma = sigma

        self.z = self.z_mean + self.z_sigma * Variable(std_z, requires_grad=False)
        return self.z  # Reparameterization trick

    def forward(self, state):

        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)

class VAE2(torch.nn.Module):

    def __init__(self, enc_out_dim=512, latent_dim=256, input_height=32):
        super(VAE2, self).__init__()
        self.encoder = resnet18_encoder(False, False)
        self.decoder = resnet18_decoder(

        self.latent_dim = latent_dim
        self._enc_mu = torch.nn.Linear(100, self.latent_dim)
        self._enc_log_sigma = torch.nn.Linear(100, self.latent_dim)
        self.mu_bn = torch.nn.BatchNorm1d(self.latent_dim)
        self.mu_bn.weight.requires_grad = False
        nn.init.constant_(self.mu_bn.bias, 0.0)
    def _sample_latent(self, h_enc):
        Return the latent normal sample z ~ N(mu, sigma^2)
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
        self.z_mean = self.mu_bn(mu)
        self.z_sigma = sigma

        self.z = self.z_mean + self.z_sigma * Variable(std_z, requires_grad=False)
        return self.z  # Reparameterization trick

    def forward(self, state):

        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)

def gaussian_likelihood(x_hat,x):

    log_scale = nn.Parameter(torch.Tensor([0.0]))
    scale = torch.exp(log_scale)
    mean = x_hat
    dist = torch.distributions.Normal(mean, scale)

    # measure prob of seeing image under p(x|z)
    log_pxz = dist.log_prob(x)
    return log_pxz.sum(dim=(1))

def kl_divergence(z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
    p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
    q = torch.distributions.Normal(mu, std)

            # 2. get the probabilities from the equation
    log_qzx = q.log_prob(z)
    log_pz = p.log_prob(z)

            # kl
    kl = (log_qzx - log_pz)
    kl = kl.sum(-1)
    return kl

def show_image_grid(images, batch_size=8):
    fig = plt.figure(figsize=(8, batch_size/10))
    #fig.suptitle("Pass {}".format(pass_id))
    gs = plt.GridSpec(int(batch_size/10)+1, 10)
    gs.update(wspace=0.05, hspace=0.05)

    for i, image in enumerate(images):
        ax = plt.subplot(gs[i])
       # print(image)
        plt.imshow(image.reshape(28, 28)*0.5+0.5,cmap='gray')

if __name__ == '__main__':

    input_dim = 28 * 28
    batch_size = 128

    transform = transforms.Compose([transforms.Resize([28, 28]),
        # transforms.Normalize(0.5,0.5),

    mnist = torchvision.datasets.MNIST('./', download=True, transform=transform)

    dataloader =, batch_size=batch_size, shuffle=True, num_workers=2)

    print('Number of samples: ', len(mnist))

    encoder = Encoder(input_dim, 100, 100)
    decoder = Decoder(32, 100, input_dim)
    vae = VAE(encoder, decoder)

    #criterion = nn.MSELoss(size_average=False)

    optimizer = optim.Adam(vae.parameters(), lr=0.0001)
    l = None
    for epoch in range(100):
        for i, data in enumerate(dataloader, 0):
            inputs, classes = data
            #inputs, classes = Variable(inputs.resize_(batch_size, input_dim)), Variable(classes)
            dec = vae(inputs.view(-1,784))
            #ll = latent_loss(vae.z_mean, vae.z_sigma)
            ll = kl_divergence(vae.z, vae.z_mean, vae.z_sigma)
            #dec_ll = criterion(dec, inputs)
            dec_ll =gaussian_likelihood(dec,inputs.view(-1,784))

            elbo = (ll - dec_ll)

            loss = elbo.mean()

        dec = vae(inputs.view(-1,784))
        out = dec.detach().numpy()
        show_image_grid(inputs, batch_size)
        print(epoch, dec_ll.mean().item(),ll.mean().item(),loss.item())


    plt.imshow(vae(inputs).item().numpy().reshape(28, 28), cmap='gray')

The epoch 100 result: Orignal images image Reconstruction images image

hi!Thank you for the code!!when I run your code"plt.imshow(vae(inputs).item().numpy().reshape(28, 28), cmap='gray')",the error says"RuntimeError: mat1 and mat2 shapes cannot be multiplied (2688x28 and 784x100)",could you please give some advice?

from pytorch-vae.

Related Issues (3)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.