Code Monkey home page Code Monkey logo

Comments (14)

87hbteo avatar 87hbteo commented on May 27, 2024 13

不好意思刚刚发现好像大家都是**人...因为不知道怎么放图片所以可能没有办法很好地展示我这边跑代码的结果...如果大致用语言描述的话,所有的图片都非常模糊(虽然VAE本身就会比较模糊,但是用这个代码跑出来比一般的VAE还会模糊一些)而且所有的图片可以说都一样...是一个非常模糊的9的样子。我后来参照了别人的代码发现主要是损失函数使用MSE和decoder的最后一层使用了relu的缘故,mse的话因为算的是欧氏距离,在vae中其实很不合适,因为假设由两个一模一样的数字1,他们的位置可能就差了几个像素,我们会认为两个都是很好或者说很相近的图片,但是如果使用mse的话就会认为他们非常不相似,导致在多次运行之后,他们所生成的结果无论是形状还是位置都是趋于相同,这不利于VAE的生成。而sigmoid我是参照了别人的代码之后才改的,配合torch.nn.functional.binary_cross_entropy这个函数一起使用,最后达到了很好地效果,这个评论可能有些失礼,毕竟是在评价别人的代码,而且也有可能是我这边配置的问题,但是因为这个问题我确实是调了好几周,所以还是希望能写出来供大家参考,谢谢。

from pytorch-vae.

MrChenFeng avatar MrChenFeng commented on May 27, 2024 12

不好意思刚刚发现好像大家都是**人...因为不知道怎么放图片所以可能没有办法很好地展示我这边跑代码的结果...如果大致用语言描述的话,所有的图片都非常模糊(虽然VAE本身就会比较模糊,但是用这个代码跑出来比一般的VAE还会模糊一些)而且所有的图片可以说都一样...是一个非常模糊的9的样子。我后来参照了别人的代码发现主要是损失函数使用MSE和decoder的最后一层使用了relu的缘故,mse的话因为算的是欧氏距离,在vae中其实很不合适,因为假设由两个一模一样的数字1,他们的位置可能就差了几个像素,我们会认为两个都是很好或者说很相近的图片,但是如果使用mse的话就会认为他们非常不相似,导致在多次运行之后,他们所生成的结果无论是形状还是位置都是趋于相同,这不利于VAE的生成。而sigmoid我是参照了别人的代码之后才改的,配合torch.nn.functional.binary_cross_entropy这个函数一起使用,最后达到了很好地效果,这个评论可能有些失礼,毕竟是在评价别人的代码,而且也有可能是我这边配置的问题,但是因为这个问题我确实是调了好几周,所以还是希望能写出来供大家参考,谢谢。

我觉得应该是latentloss和最终的loss应该先对batchsize进行累加torch.sum(latentloss,1),最后再合并loss,我这么更改之后,效果就很不错了,之前全都是一个样。
至于你说的binary_cross_entropy,我看到其他代码也用了这个,只是很好奇,因为‘标签’并不是总是0和1的,我自己更改之后,基本都是乱码。

我刚刚试了一下确实如你所说的,mse应该也是好用的,不过无论是使用mse还是bce都需要调整sum等让loss有效。不过具体为什么bce也好用我确实也没有想明白。。。

你好,torch.sum(latentloss,1)没明白说得什么意思,可以解释一下吗?
Hello, can you give more explainations about the torch.sum(latentloss,1)

同样遇到这个问题,读一下原论文你会发现整体损失里面的recon损失相对于KLdivergence而言就是考虑整个batch,而在整合loss的时候,criterion = nn.MSELoss()这里默认的是做batch average,改成criterion = nn.MSELoss(size_average=False)也能work。

from pytorch-vae.

MPolaris avatar MPolaris commented on May 27, 2024 1

不好意思刚刚发现好像大家都是**人...因为不知道怎么放图片所以可能没有办法很好地展示我这边跑代码的结果...如果大致用语言描述的话,所有的图片都非常模糊(虽然VAE本身就会比较模糊,但是用这个代码跑出来比一般的VAE还会模糊一些)而且所有的图片可以说都一样...是一个非常模糊的9的样子。我后来参照了别人的代码发现主要是损失函数使用MSE和decoder的最后一层使用了relu的缘故,mse的话因为算的是欧氏距离,在vae中其实很不合适,因为假设由两个一模一样的数字1,他们的位置可能就差了几个像素,我们会认为两个都是很好或者说很相近的图片,但是如果使用mse的话就会认为他们非常不相似,导致在多次运行之后,他们所生成的结果无论是形状还是位置都是趋于相同,这不利于VAE的生成。而sigmoid我是参照了别人的代码之后才改的,配合torch.nn.functional.binary_cross_entropy这个函数一起使用,最后达到了很好地效果,这个评论可能有些失礼,毕竟是在评价别人的代码,而且也有可能是我这边配置的问题,但是因为这个问题我确实是调了好几周,所以还是希望能写出来供大家参考,谢谢。

