Code Monkey home page Code Monkey logo

lr-gan.pytorch's Introduction

Pytorch code for Layered Recursive Generative Adversarial Networks

Introduction

This is the pytorch implementation of our ICLR 2017 paper "LR-GAN: Layered Recursive Generative Adversarial Networks for Image Generation".

In our paper, we proposed LR-GAN to generate images layer-by-layer recursively considering images naturally have structure and context. As show below, LR-GAN first generates a background image, and then generates foregrounds with appearance, pose and shape. Afterward, LR-GAN place the foregrounds at somewhere of background accordingly.

By this way, LR-GAN can significantly reduce the blending between background and foregrounds. Both the qualitative and quantitative comparisons indicate that LR-GAN could generate better and sharp images than the baseline DCGAN model.

Disclaimer

This is the reimplementation code of LR-GAN based on Pytorch. It is developed based on Pytorch DCGAN. Our original code was implemented based on Torch during the first author's internship. All the results presented in our paper were obtained based on the Torch code, which cannot be released since the firm restriction. This project is an attempt to reproduce the results in our paper.

Citation

If you find this code useful, please cite the following paper:

@article{yang2017lr,
    title={LR-GAN: Layered recursive generative adversarial networks for image generation},
    author={Yang, Jianwei and Kannan, Anitha and Batra, Dhruv and Parikh, Devi},
    journal={ICLR},
    year={2017}
}

Dependencies

  1. PyTorch. Install PyTorch with proper commands. Make sure you also install torchvision.

  2. Spatial transformer network with mask (STNM). Install STNM from this project. Since we might use different gpu devices and cuda driver, etc. Please make your own stnm.so on your machine.

Train LR-GAN

Preparation

Pull this project to your own machine, and then make sure Pytorch is installed successfully. Create a folder datasets to hold the training sets, and a folder images to save the generation results, and a folder models to save the models (generators and discriminators):

$ mkdir datasets
$ mkdir images
$ mkdir models

Then, you can try to train the LR-GAN model on the datasets: 1) MNIST-ONE; 2) MNIST-TWO; 3) CUB-200; 4) CIFAR-10. The sample images are shown below:

In the datasets folder, create subfolders for all these datasets separately:

$ mkdir datasets/mnist-one
$ mkdir datasets/mnist-two
$ mkdir datasets/cub200
$ mkdir datasets/cifar10

Training

  1. MNIST-ONE. We first run experiments on MNIST-ONE, which can be downloaded from here. Unzip this into datasets/mnist-one folder, and then run the following command:
$ python train.py \
      --dataset mnist-one \
      --dataroot datasets/mnist-one \
      --ntimestep 2 \
      --imageSize 32 \
      --maxobjscale 1.2 \
      --niter 50 \
      --session 1

Here, ntimestep specifies the number of recursive layers, e.g., 2 means one background and one foreground layer; imageSize is the scale size the training images; maxobjscale is the maximal object (foreground) scale, the larger the value, the smaller the object size; session specifies the training session; niter specifies the number of training epochs. Below are randomly generation results using trained model in epoch 50:

From left to right, they are generated background images, foreground images, foreground masks and final images.

  1. CUB200. We run on CUB200 in 64x64. Here is the processed dataset. Similarly, download it and unzip it into datasets/cub200. Then, run the following command:
$ python train.py \
      --dataset cub200 \
      --dataroot datasets/cub200 \
      --ntimestep 2 \
      --imageSize 64 \
      --ndf 128 \
      --ngf 128 \
      --maxobjscale 1.2 \
      --niter 200 \
      --session 1

Based on above command, we obtained the model same to the one in our paper. Below are randomly generated images:

The layout is similar to MNIST-ONE. As we an see, the generator generated bird-shape masks, and thus make the final images sharper and cleaner.

  1. CIFAR-10. CIFAR-10 can be automatically downloaded with pytorch dataloader. We use two timesteps for the generation. To train the model, run:
