Code Monkey home page Code Monkey logo

early-bird-tickets's Introduction

Early-Bird-Tickets

ICLR2020: spotlight License: MIT

This is PyTorch implementation of Drawing Early-Bird Tickets: Toward More Efficient Training of Deep Networks

ICLR 2020 spotlight oral paper

Table of Content

Introduction

  • Lottery Ticket Hypothesis: (Frankle & Carbin, 2019) shows that there exist winning tickets (small but critical subnetworks) for dense, randomly initialized networks, that can be trained alone to achieve comparable accuracies to the latter in a similar number of iterations.

  • Limitation: However, the identification of these winning tickets still requires the costly train-prune-retrain process, limiting their practical benefits.

  • Our Contributions:

    • We discover for the first time that the winning tickets can be identified at the very early training stage, which we term as early-bird (EB) tickets, via low-cost training schemes (e.g., early stopping and low-precision training) at large learning rates. Our finding of EB tickets is consistent with recently reported observations that the key connectivity patterns of neural networks emerge early.
    • Furthermore, we propose a mask distance metric that can be used to identify EB tickets with low computational overhead, without needing to know the true winning tickets that emerge after the full training.
    • Finally, we leverage the existence of EB tickets and the proposed mask distance to develop efficient training methods, which are achieved by first identifying EB tickets via low-cost schemes, and then continuing to train merely the EB tickets towards the target accuracy.

Experiments based on various deep networks and datasets validate: 1) the existence of EB tickets, and the effectiveness of mask distance in efficiently identifying them; and 2) that the proposed efficient training via EB tickets can achieve up to 4.7x energy savings while maintaining comparable or even better accuracy, demonstrating a promising and easily adopted method for tackling cost-prohibitive deep network training.

Early-Bird Tickets

Existence of Early-Bird Tickets

To articulate the Early-Bird (EB) tickets phenomenon: the winning tickets can be drawn very early in training, we perform ablation simulation using two representative deep models (VGG16 and PreResNet101) on two popular datasets (CIFAR10 and CIFAR100). Specifically, we follow the main idea of (Frankle & Carbin, 2019) but instead prune networks trained at earlier points to see if reliable tickets can be drawn. We adopt the same channel pruning in (Liu et al., 2017) as pruning techniuqes for all experiments since it aligns with our end goal of efficient trianing. Below figure demonstrates the existence of EB tickets (p = 30% means 30% weights are pruned, hollow star means retraining accuracy of subnetwork drawn from checkpoint with best accuracy in search stage).

Identify Early-Bird Tickets

we visialize distance evolution process among the tickets drawn from each epoch. Below figure plots the pairwise mask distance matrices (160 x 160) of the VGG16 and PreResNet101 experiments on CIFAR100 at different pruning ratio p, where (i, j)-th element in a matrix denotes the mask distance between epochs i and j in that corresponding experiment. A lower distance (close to 0) indicates a smaller mask distance and is colored warmer.

overlap

Our observation that the ticket masks quickly become stable and hardly changed in early training stages supports drawing EB tickets. We therefore measure the mask distance consecutive epochs, and draw EB tickets when such distance is smaller than a threshold. Practically, to improve the reliability of EB tickets, we will stop to draw EB tickets when the last five recorded mask distances are all smaller than given threshold.

Efficient Training via Early-Bird Tickets

Instead of adopting a three-step routine of 1) training a dense model, 2) pruning it and 3) then retraining the pruned model to restore performance, and these three steps can be iterated, we leverage the existence of EB tickets to develop EB Train scheme which replaces the aforementioned steps 1 and 2 with a lower-cost step of detecting the EB tickets.

eb-train

Basic Usage

Prerequisites

The code has the following dependencies:

  • python 3.7
  • pytorch 1.1.0
  • torchvision 0.3.0
  • Pillow (PIL) 5.4.1
  • scipy 1.2.1
  • qtorch 0.1.1 (for low precision)
  • GCC >= 4.9 on linux (for low precision)

Core Training Options

  • dataset: which dataset you want to use CIFAR10/100 by default
  • data: If you want to use ImageNet, plz specified the path to raw data
  • batch-size: all exps use 256 by default in paper
  • epochs: total epochs, 160 in total
  • schedule: at which points the learning rate degraded, use [80, 120] by default
  • lr: initial learning rate, 0.1 by default
  • save: save checkpoints to the specific directory
  • arch: which model you want to use, support vgg and resnet now
  • depth: model depth
  • filter: apply filter to dataset, default is none
  • sparsify_gt: sparify the dataset with given percentage
  • gpu_ids: multi-gpus is supported

Standard Train for Identifying Early-Bird Tickets

Example: Reproduce early-bird (EB) tickets on CIFAR-100

  • Step1: Standard train to find EB tickets at different pruning ratio. Note that one can directly stop training after identifying the emergence of EB tickets while we keep training here to compare among underlying subnetworks drawn at different training stages.
bash ./scripts/standard-train/search.sh
  • Step2: Conduct real prune for the saved checkpoints (checkpoints containing EB tickets are represented as EB-{pruning ratio}-{drawing epoch}.pth.tar format).
bash ./scripts/standard-train/prune.sh
  • Optional: Pairwise mask distance matrix visualization.
bash ./scripts/standard-train/mask_distance.sh