我觉得应该是latentloss和最终的loss应该先对batchsize进行累加torch.sum(latentloss,1),最后再合并loss,我这么更改之后,效果就很不错了,之前全都是一个样。
至于你说的binary_cross_entropy,我看到其他代码也用了这个,只是很好奇,因为‘标签’并不是总是0和1的,我自己更改之后,基本都是乱码。

from pytorch-vae.

87hbteo avatar 87hbteo commented on May 27, 2024

不好意思刚刚发现好像大家都是**人...因为不知道怎么放图片所以可能没有办法很好地展示我这边跑代码的结果...如果大致用语言描述的话,所有的图片都非常模糊(虽然VAE本身就会比较模糊,但是用这个代码跑出来比一般的VAE还会模糊一些)而且所有的图片可以说都一样...是一个非常模糊的9的样子。我后来参照了别人的代码发现主要是损失函数使用MSE和decoder的最后一层使用了relu的缘故,mse的话因为算的是欧氏距离,在vae中其实很不合适,因为假设由两个一模一样的数字1,他们的位置可能就差了几个像素,我们会认为两个都是很好或者说很相近的图片,但是如果使用mse的话就会认为他们非常不相似,导致在多次运行之后,他们所生成的结果无论是形状还是位置都是趋于相同,这不利于VAE的生成。而sigmoid我是参照了别人的代码之后才改的,配合torch.nn.functional.binary_cross_entropy这个函数一起使用,最后达到了很好地效果,这个评论可能有些失礼,毕竟是在评价别人的代码,而且也有可能是我这边配置的问题,但是因为这个问题我确实是调了好几周,所以还是希望能写出来供大家参考,谢谢。

我觉得应该是latentloss和最终的loss应该先对batchsize进行累加torch.sum(latentloss,1),最后再合并loss,我这么更改之后,效果就很不错了,之前全都是一个样。
至于你说的binary_cross_entropy,我看到其他代码也用了这个,只是很好奇,因为‘标签’并不是总是0和1的,我自己更改之后,基本都是乱码。

我刚刚试了一下确实如你所说的,mse应该也是好用的,不过无论是使用mse还是bce都需要调整sum等让loss有效。不过具体为什么bce也好用我确实也没有想明白。。。

from pytorch-vae.

shukoushuu avatar shukoushuu commented on May 27, 2024

不好意思刚刚发现好像大家都是**人...因为不知道怎么放图片所以可能没有办法很好地展示我这边跑代码的结果...如果大致用语言描述的话,所有的图片都非常模糊(虽然VAE本身就会比较模糊,但是用这个代码跑出来比一般的VAE还会模糊一些)而且所有的图片可以说都一样...是一个非常模糊的9的样子。我后来参照了别人的代码发现主要是损失函数使用MSE和decoder的最后一层使用了relu的缘故,mse的话因为算的是欧氏距离,在vae中其实很不合适,因为假设由两个一模一样的数字1,他们的位置可能就差了几个像素,我们会认为两个都是很好或者说很相近的图片,但是如果使用mse的话就会认为他们非常不相似,导致在多次运行之后,他们所生成的结果无论是形状还是位置都是趋于相同,这不利于VAE的生成。而sigmoid我是参照了别人的代码之后才改的,配合torch.nn.functional.binary_cross_entropy这个函数一起使用,最后达到了很好地效果,这个评论可能有些失礼,毕竟是在评价别人的代码,而且也有可能是我这边配置的问题,但是因为这个问题我确实是调了好几周,所以还是希望能写出来供大家参考,谢谢。

我觉得应该是latentloss和最终的loss应该先对batchsize进行累加torch.sum(latentloss,1),最后再合并loss,我这么更改之后,效果就很不错了,之前全都是一个样。
至于你说的binary_cross_entropy,我看到其他代码也用了这个,只是很好奇,因为‘标签’并不是总是0和1的,我自己更改之后,基本都是乱码。

我刚刚试了一下确实如你所说的,mse应该也是好用的,不过无论是使用mse还是bce都需要调整sum等让loss有效。不过具体为什么bce也好用我确实也没有想明白。。。

你好,torch.sum(latentloss,1)没明白说得什么意思,可以解释一下吗?
Hello, can you give more explainations about the torch.sum(latentloss,1)

from pytorch-vae.

