The aim of the toolkit is to provide a highly flexible, no-code way of implementing GAN models. By providing the details of a GAN model, in an intuitive config file or as command line arguments, the code could be generated for training the GAN model. With very little or no prior knowledge about GAN, people could play around with different formulations of GAN and mix-and-match GAN modules to create novel models, as well.
-
(Optional) If you want to setup an anaconda environment
a. Install Anaconda from here
b. Create a conda environment
$ conda create -n gantoolkit python=3.6 anaconda
c. Activate the conda environment
$ source activate gantoolkit
-
Clone the code
$ git clone https://github.com/IBM/gan-toolkit
-
Install all the requirements. Tested for Python 3.5.x+
$ pip install -r requirements.txt
-
Train the model using a configuration file. (Many samples are provided in the
configs
folder)$ cd agant $ python main.py --config configs/gan_gan.json
-
Default input and output paths (override thse paths in the config file)
logs/
: training logssaved_models/
: saved trained modelstrain_results/
: saved all the intermediate generated imagesdatasets/
: input dataset path
-
Vanilla GAN: Generative Adversarial Learning (Goodfellow et al., 2014)
-
C-GAN: Conditional Generative Adversarial Networks (Mirza et al., 2014)
-
DC-GAN: Deep Convolutional Generative Adversarial Network (Radford et al., 2016)
-
Cycle-GAN: Cycle-Consistent Adversarial Networks (Zhu et al., 2017)
-
W-GAN: Wasserstein GAN (Arjovsky et al., 2017)
-
W-GAN-GP: Improved Training of Wasserstein GANs (Goodfellow et al., 2017)
The config file is a set of key-value pairs in JSON format. A collection of sample config files are provided here
The basic structure of the config
json file is as follows,
{
"generator":{
"choice":"gan"
},
"discriminator":{
"choice":"gan"
},
"data_path":"datasets/dataset1.p",
"metric_evaluate":"MMD"
}
The details of the config files are provided here:
-
generator
: < json > value which contains the details of the generator module. The available parameters and possible values are:choice
: ["gan", "cgan", "dcgan", "cycle_gan", "wgan", "wgan_gp"] // choice of the generator moduleinput_shape
: < int > // row size of the input imagechannels
: < int > // number of channels in the input imagelatent_dim
: < int > // the size of the input random vectorinput
: "[(g_channels, g_input_shape, g_input_shape), g_latent_dim]" // of the given format of input dataloss
: ["Mean", "MSE", "BCE", "NLL"] // choice of the loss functionoptimizer
: < json > value of the optimizer and it's parameterschoice
: ["Adam", "RMSprop"]learning_rate
: < int > // learning rate of the optimizerb1
: < int > // Coefficients used for computing running averages of gradient and its square. Used in Adam optimizer.b2
: < int > // Coefficients used for computing running averages of gradient and its square. Used in Adam optimizer.
-
discriminator
: < json > value which contains the details of the discriminator module. The available parameters and possible values are:choice
: ["gan", "cgan", "dcgan", "cycle_gan", "wgan", "wgan_gp", "seq_gan"] // choice of the discriminator moduleinput_shape
: < int > // row size of the input imagechannels
: < int > // number of channels in the input imageinput
: "[(g_channels, g_input_shape, g_input_shape), g_latent_dim]" // of the given format of input dataloss
: ["Mean", "MSE", "BCE", "NLL"] // choice of the loss functionoptimizer
: < json > value of the optimizer and it's parameterschoice
: ["Adam", "RMSprop"]learning_rate
: < int > // learning rate of the optimizerb1
: < int > // Coefficients used for computing running averages of gradient and its square. Used in Adam optimizer.b2
: < int > // Coefficients used for computing running averages of gradient and its square. Used in Adam optimizer.
-
data_path
: "path/of/data/in/local/system" -
metric_evaluate
: ["MMD", "FID"] // maximum mean discrepancy -
GAN_model
: < json > format providing the meta details for training the GAN modelepochs
: < int > // number of epochs for trainingmini_batch_size
: < int > // size of each mini batchclip_value
: < int > // the peak clip valuen_critic
: < int > // the number of critics required for wganlambda_gp
: < int > // the parameter for wgan_gpdata_label
: < int > // the parameter required for cganclasses
: < int > // the number of classes in the given real dataseq
: < binary > // 0 or 1 on whether the generation is single value or sequential. Used for seq_gan
-
result_path
: "path/to/write/resulting/images" -
save_model_path
: "path/to/write/trained/model" -
performance_log
: "path/to/write/training/logs" -
sample_interval
: "frequency/to/write/resulting/images"
Realizing the importance of easiness in training GAN models, there are a few other toolkits available in open source domain such as Keras-GAN, TF-GAN, PyTorch-GAN. However, our gan-toolkit
has the following advantages:
-
Highly modularized representation of GAN model for easy mix-and-match of components across architectures. For instance, one can use the
generator
component from DCGAN and thediscriminator
component from CGAN, with the training process of WGAN. -
An abstract representation of GAN architecture to provide multi-library support. Currently, we are providing a PyTorch support for the provided
config
file, while in future, we plan to support Keras and Tensorflow as well. Thus, the abstract representation is library agnostic. -
Coding free way of designing GAN models. A simple JSON file is required to define a GAN architecture and there is no need for writing any training code to train the GAN model.
Immediate tasks:
- Better the performance of seq-GAN
- Implement a textGAN for text based applications
- Study and implement better transfer learning approaches
- Check out different weight init for GANs
- Check if making optimizer as cuda is also important or not
- Check the input for generator and discriminator to conf_data
- Find a smart way to check the size of the reward
Long term tasks:
- Implement driver and support for Keras and PyTorch
- Implement more popular GAN models in this framework
- Implement more metrics to evaluate different GAN models
- Support multimodal data generation for GAN frameworks
We would like to thank Raunak Sinha (email) who interned with us during summer 2018 and contributed heavily to this toolkit.