$ python train.py \
      --dataset cifar10 \
      --dataroot datasets/cifar10 \
      --ntimestep 2 \
      --imageSize 32 \
      --maxobjscale 1.2 \
      --niter 100 \
      --session 1

Here are some randomly sampled generation results:

From left to right, they are generated background images, foreground images, foreground masks and final images. We can clearly find some horse-shape, bird-shape and boat-shape masks generated, and the finally generated images are more sharper.

  1. MNIST-TWO. The images are 64x64 and contain two digits. We train the model using the following command:
$ python train.py \
      --dataset mnist-two \
      --dataroot datasets/mnist-two \
      --ntimestep 3 \
      --imageSize 64 \
      --maxobjscale 2 \
      --niter 50 \
      --session 1

The layout is the same to the one in our paper.

  1. LFW. We train on 64x64 images, which can be downloaded from here. Unzip it to the folder datasets/lfw. We train the model using the following command:
$ python train.py \
      --dataset lfw \
      --dataroot datasets/lfw \
      --ntimestep 2 \
      --imageSize 64 \
      --maxobjscale 1.3 \
      --niter 100 \
      --session 1

Below are the generation results:

The left most 8x8 grid are the real images, followed by generated backgrounds, foregrounds, masks and final images.

Test LR-GAN

After the training, the checkpoints will be saved to models. You can append two more options (netG and evaluate) to the command used for training model. Use cifar10 as the example, it will be:

$ python train.py \
      --dataset cifar10 \
      --dataroot datasets/cifar10 \
      --ntimestep 2 \
      --imageSize 32 \
      --maxobjscale 1.2 \
      --niter 100 \
      --session 1
      --netG models/cifar10_netG_s_1_epoch_100.pth
      --evaluate True

Then you can get the generation results for the session:1 and epoch:100 model in the folder images.

lr-gan.pytorch's People

Contributors

jwyang avatar kwon-young avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

lr-gan.pytorch's Issues

About Affine Function

Thank you for your great work.
I am studying your code recently. However, I do not touch any pytorch before so that cannot figure out how the affine function(in modules/gridgen.py & functions/gridgen.py) works.
Could you explain the meaning of variable 'grid', 'grad_output', 'grad_input' respectly? Or show the workflow of AffineGridGen how to do. It will be very helpful for me to understand the full operating workflow in LR-GAN. Thank you so much.

STNM using pytorch official function

Hi I implement STNM with the official function in Pytorch(grid_sample()). But the model can't converge.

class STNM(nn.Module):
    def __init__(self):
        super(STNM, self).__init__()  
    
    def forward(self, canvas, fgimg, fggrid, fgmask):
        #print('grid size: {} img_size: {}'.format(fggrid.size(), fgimg.size()))
        mask = F.grid_sample(fgmask, fggrid)
        fg = F.grid_sample(fgimg, fggrid)
        #torch.addcmul(tensor, value=1, tensor1, tensor2, out=None) → Tensor
        tmp1 = torch.FloatTensor(fg.size(0), fg.size(1), fg.size(2), fg.size(3))
        torch.addcmul(tmp1, mask.data, fg.data) 
        
        ng_mask = -1*mask 
        out = torch.add(ng_mask, 1)
        tmp2 = torch.FloatTensor(out.size(0), out.size(1), out.size(2), out.size(3))
        torch.addcmul(tmp2, out.data, canvas.data)
        
        return Variable(tmp1+tmp2, requires_grad = True)

problem when importing stnm

hi jianwei, thanks a lot for sharing the work!

i meet a problem when importing the stnm:

ImportError: dlopen(lr-gan.pytorch-master/_ext/stnm/_stnm.so, 2): no suitable image found. Did find: /lr-gan.pytorch-master/_ext/stnm/_stnm.so: unknown file type, first eight bytes: 0x7F 0x45 0x4C 0x46 0x02 0x01 0x01 0x00