joray86 avatar joray86 commented on May 27, 2024

不好意思刚刚发现好像大家都是**人...因为不知道怎么放图片所以可能没有办法很好地展示我这边跑代码的结果...如果大致用语言描述的话,所有的图片都非常模糊(虽然VAE本身就会比较模糊,但是用这个代码跑出来比一般的VAE还会模糊一些)而且所有的图片可以说都一样...是一个非常模糊的9的样子。我后来参照了别人的代码发现主要是损失函数使用MSE和decoder的最后一层使用了relu的缘故,mse的话因为算的是欧氏距离,在vae中其实很不合适,因为假设由两个一模一样的数字1,他们的位置可能就差了几个像素,我们会认为两个都是很好或者说很相近的图片,但是如果使用mse的话就会认为他们非常不相似,导致在多次运行之后,他们所生成的结果无论是形状还是位置都是趋于相同,这不利于VAE的生成。而sigmoid我是参照了别人的代码之后才改的,配合torch.nn.functional.binary_cross_entropy这个函数一起使用,最后达到了很好地效果,这个评论可能有些失礼,毕竟是在评价别人的代码,而且也有可能是我这边配置的问题,但是因为这个问题我确实是调了好几周,所以还是希望能写出来供大家参考,谢谢。

我觉得应该是latentloss和最终的loss应该先对batchsize进行累加torch.sum(latentloss,1),最后再合并loss,我这么更改之后,效果就很不错了,之前全都是一个样。
至于你说的binary_cross_entropy,我看到其他代码也用了这个,只是很好奇,因为‘标签’并不是总是0和1的,我自己更改之后,基本都是乱码。

我刚刚试了一下确实如你所说的,mse应该也是好用的,不过无论是使用mse还是bce都需要调整sum等让loss有效。不过具体为什么bce也好用我确实也没有想明白。。。

你好,torch.sum(latentloss,1)没明白说得什么意思,可以解释一下吗?
Hello, can you give more explainations about the torch.sum(latentloss,1)

同样遇到这个问题,读一下原论文你会发现整体损失里面的recon损失相对于KLdivergence而言就是考虑整个batch,而在整合loss的时候,criterion = nn.MSELoss()这里默认的是做batch average,改成criterion = nn.MSELoss(size_average=False)也能work。

非常感谢,确实能用。

from pytorch-vae.

lantudou avatar lantudou commented on May 27, 2024

This is not a very good demo for VAE. The relu non-linear function for decoding output will limit the range of the pixels in your reconstruction image. If your input image has been normalized using the transform in pytorch, it may cause your loss cannot be decreased.

For the loss funtion, mse is ok for reconstruction loss but you should take care of the dimension problem. Firstly, you should finish the loss caculation for every samples in the batch. and finally get the mean value of the total batch loss.

Last but not least, using resize for tensor will cause the extra memory copy, try to use view() to instead it.

Here is my simple correction version, when I am debugging the VAE model, I prefer to printing the KL loss, which is a good representation to indicate whether your model is work. If this value increases with training, it proves that your model has learned the characteristics of input and is committed to output more diversified images.

import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt



class Encoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Encoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.tanh(self.linear2(x))


class Decoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.tanh(self.linear2(x))



class VAE(torch.nn.Module):

    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.latent_dim = 32
        self._enc_mu = torch.nn.Linear(100, self.latent_dim)
        self._enc_log_sigma = torch.nn.Linear(100, self.latent_dim)
        self.mu_bn = torch.nn.BatchNorm1d(self.latent_dim)
        self.mu_bn.weight.requires_grad = False
        nn.init.constant_(self.mu_bn.bias, 0.0)
        self.mu_bn.weight.fill_(0.5)
                                
    def _sample_latent(self, h_enc):
        """
        Return the latent normal sample z ~ N(mu, sigma^2)
        """
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
        
        self.z_mean = self.mu_bn(mu)
        self.z_sigma = sigma

        self.z = self.z_mean + self.z_sigma * Variable(std_z, requires_grad=False)
        return self.z  # Reparameterization trick

    def forward(self, state):

        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)

