Code Monkey home page Code Monkey logo

pytorch-mnist-celeba-gan-dcgan's Introduction

pytorch-MNIST-CelebA-GAN-DCGAN

Pytorch implementation of Generative Adversarial Networks (GAN) [1] and Deep Convolutional Generative Adversarial Networks (DCGAN) [2] for MNIST [3] and CelebA [4] datasets.

  • If you want to train using cropped CelebA dataset, you have to change isCrop = False to isCrop = True.

  • you can download

  • pytorch_CelebA_DCGAN.py requires 64 x 64 size image, so you have to resize CelebA dataset (celebA_data_preprocess.py).

  • pytorch_CelebA_DCGAN.py added learning rate decay code.

Implementation details

  • GAN

GAN

  • DCGAN

Loss

Resutls

MNIST

  • Generate using fixed noise (fixed_z_)
GAN DCGAN
  • MNIST vs Generated images
MNIST GAN after 100 epochs DCGAN after 20 epochs
  • Training loss

    • GAN Loss
  • Learning Time

    • MNIST DCGAN - Avg. per epoch: 197.86 sec; (if you want to reduce learning time, you can change 'generator(128)' and 'discriminator(128)' to 'generator(64)' and 'discriminator(64)' ... then Avg. per epoch: about 67sec in my development environment.)

CelebA

  • Generate using fixed noise (fixed_z_)
DCGAN DCGAN crop
  • CelebA vs Generated images
CelebA DCGAN after 20 epochs DCGAN crop after 30 epochs
  • Learning Time
    • CelebA DCGAN - Avg. per epoch: 732.54 sec; total 20 epochs ptime: 14744.66 sec

Development Environment

  • Ubuntu 14.04 LTS
  • NVIDIA GTX 1080 ti
  • cuda 8.0
  • Python 2.7.6
  • pytorch 0.1.12
  • torchvision 0.1.8
  • matplotlib 1.3.1
  • imageio 2.2.0
  • scipy 0.19.1

Reference

[1] Goodfellow, Ian, et al. "Generative adversarial nets." Advances in neural information processing systems. 2014.

(Full paper: http://papers.nips.cc/paper/5423-generative-adversarial-nets.pdf)

[2] Radford, Alec, Luke Metz, and Soumith Chintala. "Unsupervised representation learning with deep convolutional generative adversarial networks." arXiv preprint arXiv:1511.06434 (2015).

(Full paper: https://arxiv.org/pdf/1511.06434.pdf)

[3] Y. LeCun, L. Bottou, Y. Bengio, and P. Haffner. "Gradient-based learning applied to document recognition." Proceedings of the IEEE, 86(11):2278-2324, November 1998.

[4] Liu, Ziwei, et al. "Deep learning face attributes in the wild." Proceedings of the IEEE International Conference on Computer Vision. 2015.

pytorch-mnist-celeba-gan-dcgan's People

Contributors

znxlwm avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-mnist-celeba-gan-dcgan's Issues

crop

Dear Mr:
I'm interested in your work and want to know how to get the Croped image.

The MNIST_DCGAN learning is too slow

Sir, as u said the avg time for an epoch is around 180s, while on my server, it shows:
[1/20] - ptime: 372.38, loss_d: 0.597, loss_g: 5.759

My environment is:
ubuntu 16.04+cuda8.0+cudnn 6+ pytorch 0.2 +Titan XP

I also set the worker_num for train data loader to 2, so it shouldn't be a problem of IO.
Do u have any idea of what's going wrong , Sir?

Training does not show

Hello, your code is very good, but I have some problems that I can't solve.When I replace

transform = transforms.Compose([
transforms.Scale(img_size),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
and
D_train_loss.item()

And then you get these two warnings

UserWarning: volatile was removed and now has no effect. Use with torch.no_grad(): instead.
fixed_z_ = Variable(fixed_z_.cuda(),volatile=True)
UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
"please use transforms.Resize instead.")

and after that the program keeps running, but it doesn't do any calculations like output losses and so forth.I would be most grateful if you could answer me
image

Iscrop

i used python3.5 and pytorch0.4 which can successfully train the CelebA.but when i turned the parameter “IsCrop=True” ,there is another error “Runtime Error:sizes must be non-negative ”
Do you know how can I fix it?hope to get your help

RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]

for x_, _ in train_loader:
    # train discriminator D
    D.zero_grad()

    x_ = x_.view(-1, 28 * 28)

    mini_batch = x_.size()[0]

When I tried to run these lines, it showed me " output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]" and I don't know why.

Could you help me to fix this bug?

And by the way, in the loop "for x_, _ in train_loader", what is the structure of x_ and ? I could not get the meaning of x and _

I am new to GAN. Thanks!

The imshow of celebA project

In celebA project,

ax[i, j].imshow((test_images[k].cpu().data.numpy().transpose(1, 2, 0) + 1) / 2)

is there any reason for you to make the image (value +1)/2?

Why the loss of D quickly go down to zero?

Hello
I ran your code of DCGAN implementation on dataset of MNIST but the quality of the generated images were poor. I have already tried to reduce the learning rate but it didn't work and the result was a far cry from yours. I am new to GAN and feel really confused. Did you change your parameter settings or do some other adjustments? Hope to get some advice from you, thank you very much!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.