Code Monkey home page Code Monkey logo

sinha-raunak / gan-toolkit Goto Github PK

View Code? Open in Web Editor NEW

This project forked from ibm/gan-toolkit

0.0 1.0 0.0 91.16 MB

The aim of the toolkit is to provide a highly flexible, no-code way of implementing GAN models. By providing the details of a GAN model, in an intuitive config file or as command line arguments, the code could be generated for training the GAN model. With very little or no prior knowledge about GAN, people could play around with different formulations of GAN and mix-and-match GAN modules to create novel models, as well.

License: MIT License

Python 100.00%

gan-toolkit's Introduction

gan-toolkit

The aim of the toolkit is to provide a highly flexible, no-code way of implementing GAN models. By providing the details of a GAN model, in an intuitive config file or as command line arguments, the code could be generated for training the GAN model. With very little or no prior knowledge about GAN, people could play around with different formulations of GAN and mix-and-match GAN modules to create novel models, as well.

Modular GAN Architecture

GAN Architecture

Quick Start

  1. (Optional) If you want to setup an anaconda environment

    a. Install Anaconda from here

    b. Create a conda environment

    $ conda create -n gantoolkit python=3.6 anaconda

    c. Activate the conda environment

    $ source activate gantoolkit
  2. Clone the code

    $ git clone https://github.com/IBM/gan-toolkit
  3. Install all the requirements. Tested for Python 3.5.x+

    $ pip install -r requirements.txt
  4. Train the model using a configuration file. (Many samples are provided in the configs folder)

    $ cd agant
    $ python main.py --config configs/gan_gan.json
  5. Default input and output paths (override thse paths in the config file)

    logs/ : training logs

    saved_models/ : saved trained models

    train_results/ : saved all the intermediate generated images

    datasets/ : input dataset path

Implemented GAN Models

  1. Vanilla GAN: Generative Adversarial Learning (Goodfellow et al., 2014)

  2. C-GAN: Conditional Generative Adversarial Networks (Mirza et al., 2014)

  3. DC-GAN: Deep Convolutional Generative Adversarial Network (Radford et al., 2016)

  4. Cycle-GAN: Cycle-Consistent Adversarial Networks (Zhu et al., 2017)

  5. W-GAN: Wasserstein GAN (Arjovsky et al., 2017)

  6. W-GAN-GP: Improved Training of Wasserstein GANs (Goodfellow et al., 2017)

Config File Structure and Details

The config file is a set of key-value pairs in JSON format. A collection of sample config files are provided here

The basic structure of the config json file is as follows,

    { 
        "generator":{
            "choice":"gan"
        },
        "discriminator":{
            "choice":"gan"
        },
        "data_path":"datasets/dataset1.p",
        "metric_evaluate":"MMD"
    }

The details of the config files are provided here:

  • generator: < json > value which contains the details of the generator module. The available parameters and possible values are:

    • choice: ["gan", "cgan", "dcgan", "cycle_gan", "wgan", "wgan_gp"] // choice of the generator module
    • input_shape: < int > // row size of the input image
    • channels: < int > // number of channels in the input image
    • latent_dim: < int > // the size of the input random vector
    • input: "[(g_channels, g_input_shape, g_input_shape), g_latent_dim]" // of the given format of input data
    • loss: ["Mean", "MSE", "BCE", "NLL"] // choice of the loss function
    • optimizer: < json > value of the optimizer and it's parameters
      • choice: ["Adam", "RMSprop"]
      • learning_rate: < int > // learning rate of the optimizer
      • b1: < int > // Coefficients used for computing running averages of gradient and its square. Used in Adam optimizer.
      • b2: < int > // Coefficients used for computing running averages of gradient and its square. Used in Adam optimizer.
  • discriminator: < json > value which contains the details of the discriminator module. The available parameters and possible values are:

    • choice: ["gan", "cgan", "dcgan", "cycle_gan", "wgan", "wgan_gp", "seq_gan"] // choice of the discriminator module
    • input_shape: < int > // row size of the input image
    • channels: < int > // number of channels in the input image
    • input: "[(g_channels, g_input_shape, g_input_shape), g_latent_dim]" // of the given format of input data
    • loss: ["Mean", "MSE", "BCE", "NLL"] // choice of the loss function
    • optimizer: < json > value of the optimizer and it's parameters
      • choice: ["Adam", "RMSprop"]
      • learning_rate: < int > // learning rate of the optimizer
      • b1: < int > // Coefficients used for computing running averages of gradient and its square. Used in Adam optimizer.
      • b2: < int > // Coefficients used for computing running averages of gradient and its square. Used in Adam optimizer.
  • data_path: "path/of/data/in/local/system"

  • metric_evaluate: ["MMD", "FID"] // maximum mean discrepancy

  • GAN_model: < json > format providing the meta details for training the GAN model

    • epochs: < int > // number of epochs for training
    • mini_batch_size: < int > // size of each mini batch
    • clip_value: < int > // the peak clip value
    • n_critic: < int > // the number of critics required for wgan
    • lambda_gp: < int > // the parameter for wgan_gp
    • data_label: < int > // the parameter required for cgan
    • classes: < int > // the number of classes in the given real data
    • seq: < binary > // 0 or 1 on whether the generation is single value or sequential. Used for seq_gan
  • result_path: "path/to/write/resulting/images"

  • save_model_path: "path/to/write/trained/model"

  • performance_log: "path/to/write/training/logs"

  • sample_interval: "frequency/to/write/resulting/images"

Comparison with Other Toolkits

Realizing the importance of easiness in training GAN models, there are a few other toolkits available in open source domain such as Keras-GAN, TF-GAN, PyTorch-GAN. However, our gan-toolkit has the following advantages:

  • Highly modularized representation of GAN model for easy mix-and-match of components across architectures. For instance, one can use the generator component from DCGAN and the discriminator component from CGAN, with the training process of WGAN.

  • An abstract representation of GAN architecture to provide multi-library support. Currently, we are providing a PyTorch support for the provided config file, while in future, we plan to support Keras and Tensorflow as well. Thus, the abstract representation is library agnostic.

  • Coding free way of designing GAN models. A simple JSON file is required to define a GAN architecture and there is no need for writing any training code to train the GAN model.

TO-DO

Immediate tasks:

  • Better the performance of seq-GAN
  • Implement a textGAN for text based applications
  • Study and implement better transfer learning approaches
  • Check out different weight init for GANs
  • Check if making optimizer as cuda is also important or not
  • Check the input for generator and discriminator to conf_data
  • Find a smart way to check the size of the reward

Long term tasks:

  • Implement driver and support for Keras and PyTorch
  • Implement more popular GAN models in this framework
  • Implement more metrics to evaluate different GAN models
  • Support multimodal data generation for GAN frameworks

Credits

We would like to thank Raunak Sinha (email) who interned with us during summer 2018 and contributed heavily to this toolkit.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.