class VAE2(torch.nn.Module):

    def __init__(self, enc_out_dim=512, latent_dim=256, input_height=32):
        super(VAE2, self).__init__()
        self.encoder = resnet18_encoder(False, False)
        self.decoder = resnet18_decoder(
            latent_dim=latent_dim,
            input_height=input_height,
            first_conv=False,
            maxpool1=False
        )

        self.latent_dim = latent_dim
        self._enc_mu = torch.nn.Linear(100, self.latent_dim)
        self._enc_log_sigma = torch.nn.Linear(100, self.latent_dim)
        self.mu_bn = torch.nn.BatchNorm1d(self.latent_dim)
        self.mu_bn.weight.requires_grad = False
        nn.init.constant_(self.mu_bn.bias, 0.0)
        self.mu_bn.weight.fill_(0.5)
                                
    def _sample_latent(self, h_enc):
        """
        Return the latent normal sample z ~ N(mu, sigma^2)
        """
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
        
        self.z_mean = self.mu_bn(mu)
        self.z_sigma = sigma

        self.z = self.z_mean + self.z_sigma * Variable(std_z, requires_grad=False)
        return self.z  # Reparameterization trick

    def forward(self, state):

        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)



def gaussian_likelihood(x_hat,x):

    log_scale = nn.Parameter(torch.Tensor([0.0]))
    scale = torch.exp(log_scale)
    mean = x_hat
    dist = torch.distributions.Normal(mean, scale)

    # measure prob of seeing image under p(x|z)
    log_pxz = dist.log_prob(x)
    return log_pxz.sum(dim=(1))

def kl_divergence(z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
    p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
    q = torch.distributions.Normal(mu, std)

            # 2. get the probabilities from the equation
    log_qzx = q.log_prob(z)
    log_pz = p.log_prob(z)

            # kl
    kl = (log_qzx - log_pz)
    kl = kl.sum(-1)
    return kl

def show_image_grid(images, batch_size=8):
    fig = plt.figure(figsize=(8, batch_size/10))
    #fig.suptitle("Pass {}".format(pass_id))
    gs = plt.GridSpec(int(batch_size/10)+1, 10)
    gs.update(wspace=0.05, hspace=0.05)

    for i, image in enumerate(images):
        
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
       # print(image)
        plt.imshow(image.reshape(28, 28)*0.5+0.5,cmap='gray')
    
    plt.show()


if __name__ == '__main__':

    input_dim = 28 * 28
    batch_size = 128

    transform = transforms.Compose([transforms.Resize([28, 28]),
        # transforms.Normalize(0.5,0.5),
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
    


    mnist = torchvision.datasets.MNIST('./', download=True, transform=transform)

    dataloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size, shuffle=True, num_workers=2)

    print('Number of samples: ', len(mnist))

    encoder = Encoder(input_dim, 100, 100)
    decoder = Decoder(32, 100, input_dim)
    vae = VAE(encoder, decoder)

    #criterion = nn.MSELoss(size_average=False)

    optimizer = optim.Adam(vae.parameters(), lr=0.0001)
    l = None
    for epoch in range(100):
        vae.train()
        for i, data in enumerate(dataloader, 0):
            inputs, classes = data
  
            #inputs, classes = Variable(inputs.resize_(batch_size, input_dim)), Variable(classes)
            optimizer.zero_grad()
            dec = vae(inputs.view(-1,784))
            #ll = latent_loss(vae.z_mean, vae.z_sigma)
            ll = kl_divergence(vae.z, vae.z_mean, vae.z_sigma)
            #dec_ll = criterion(dec, inputs)
            dec_ll =gaussian_likelihood(dec,inputs.view(-1,784))

            elbo = (ll - dec_ll)

            loss = elbo.mean()

            loss.backward()
            optimizer.step()
        vae.eval()
        dec = vae(inputs.view(-1,784))
        out = dec.detach().numpy()
        show_image_grid(inputs, batch_size)
        print(epoch, dec_ll.mean().item(),ll.mean().item(),loss.item())

        show_image_grid(out,batch_size)

        
    plt.imshow(vae(inputs).item().numpy().reshape(28, 28), cmap='gray')
    plt.show(block=True)

The epoch 100 result:
Orignal images
image
Reconstruction images
image

from pytorch-vae.

zl457 avatar zl457 commented on May 27, 2024

This is not a very good demo for VAE. The relu non-linear function for decoding output will limit the range of the pixels in your reconstruction image. If your input image has been normalized using the transform in pytorch, it may cause your loss cannot be decreased.

For the loss funtion, mse is ok for reconstruction loss but you should take care of the dimension problem. Firstly, you should finish the loss caculation for every samples in the batch. and finally get the mean value of the total batch loss.

Last but not least, using resize for tensor will cause the extra memory copy, try to use view() to instead it.

Here is my simple correction version, when I am debugging the VAE model, I prefer to printing the KL loss, which is a good representation to indicate whether your model is work. If this value increases with training, it proves that your model has learned the characteristics of input and is committed to output more diversified images.

import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt



class Encoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Encoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.tanh(self.linear2(x))


class Decoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.tanh(self.linear2(x))



class VAE(torch.nn.Module):

    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.latent_dim = 32
        self._enc_mu = torch.nn.Linear(100, self.latent_dim)
        self._enc_log_sigma = torch.nn.Linear(100, self.latent_dim)
        self.mu_bn = torch.nn.BatchNorm1d(self.latent_dim)
        self.mu_bn.weight.requires_grad = False
        nn.init.constant_(self.mu_bn.bias, 0.0)
        self.mu_bn.weight.fill_(0.5)
                                
    def _sample_latent(self, h_enc):
        """
        Return the latent normal sample z ~ N(mu, sigma^2)
        """
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
        
        self.z_mean = self.mu_bn(mu)
        self.z_sigma = sigma

        self.z = self.z_mean + self.z_sigma * Variable(std_z, requires_grad=False)
        return self.z  # Reparameterization trick

    def forward(self, state):

        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)

class VAE2(torch.nn.Module):

    def __init__(self, enc_out_dim=512, latent_dim=256, input_height=32):
        super(VAE2, self).__init__()
        self.encoder = resnet18_encoder(False, False)
        self.decoder = resnet18_decoder(
            latent_dim=latent_dim,
            input_height=input_height,
            first_conv=False,
            maxpool1=False
        )

        self.latent_dim = latent_dim
        self._enc_mu = torch.nn.Linear(100, self.latent_dim)
        self._enc_log_sigma = torch.nn.Linear(100, self.latent_dim)
        self.mu_bn = torch.nn.BatchNorm1d(self.latent_dim)
        self.mu_bn.weight.requires_grad = False
        nn.init.constant_(self.mu_bn.bias, 0.0)
        self.mu_bn.weight.fill_(0.5)
                                
    def _sample_latent(self, h_enc):
        """
        Return the latent normal sample z ~ N(mu, sigma^2)
        """
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
        
        self.z_mean = self.mu_bn(mu)
        self.z_sigma = sigma

        self.z = self.z_mean + self.z_sigma * Variable(std_z, requires_grad=False)
        return self.z  # Reparameterization trick

    def forward(self, state):

        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)



