Code Monkey home page Code Monkey logo

pro_gan_pytorch-examples's Introduction

pro_gan_pytorch-examples

This repository contains examples trained using the python package pro-gan-pth. You can find the github repo for the project at github-repository and the PyPI package at pypi

There are two examples presented here for LFW dataset and MNIST dataset. Please refer to the following sections for how to train and / or load the provided trained weights for these models.

Prior Setup

Before running any of the following training experiments, please setup your VirtualEnv with the required packages for this project. Importantly, please install the progan package using $ pip install pro-gan-pth and your appropriate gpu / cpu version of PyTorch 0.4.0. Once this is done, you can proceed with the following experiments.

LFW Experiment

The configuration used for the LFW training experiment can be found in implementation/configs/lfw.conf in this repository. The training was performed using the wgan-gp loss function.

Examples:


Sample loss plot:


MNIST Experiment

The configuration used for the MNIST training experiment can be found in implementation/configs/mnist.conf in this repository. The training was performed using the lsgan loss function.

Examples:


Sample loss plot:


How to use:

Running the training script:

For running the training script, simply use the following procedure:
$ cd implementation
$ python train_network.py --config=configs/mnist.conf

You can tinker with the configuration for your desired behaviour. This training script also exposes some of the use cases of the package's api.

Generating loss plots:

You can generate the loss plots from the `loss-logs` by using the provided script. The logs get generated while the training is in progress.
$ python generate_loss_plots --logdir=training_runs/mnist/losses/ \
                             --plotdir=training_runs/mnist/losses/loss_plots/

Using trained model:

please refer to the following code snippet if you just wish to use the trained model for generating samples:
import torch as th
import pro_gan_pytorch.PRO_GAN as pg
import matplotlib.pyplot as plt

device = th.device("cuda" if th.cuda.is_available() 
                   else "cpu")
gen = pg.Generator(depth=4, latent_size=128, 
                   use_eql=False).to(device)

gen.load_state_dict(
    th.load("training_runs/saved_models/GAN_GEN_3.pth")
)

noise = th.randn(1, 128).to(device)

sample_image = gen(noise, detph=3, alpha=1).detach()

plt.imshow(sample_image[0].permute(1, 2, 0) / 2 + 0.5)
plt.show()

The trained weights can be found in the saved_models directory present in respective training_runs.

How to use on Google Colab Notebook:

This code can be run on Google Colaboratory using GPU acceleration. Colab offers a free Tesla K80 GPU with up to ~12GB of VRAM. However, the duration of the instance is limited and closes after a certain time. All installed libraries and saved files will be reset in that process. A workaround is to save training results to Google Drive. The packages need to be installed after every instance reset.

Here is a step-by-step instruction on how to run this using Google Colab. ProGAN Colaboratory Notebook

Thanks:

Please feel free to open PRs here if you train on other datasets using this package.

Best regards,
@akanimax :)

pro_gan_pytorch-examples's People

Contributors

akanimax avatar ianmcmill avatar jyu-theartofml avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.