Code Monkey home page Code Monkey logo

image-gpt-1's Introduction

Image GPT

PyTorch implementation of Image GPT, based on paper Generative Pretraining from Pixels (Chen et al.) and accompanying code.


Model-generated completions of half-images from test set. First column is input; last column is original image


iGPT-S pretrained on CIFAR10. Completions are fairly poor as the model was only trained on CIFAR10, not all of ImageNet.

WIP

  • Batched k-means on GPU for quantization of larger datasets (currently using sklearn.cluster.MiniBatchKMeans.)
  • BERT-style pretraining (currently only generative is supported.)
  • Load pretrained models from OpenAI.
  • Reproduce at least iGPT-S results.

According to their blog post, the largest model, iGPT-L (1.4 M parameters), was trained for 2500 V100-days. By greatly reducing the number of attention head, number of layers, and input size (which effects model size quadratically), we can train our own model (26 K parameters) on Fashion-MNIST on a single NVIDIA 2070 in less than 2 hours.

Usage

Pre-trained Models

Some pre-trained models are located in models directory. Run ./download.sh to download the cifar10 pretrained iGPT-S model.

Compute Centroids

Images are downloaded, and centroids are computed using k-means with num_clusters clusters. These centroids are used to quantize the images before they are fed into the model.

# options: mnist, fmnist, cifar10
python src/compute_centroids.py --dataset mnist --num_clusters=8

# creates data/<dataset>_centroids.npy

Note: Use the same num_clusters as num_vocab in your model.

Training

Models can be trained using src/run.py with the train subcommand.

Generative Pre-training

Models can be pretrained by specifying a dataset and model config. configs/s_gen.yml corresponds to iGPT-S from the paper, configs/xxs_gen.yml is an extra small model for trying on toy datasets with limited compute.

python src/run.py --dataset mnist train configs/xxs_gen.yml

Classification Fine-tuning

Pre-trained models can be fine-tuned by passing the path to the pre-trained checkpoint to --pretrained, along with the config file and dataset.

python src/run.py --dataset mnist train configs/xxs_clf.yml --pretrained=models/mnist_gen.ckpt`

Sampling

Figures like those seen above can be created using random images from test set:

# outputs to figure.png
python src/sample.py models/mnist_gen.ckpt

Gifs like the one seen in my tweet can be made like so:

# outputs to out.gif
python src/gif.py models/mnist_gen.ckpt

image-gpt-1's People

Contributors

teddykoker avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.