def gaussian_likelihood(x_hat,x):

    log_scale = nn.Parameter(torch.Tensor([0.0]))
    scale = torch.exp(log_scale)
    mean = x_hat
    dist = torch.distributions.Normal(mean, scale)

    # measure prob of seeing image under p(x|z)
    log_pxz = dist.log_prob(x)
    return log_pxz.sum(dim=(1))

def kl_divergence(z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
    p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
    q = torch.distributions.Normal(mu, std)

            # 2. get the probabilities from the equation
    log_qzx = q.log_prob(z)
    log_pz = p.log_prob(z)

            # kl
    kl = (log_qzx - log_pz)
    kl = kl.sum(-1)
    return kl

def show_image_grid(images, batch_size=8):
    fig = plt.figure(figsize=(8, batch_size/10))
    #fig.suptitle("Pass {}".format(pass_id))
    gs = plt.GridSpec(int(batch_size/10)+1, 10)
    gs.update(wspace=0.05, hspace=0.05)

    for i, image in enumerate(images):
        
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
       # print(image)
        plt.imshow(image.reshape(28, 28)*0.5+0.5,cmap='gray')
    
    plt.show()


if __name__ == '__main__':

    input_dim = 28 * 28
    batch_size = 128

    transform = transforms.Compose([transforms.Resize([28, 28]),
        # transforms.Normalize(0.5,0.5),
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
    


    mnist = torchvision.datasets.MNIST('./', download=True, transform=transform)

    dataloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size, shuffle=True, num_workers=2)

    print('Number of samples: ', len(mnist))

    encoder = Encoder(input_dim, 100, 100)
    decoder = Decoder(32, 100, input_dim)
    vae = VAE(encoder, decoder)

    #criterion = nn.MSELoss(size_average=False)

    optimizer = optim.Adam(vae.parameters(), lr=0.0001)
    l = None
    for epoch in range(100):
        vae.train()
        for i, data in enumerate(dataloader, 0):
            inputs, classes = data
  
            #inputs, classes = Variable(inputs.resize_(batch_size, input_dim)), Variable(classes)
            optimizer.zero_grad()
            dec = vae(inputs.view(-1,784))
            #ll = latent_loss(vae.z_mean, vae.z_sigma)
            ll = kl_divergence(vae.z, vae.z_mean, vae.z_sigma)
            #dec_ll = criterion(dec, inputs)
            dec_ll =gaussian_likelihood(dec,inputs.view(-1,784))

            elbo = (ll - dec_ll)

            loss = elbo.mean()

            loss.backward()
            optimizer.step()
        vae.eval()
        dec = vae(inputs.view(-1,784))
        out = dec.detach().numpy()
        show_image_grid(inputs, batch_size)
        print(epoch, dec_ll.mean().item(),ll.mean().item(),loss.item())

        show_image_grid(out,batch_size)

        
    plt.imshow(vae(inputs).item().numpy().reshape(28, 28), cmap='gray')
    plt.show(block=True)

The epoch 100 result: Orignal images image Reconstruction images image

hi, what's the resnet_encoder?

from pytorch-vae.

lantudou avatar lantudou commented on May 27, 2024

This is for my another test. Do not care about the VAE2 class and just delete it.

from pytorch-vae.

ethanluoyc avatar ethanluoyc commented on May 27, 2024

Thanks to everyone who's commenting here. I have not updated this repo for a while and also stopped using PyTorch (and switched to JAX).

This code was written when I first learned about VAE, so it's probably different from what I would write if I were to do it now. Just so you know, PyTorch now has a torch.distributions package which would allow you to compute the log-likelihood and KL divergences easily instead of hard-coding it as I did, so you should definitely use that. A very good example of VAE in JAX (but should read very well for anyone who is familiar with NumPy is
https://github.com/deepmind/dm-haiku/blob/main/examples/vae.py)

My PyTorch is a bit rusty now but just to answer a few general questions raised here, in the hope it clarifies things

Using MSE is OK since if you assume a Gaussian likelihood with fixed stddev, then the log-likelihood takes the form of square error. They will be different from the proper Gaussian likelihood by a constant so using MSE corresponds to beta-VAE with beta != 1, but you are right that you should sum across the dimension of the image but take the mean across the batch. As a rule of thumb, for any VAE-based models, to implement what's in the original VAE paper, you should sum across the dimensions of the inputs and latents and average across the batch. Doing it differently can still work in practice and you can refer to the beta-VAE paper for more details.

In addition, the above makes the assumption of fixed stddev, in practice, you can also learn the stddev albeit being quite difficult in general.

A Gaussian decoder is not the only choice. If you work with binary MNIST then using the Bernoulli distribution may be a more proper choice, and that corresponds to the cross-entropy loss. If you work with a Gaussian decoder, then indeed problematic to use another activation function since the parameters of the Gaussian distribution may be negative. It's ok to add ReLU if you work with a Bernoulli decoder and assume that the output of the decoder is used to parameterize the probs (instead of logits), but I am not sure how good it would perform in practice. It also seems in practice that using the Bernoulli decoder works better for MNIST, but theoretically, both a Gaussian likelihood and Bernoulli are fine, it is only a difference in the modeling choice. There's nothing wrong in theory, but which one works best depends on the practical application.

Normalizing the image to be [-0.5, 0.5] may be useful if you work with Gaussian likelihoods as typically the output of the NN would be initialized to be centered around zero. This may have some positive impact on the optimization.

from pytorch-vae.

ethanluoyc avatar ethanluoyc commented on May 27, 2024

Updated code which uses new PyTorch as well as some comments.

import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt
from torch import distributions

class Encoder(torch.nn.Module):
    def __init__(self, D_in, H, latent_size):
        super(Encoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, H)
        self.enc_mu = torch.nn.Linear(H, latent_size)
        self.enc_log_sigma = torch.nn.Linear(H, latent_size)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        mu = self.enc_mu(x)
        log_sigma = self.enc_log_sigma(x)
        sigma = torch.exp(log_sigma)
        return torch.distributions.Normal(loc=mu, scale=sigma)


class Decoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)
        

    def forward(self, x):
        x = F.relu(self.linear1(x))
        mu = torch.tanh(self.linear2(x))
        return torch.distributions.Normal(mu, torch.ones_like(mu))

class VAE(torch.nn.Module):
    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, state):
        q_z = self.encoder(state)
        z = q_z.rsample()
        return self.decoder(z), q_z


