Code Monkey home page Code Monkey logo

pytorch-gan-collections's Introduction

Collections of GANs

Pytorch implementation of basic unsupervised GANs on CIFAR10.

For more defails about calculating Inception Score and FID using pytorch can be found here pytorch_gan_metrics.

Models

  • DCGAN
  • WGAN
  • WGAN-GP
  • SN-GAN

Requirements

  • Install python packages
    pip install -U pip setuptools
    pip install -r requirements.txt

Results

The FID is calculated by 50k generated images and CIFAR10 train set.

Model Dataset Inception Score FID
DCGAN CIFAR10 6.01(0.05) 42.72
WGAN(CNN) CIFAR10 6.62(0.09) 40.03
WGAN-GP(CNN) CIFAR10 7.66(0.10) 19.83
WGAN-GP(ResNet) CIFAR10 7.95(0.14) 16.95
SNGAN(CNN) CIFAR10 7.84(0.12) 17.81
SNGAN(ResNet) CIFAR10 8.31(0.10) 14.32

Examples

  • DCGAN

    dcgan_gif dcgan_png

  • WGAN(CNN)

    wgan_gif wgan_png

  • WGAN-GP(CNN)

    wgangp_cnn_gif wgangp_cnn_png

  • WGAN-GP(ResNet)

    wgangp_res_gif wgangp_res_png

  • SNGAN(CNN)

    sngan_cnn_gif sngan_cnn_png

  • SNGAN(ResNet)

    sngan_res_gif sngan_res_png

Reproduce

  • Download cifar10.train.npz for calculating FID. Then, create folder stats for the npz files

    stats
    └── cifar10.train.npz
    
  • Train from scratch

    Different methods are separated into different files for clear reading.

    # DCGAN
    python dcgan.py --flagfile ./configs/DCGAN_CIFAR10.txt
    # WGAN(CNN)
    python wgan.py --flagfile ./configs/WGAN_CIFAR10_CNN.txt
    # WGAN-GP(CNN)
    python wgangp.py --flagfile ./configs/WGANGP_CIFAR10_CNN.txt
    # WGAN-GP(ResNet)
    python wgangp.py --flagfile ./configs/WGANGP_CIFAR10_RES.txt
    # SNGAN(CNN)
    python sngan.py --flagfile ./configs/SNGAN_CIFAR10_CNN.txt
    # SNGAN(ResNet)
    python sngan.py --flagfile ./configs/SNGAN_CIFAR10_RES.txt

Learning Curves

inception_score_curve fid_curve

Change Log

  • 2022-01-10

    • Update pytorch to 1.10.1 and CUDA 11.3
    • Use pytorch_gan_metrics to calculate FID and Inception Score
    • Use 50k generated images and CIFAR10 train set to calculate FID
    • Fix default parameters especially for wgan.py
  • 2021-04-16

    • Update pytorch to 1.8.1
    • Move metrics to submodule.
    • Evaluate FID on CIFAR10 test set instead of training set.
    • Fix cifar10.test.npz download link and sample images.

pytorch-gan-collections's People

Contributors

w86763777 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

pytorch-gan-collections's Issues

if only epoch 200 or 500 times, should I change the 'num_images' as well?

Hi! Dear Yi-Lun
I want to epoch 200 times or 500 times because 10k or 50k will cost a long time to run.
I changed the sample_step=50 and total_steps=200
should I change num_images=50000 as num_images = 50? I guess it should be smaller than epoch times? or it doens't matter ?
Thank you in advance:)

--arch=cnn32
--batch_size=128
--dataset=cifar10
--fid_cache=./stats/cifar10.train.npz
--logdir=./logs/SNGAN_CIFAR10_CNN
--loss=hinge
--lr_D=0.0002
--lr_G=0.0002
--n_dis=1
--num_images=50
--record
--sample_step=50
--sample_size=64
--seed=0
--total_steps=500
--z_dim=100

Performance replication for WGAN

Hello, thanks for creating this repo and I found it very helpful for me to play with basic GAN structures. I tried to replicate the reported performance by running the code with the following setups

python wgan.py
python wgan.py --arch=cnn32 --logdir=./logs/WGAN_CIFAR10_CNN32/
python wgangp.py

I got FID plot as follows. The FID of WGAN on CIFAR10 is ~80, which is much larger than the reported 33.27. The FID on WGAN(RES) is even worse. I was wondering how I could replicate the 33 FID. Also I am pretty confused how ResNet is performing worse than CNN. It'd be helpful if you can give any hints. Thanks!

image

if I want to change the neural networks, how can I do?

HI, I want to test the neural networks below:

cnn

And I tried to change the model files like below:

[class Generator(nn.Module):
def init(self, z_dim, M):
super(Generator, self).init()
self.z_dim = z_dim
self.main = nn.Sequential(
nn.ConvTranspose2d(self.z_dim, 256, M, 1, 0, bias=False), # 4, 4
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1, bias=False),
nn.Tanh()
)

def forward(self, z):
    return self.main(z.view(-1, self.z_dim, 1, 1))

class Discriminator(nn.Module):
def init(self, M):
super(Discriminator, self).init()

    self.main = nn.Sequential(
        # 32
        nn.Conv2d(3, 32, 5, 2, 2, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        # 16
        nn.Conv2d(32, 64, 5, 2, 2, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        nn.BatchNorm2d(64),
        # 8
        nn.Conv2d(64, 128, 5, 2, 2, bias=False),
        nn.LeakyReLU(0.2, inplace=True),
        nn.BatchNorm2d(128),
        # 4
        nn.Conv2d(128, 10, 5, 2, 2, bias=False), #128*2*2
        nn.ReLU(True),

    )

    self.linear = nn.Linear(M // 16 * M // 16 * 10, 1)

def forward(self, x):
    x = self.main(x)
    x = torch.flatten(x, start_dim=1)
    x = self.linear(x)
    return x

class Generator32(Generator):
def init(self, z_dim):
super().init(z_dim, M=3)

class Generator48(Generator):
def init(self, z_dim):
super().init(z_dim, M=4)

class Discriminator32(Discriminator):
def init(self):
super().init(M=16)

class Discriminator48(Discriminator):
def init(self):
super().init(M=48)]

it seems I have to change the M and other parameters . How can I do? Thank you in advance

Unable to allocate 2.57 GiB for an array with shape (2764800000,) and data type uint8

everything = np.fromfile(f, dtype=np.uint8)

numpy.core._exceptions.MemoryError: Unable to allocate 2.57 GiB for an array with shape (2764800000,) and data type uint8

HI!!
When I want to use stl10 dataset error happended.....
I have changed the file like below

flags.DEFINE_enum('dataset', 'stl10', ['cifar10', 'stl10'], "dataset")
flags.DEFINE_string('fid_cache', './stats/stl10.unlabeled.48.npz', 'FID cache')

可是它說沒有內存的樣子。。我該怎麼辦。。求指點。。
謝謝您

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.