i tried both on my mac and a linux server with cuda installed, this problem still raised. im wondering if you have any ideas about this? i also tried to build your stnm.pytorch project where i met this problem too.

thanks!

Saved image seems weird.

Hi @jwyang , I download your source code and run it without any modification. But the generated image and the real_cpu saved from function vutils.save_image looks weird, which is looked like have more dark color. Could you tell me what should I do to modify this thing?

why doesn't the background layer learn everything?

It seems the background layer has enough capacity to model complex image in your paper's figures, but I didnt find any regularization term about this for G in your code. It's amaing, what do you think the possible reason is?

lr-gan.pytorch update

Hello,
First of all, thank you for your open-source implementation of the lr-gan model.

My end goal is to build on this model in order to do unsupervised detection, and I have a few questions:

  • Could you update the dead-link to the mnist1 and mnist2 dataset (or maybe publish the script you use to generate the dataset)?
  • There is a discrepancy between what is explained in original paper and the actual implementation of the model relative the use of the lstm layer when generating the background.
    • In the paper, the input noise is fed to the lstm layer before being use by the background generator
    • In the implementation: train.py:L290 you note that you fed the noise directly without explaining why.

I opened pull-request to contribute back and to update the code to python 3, to the most recent version of pytorch, as well as correcting some minor bugs when using the data.
I've also implemented a pytorch version of your STNM module.
Could you check if the implementation is correct?

Finally, would you be interested to actively accept pull-request concerning the use of the model for detection (or what you call conditional image generation in your paper)? The other issue where you talked about detection is quite helpful to actually implement the encoder and the reconstruction loss, but by actually implementing the architecture, I could make sure that I'm doing the correct thing.

Thank you very much for your help!

How to reproduce the Category specific models in 5.4?

Hi @jwyang , I'm very impressed by your work and very thanks for your publication code.
I have few questions about part 5.4 about how to train the Category specific models:

  1. Did you insert the catogory label into z before inputing the LSTM?
  2. Did you put the catogory label as the input of the Discrimitor together with the input image(fake and real)?

Could you please give me some advice?
Thanks a lot.

how to reproduce results on the paper?

Hi, I'm very impressed by your work LR-GAN. So, I'm trying to reproduce results (CUB200, CIFAR10) on the paper, but it totally doesn't learn at all.

  • When I train the algorithm with given parameters in the README (CUB200, CIFAR10)
    • discriminator loss goes to zero, while generator can't generate any meaningful pattern with fixed seed.
    • I've tried several difference learning rates for discriminator and generation, it didn't work also.
  • If I don't use foreground & mask generator and STN, time step = 1, it produces good results like DCGAN after training.

Could you give some information for reproducing the result? I'm using pyTorch with version '0.2.0_2' on python 3.6 and nvidia 367.57 with cuda 8.0.

Thank you.

Question on conditional image generation

With the help of your advice, I reproduced the main part of the paper and succeed in training your model to ImageNet with specific classes, minibus, dog.

Now, I try to reproduce 6.8 conditional image generation in the paper to use this architecture for unsupervised/weakly-supervised object segmentation.
But I had a hard time to find a trainable setup. Can you share a setting in the experiment?

[Encoder architecture]

  • Use same architecture with a discriminator.
  • If an image is given, encoder returns a vector with the dimension same as that of random noise vector (nz).
  • To generate background vector, I use Encoder(real image)=background vector
  • To generate foreground vector, I use Encoder(real image - generated image)=foreground vector

[Optimization method]

  • Original code (1) maximize log(D(x)) + log(1 - D(G(z))) with fake image then (2) maximize log(D(G(z))) with true image.
  • I use three-options in optimization (1) update auto-encoder + generator part additionally (2) update auto-encoder part with maximizing log(D(G(z)) step. (3) update only auto-encoder part additonally. But all variants are not working.

Thanks.

Version of software

Hello, could you say the versions of the software, on which you ran this model?
CUDA, nvidia driver, pytorch, torchvision, version of python, type of operation system?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.