transform = transforms.Compose(
    [transforms.ToTensor(),
     # Normalize the images to be -0.5, 0.5
     transforms.Normalize(0.5, 1)]
    )
mnist = torchvision.datasets.MNIST('./', download=True, transform=transform)

input_dim = 28 * 28
batch_size = 128
num_epochs = 100
learning_rate = 0.001
hidden_size = 512
latent_size = 8

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

dataloader = torch.utils.data.DataLoader(
    mnist, batch_size=batch_size,
    shuffle=True, 
    pin_memory=True)

print('Number of samples: ', len(mnist))

encoder = Encoder(input_dim, hidden_size, latent_size)
decoder = Decoder(latent_size, hidden_size, input_dim)

vae = VAE(encoder, decoder).to(device)

optimizer = optim.Adam(vae.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
    for data in dataloader:
        inputs, _ = data
        inputs = inputs.view(-1, input_dim).to(device)
        optimizer.zero_grad()
        p_x, q_z = vae(inputs)
        log_likelihood = p_x.log_prob(inputs).sum(-1).mean()
        kl = torch.distributions.kl_divergence(
            q_z, 
            torch.distributions.Normal(0, 1.)
        ).sum(-1).mean()
        loss = -(log_likelihood - kl)
        loss.backward()
        optimizer.step()
        l = loss.item()
    print(epoch, l, log_likelihood.item(), kl.item())

from pytorch-vae.

coderaBruce avatar coderaBruce commented on May 27, 2024

I believe I did run the updated version @ethanluoyc, it seems still doesn't work well whether the task is generating from latent space or trying to recover the original images. Plus loss does not drop at all either.
Wander if anyone had a successful attempt? @87hbteo @MrChenFeng @joray86 @MrChenFeng ?

I just doubt if it is proper to calculate log-likelihood as p_x.log_prob(inputs).sum(-1).mean()? Since this seems to be calculating the log-likelihood of p(x|z) while we should target the log-likelihood of p(x)?

from pytorch-vae.

ethanluoyc avatar ethanluoyc commented on May 27, 2024

I just tested and indeed the new version I posted does not work well. Looking at it though, I am not sure what I did differently from @zl457.

@coderaBruce I believe the calculation of the expected log-likelihood is correct. If you look at @zl457's answer the new version looks the same.

I don't have time to look at this now. My hypothesis right now is that using a Gaussian likelihood in conjunction with sigmoid/tanh causes some vanishing gradient problem that otherwise would not occur if you use a BCE likelihood. Referring to @zl457 's answer, it looks like there is an additional BatchNorm layer that was missing from my updated version. I can try to take a look again later this week to see what caused the bad result.

from pytorch-vae.

handesome avatar handesome commented on May 27, 2024

This is not a very good demo for VAE. The relu non-linear function for decoding output will limit the range of the pixels in your reconstruction image. If your input image has been normalized using the transform in pytorch, it may cause your loss cannot be decreased.

For the loss funtion, mse is ok for reconstruction loss but you should take care of the dimension problem. Firstly, you should finish the loss caculation for every samples in the batch. and finally get the mean value of the total batch loss.

Last but not least, using resize for tensor will cause the extra memory copy, try to use view() to instead it.

Here is my simple correction version, when I am debugging the VAE model, I prefer to printing the KL loss, which is a good representation to indicate whether your model is work. If this value increases with training, it proves that your model has learned the characteristics of input and is committed to output more diversified images.

import torch
from torch.autograd import Variable
import numpy as np
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import torch.optim as optim
from torch import nn
import matplotlib.pyplot as plt



class Encoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Encoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.tanh(self.linear2(x))


class Decoder(torch.nn.Module):
    def __init__(self, D_in, H, D_out):
        super(Decoder, self).__init__()
        self.linear1 = torch.nn.Linear(D_in, H)
        self.linear2 = torch.nn.Linear(H, D_out)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        return F.tanh(self.linear2(x))



class VAE(torch.nn.Module):

    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.latent_dim = 32
        self._enc_mu = torch.nn.Linear(100, self.latent_dim)
        self._enc_log_sigma = torch.nn.Linear(100, self.latent_dim)
        self.mu_bn = torch.nn.BatchNorm1d(self.latent_dim)
        self.mu_bn.weight.requires_grad = False
        nn.init.constant_(self.mu_bn.bias, 0.0)
        self.mu_bn.weight.fill_(0.5)
                                
    def _sample_latent(self, h_enc):
        """
        Return the latent normal sample z ~ N(mu, sigma^2)
        """
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
        
        self.z_mean = self.mu_bn(mu)
        self.z_sigma = sigma

        self.z = self.z_mean + self.z_sigma * Variable(std_z, requires_grad=False)
        return self.z  # Reparameterization trick

    def forward(self, state):

        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)