After calculating mask distance matrix (automatically save as overlap-0.5.npy), u can call plot_overlap.py to draw figures.

Retrain to Restore Accuracy

Example: Retrain drawn EB tickets (e.g., VGG16 for CIFAR-100) to restore accuracy

  • Finetune EB tickets from emergence epoch. Note we keep sparsity regularization for underlying iterative pruning.
bash ./scripts/standard-train/retrain_continue.sh
  • Retrain re-initialized EB tickets from scratch (refer to EB Train (re-init) in Sec. 4.3 of paper).
bash ./scripts/standard-train/retrain_scratch.sh

Low Precision Search and Retrain

We perform low precision method SWALP to both the search and retrian stages (refer to EB Train LL in Sec. 4.3 of paper). Below is the guidance taking VGG16 performed on CIFAR-10 as an example:

  • Step1: Standard train to find EB tickets at different pruning ratio.
bash ./scripts/low-precision/search.sh
  • Step 2: Conduct real prune for the saved checkpoints.
bash ./scripts/low-precision/prune.sh
  • Step 3: Finetune EB tickets from emergence epoch.
bash ./scripts/low-precision/retrain_continue.sh
  • Comparison example
eb-train

ImageNet Experiments

All pretrained checkpoints of different pruning ratio have been collected in Google Drive. To evaluate the inference accuracy of test set, we provide evaluation scripts ( EVAL_ResNet18_ImageNet.py and EVAL_ResNet50_ImageNet.py ) and corresponding commands shown below for your convenience.

bash ./scripts/resnet18-imagenet/evaluation.sh
bash ./scripts/resnet50-imagenet/evaluation.sh

ResNet18 on ImageNet

  • Step1: Standard train to find EB tickets at different pruning ratio.
bash ./scripts/resnet18-imagenet/search.sh
  • Step 2: Conduct real prune for the saved checkpoints.
bash ./scripts/resnet18-imagenet/prune.sh
  • Step 3: Finetune EB tickets from emergence epoch.
bash ./scripts/resnet18-imagenet/retrain_continue.sh
  • comparison results
eb-train

ResNet50 on ImageNet

  • Step1: Standard train to find EB tickets at different pruning ratio.
bash ./scripts/resnet50-imagenet/search.sh
  • Step 2: Conduct real prune for the saved checkpoints.
bash ./scripts/resnet50-imagenet/prune.sh
  • Step 3: Finetune EB tickets from emergence epoch.
bash ./scripts/resnet50-imagenet/retrain_continue.sh

Citation

If you find this code is useful for your research, please cite:

@inproceedings{
you2020drawing,
title={Drawing Early-Bird Tickets: Toward More Efficient Training of Deep Networks},
author={Haoran You and Chaojian Li and Pengfei Xu and Yonggan Fu and Yue Wang and Xiaohan Chen and Yingyan Lin and Zhangyang Wang and Richard G. Baraniuk},
booktitle={International Conference on Learning Representations},
year={2020},
url={https://openreview.net/forum?id=BJxsrgStvr}
}

Acknowledgement

early-bird-tickets's People

Contributors

licj15 avatar ranery avatar santosh-b avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

early-bird-tickets's Issues

Discussion on Pruning the 0 Layers

Hello! I may have found a bug in pruning.

Take vgg16 as an example. I chose a pruned_ratio of 0.6, and obtained the following cfg_mask:
[55, 64, 'M', 128, 128, 'M', 256, 253, 239, 'M', 142, 28, 8, 'M', 0, 1, 387]
One conv layer is pruned to zero. Based on my understanding, you skip the zero layer, build a ‘vgg15’ network based on the rest 12 conv layers. Then, you conduct the pruning by filling in the new network’s weight, based on the origin network’s weight.

In code implementation, you use start_mask and end_mask to mark the selected input and output channels. End_mask is actually an element of cfg_mask, which marks the origin network’s selected channels. When you encounter a layer with 0 channel, you have to skip it. So, I guess your code:

if torch.sum(end_mask) == 0:
continue

is designed to do the job. This line of code is placed right after the judgement of conv/BN layer. However, end_mask is updated only in the end of BN layer pruning. This result in the phenomenon: once the end_mask is all 0, it never changes, and the pruning algorithm directly skips all layers and ends.

Actually, I don’t think you have conducted pruning on the layers after the 0 layer. This can be seen by breakpoint debugging on the last BN layer, which weights are all 0.5(initial value). In addition, you print each conv layer right after conducting pruning on it. However, layers after the 0 layer have no printed information. I think those layers remain the initial weight.

Hope to receive your reply. Maybe I have problem understanding the code, if so please point it out. Thank you very much for your attention.
Best regards

mask distance and distance threshold

Thanks for the great work!
I have two questions.

  1. About calculating mask distance. Some layers of VGGNet have different numbers of filters, how to calculate the Hamming distance of the masks with different length? Or I guess it about calculating layer-wise distance?
  2. Does the threshold=0.1 hold for different architectures? How to find a proper threshold for a new network?

Thank you!

Issues with resprune_50

Hi,

I followed the instructions for training and pruning the resnet50 model on the ImageNet dataset. However, I get an error in the pruning stage.   I notice that there is no channel selection layer introduced in resnet50_official model. Could you please help me to reproduce the results from the paper on the ResNet-50 model?  Thank You.  

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.