Code Monkey home page Code Monkey logo

pba's Introduction

Population Based Augmentation (PBA)

Table of Contents

  1. Introduction
  2. Getting Started
  3. Reproduce Results
  4. Run PBA Search
  5. Citation

Introduction

Population Based Augmentation (PBA) is a algorithm that quickly and efficiently learns data augmentation functions for neural network training. PBA matches state-of-the-art results on CIFAR with one thousand times less compute, enabling researchers and practitioners to effectively learn new augmentation policies using a single workstation GPU.

This repository contains code for the work "Population Based Augmentation: Efficient Learning of Augmentation Schedules" (http://arxiv.org/abs/1905.05393) in TensorFlow and Python 2. It includes training of models with the reported augmentation schedules and discovery of new augmentation policy schedules.

See below for a visualization of our augmentation strategy.

Getting Started

Install requirements

pip install -r requirements.txt

Download CIFAR-10/CIFAR-100 datasets

bash datasets/cifar10.sh
bash datasets/cifar100.sh

Reproduce Results

Dataset Model Test Error (%)
CIFAR-10 Wide-ResNet-28-10 2.58
Shake-Shake (26 2x32d) 2.54
Shake-Shake (26 2x96d) 2.03
Shake-Shake (26 2x112d) 2.03
PyramidNet+ShakeDrop 1.46
Reduced CIFAR-10 Wide-ResNet-28-10 12.82
Shake-Shake (26 2x96d) 10.64
CIFAR-100 Wide-ResNet-28-10 16.73
Shake-Shake (26 2x96d) 15.31
PyramidNet+ShakeDrop 10.94
SVHN Wide-ResNet-28-10 1.18
Shake-Shake (26 2x96d) 1.13
Reduced SVHN Wide-ResNet-28-10 7.83
Shake-Shake (26 2x96d) 6.46

Scripts to reproduce results are located in scripts/table_*.sh. One argument, the model name, is required for all of the scripts. The available options are those reported for each dataset in Tables 1-4 of the paper, among the choices: wrn_28_10, ss_32, ss_96, ss_112, pyramid_net. Hyperparamaters are also located inside each script file.

For example, to reproduce CIFAR-10 results on Wide-ResNet-28-10:

bash scripts/table_1_cifar10.sh wrn_28_10

To reproduce Reduced SVHN results on Shake-Shake (26 2x96d):

bash scripts/table_4_svhn.sh rsvhn_ss_96

A good place to start is Reduced SVHN on Wide-ResNet-28-10 which can complete in under 10 minutes on a Titan XP GPU reaching 91%+ test accuracy.

Running the larger models on 1800 epochs may require multiple days of training. For example, CIFAR-10 PyramidNet+ShakeDrop takes around 9 days on a Tesla V100 GPU.

Run PBA Search

Run PBA search on Wide-ResNet-40-2 with the file scripts/search.sh. One argument, the dataset name, is required. Choices are rsvhn or rcifar10.

A partial GPU size is specified to launch multiple trials on the same GPU. Reduced SVHN takes around an hour on a Titan XP GPU, and Reduced CIFAR-10 takes around 5 hours.

CUDA_VISIBLE_DEVICES=0 bash scripts/search.sh rsvhn

The resulting schedules used in search can be retreived from the Ray result directory, and the log files can be converted into policy schedules with the parse_log() function in pba/utils.py. For example, policy schedule learned on Reduced CIFAR-10 over 200 epochs is split into probability and magnitude hyperparameter values (the two values for each augmentation operation are merged) and visualized below:

Probability Hyperparameters over Time Magnitude Hyperparameters over Time

Citation

If you use PBA in your research, please cite:

@inproceedings{ho2019pba,
  title     = {Population Based Augmentation: Efficient Learning of Augmentation Policy Schedules},
  author    = {Daniel Ho and
               Eric Liang and
               Ion Stoica and
               Pieter Abbeel and
               Xi Chen
  },
  booktitle = {ICML},
  year      = {2019}
}

pba's People

Contributors

arcelien avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.