class VAE2(torch.nn.Module):

    def __init__(self, enc_out_dim=512, latent_dim=256, input_height=32):
        super(VAE2, self).__init__()
        self.encoder = resnet18_encoder(False, False)
        self.decoder = resnet18_decoder(
            latent_dim=latent_dim,
            input_height=input_height,
            first_conv=False,
            maxpool1=False
        )

        self.latent_dim = latent_dim
        self._enc_mu = torch.nn.Linear(100, self.latent_dim)
        self._enc_log_sigma = torch.nn.Linear(100, self.latent_dim)
        self.mu_bn = torch.nn.BatchNorm1d(self.latent_dim)
        self.mu_bn.weight.requires_grad = False
        nn.init.constant_(self.mu_bn.bias, 0.0)
        self.mu_bn.weight.fill_(0.5)
                                
    def _sample_latent(self, h_enc):
        """
        Return the latent normal sample z ~ N(mu, sigma^2)
        """
        mu = self._enc_mu(h_enc)
        log_sigma = self._enc_log_sigma(h_enc)
        sigma = torch.exp(log_sigma)
        std_z = torch.from_numpy(np.random.normal(0, 1, size=sigma.size())).float()
        
        self.z_mean = self.mu_bn(mu)
        self.z_sigma = sigma

        self.z = self.z_mean + self.z_sigma * Variable(std_z, requires_grad=False)
        return self.z  # Reparameterization trick

    def forward(self, state):

        h_enc = self.encoder(state)
        z = self._sample_latent(h_enc)
        return self.decoder(z)



def gaussian_likelihood(x_hat,x):

    log_scale = nn.Parameter(torch.Tensor([0.0]))
    scale = torch.exp(log_scale)
    mean = x_hat
    dist = torch.distributions.Normal(mean, scale)

    # measure prob of seeing image under p(x|z)
    log_pxz = dist.log_prob(x)
    return log_pxz.sum(dim=(1))

def kl_divergence(z, mu, std):
        # --------------------------
        # Monte carlo KL divergence
        # --------------------------
        # 1. define the first two probabilities (in this case Normal for both)
    p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
    q = torch.distributions.Normal(mu, std)

            # 2. get the probabilities from the equation
    log_qzx = q.log_prob(z)
    log_pz = p.log_prob(z)

            # kl
    kl = (log_qzx - log_pz)
    kl = kl.sum(-1)
    return kl

def show_image_grid(images, batch_size=8):
    fig = plt.figure(figsize=(8, batch_size/10))
    #fig.suptitle("Pass {}".format(pass_id))
    gs = plt.GridSpec(int(batch_size/10)+1, 10)
    gs.update(wspace=0.05, hspace=0.05)

    for i, image in enumerate(images):
        
        ax = plt.subplot(gs[i])
        plt.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])
        ax.set_aspect('equal')
       # print(image)
        plt.imshow(image.reshape(28, 28)*0.5+0.5,cmap='gray')
    
    plt.show()


if __name__ == '__main__':

    input_dim = 28 * 28
    batch_size = 128

    transform = transforms.Compose([transforms.Resize([28, 28]),
        # transforms.Normalize(0.5,0.5),
        transforms.ToTensor(),
        transforms.Normalize(0.5,0.5)
        ])
    


    mnist = torchvision.datasets.MNIST('./', download=True, transform=transform)

    dataloader = torch.utils.data.DataLoader(mnist, batch_size=batch_size, shuffle=True, num_workers=2)

    print('Number of samples: ', len(mnist))

    encoder = Encoder(input_dim, 100, 100)
    decoder = Decoder(32, 100, input_dim)
    vae = VAE(encoder, decoder)

    #criterion = nn.MSELoss(size_average=False)

    optimizer = optim.Adam(vae.parameters(), lr=0.0001)
    l = None
    for epoch in range(100):
        vae.train()
        for i, data in enumerate(dataloader, 0):
            inputs, classes = data
  
            #inputs, classes = Variable(inputs.resize_(batch_size, input_dim)), Variable(classes)
            optimizer.zero_grad()
            dec = vae(inputs.view(-1,784))
            #ll = latent_loss(vae.z_mean, vae.z_sigma)
            ll = kl_divergence(vae.z, vae.z_mean, vae.z_sigma)
            #dec_ll = criterion(dec, inputs)
            dec_ll =gaussian_likelihood(dec,inputs.view(-1,784))

            elbo = (ll - dec_ll)

            loss = elbo.mean()

            loss.backward()
            optimizer.step()
        vae.eval()
        dec = vae(inputs.view(-1,784))
        out = dec.detach().numpy()
        show_image_grid(inputs, batch_size)
        print(epoch, dec_ll.mean().item(),ll.mean().item(),loss.item())

        show_image_grid(out,batch_size)

        
    plt.imshow(vae(inputs).item().numpy().reshape(28, 28), cmap='gray')
    plt.show(block=True)

The epoch 100 result: Orignal images image Reconstruction images image

hi!Thank you for the code!!when I run your code"plt.imshow(vae(inputs).item().numpy().reshape(28, 28), cmap='gray')",the error says"RuntimeError: mat1 and mat2 shapes cannot be multiplied (2688x28 and 784x100)",could you please give some advice?

from pytorch-vae.

Related Issues (3)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.