Code Monkey home page Code Monkey logo

str's Introduction

Aditya Kusupati, Vivek Ramanujan*, Raghav Somani*, Mitchell Worstsman*, Prateek Jain, Sham Kakade and Ali Farhadi

This repository contains code for the CNN experiments presented in the ICML 2020 paper along with more functionalities.

This code base is built upon the hidden-networks repository modified for STR, DNW and GMP experiments.

The RNN experiments in the paper are done by modifying FastGRNNCell in EdgeML using the methods discussed in the paper.

Set Up

  1. Clone this repository.
  2. Using Python 3.6, create a venv with python -m venv myenv and run source myenv/bin/activate. You can also use conda to create a virtual environment.
  3. Install requirements with pip install -r requirements.txt for venv and appropriate conda commands for conda environment.
  4. Create a data directory <data-dir>. To run the ImageNet experiments there must be a folder <data-dir>/imagenet that contains the ImageNet train and val folders that contains images of each class in a seperate folder.

STRConv

STRConv along with other custom convolution modules can be found in utils/conv_type.py. Users can take STRConv and use it in most of the PyTorch based models as it inherits from nn.Conv2d or also mentioned here as DenseConv.

Vanilla Training

This codebase contains model architectures for ResNet18, ResNet50 and MobileNetV1 and support to train them on ImageNet-1K. We have provided some config files for training ResNet50 and MobileNetV1 which can be modified for other architectures and datasets. To support more datasets, please add new dataloaders to data folder.

Training across multiple GPUs is supported, however, the user should check the minimum number of GPUs required to scale ImageNet-1K.

Train dense models on ImageNet-1K:

ResNet50: python main.py --config configs/largescale/resnet50-dense.yaml --multigpu 0,1,2,3

MobileNetV1: python main.py --config configs/largescale/mobilenetv1-dense.yaml --multigpu 0,1,2,3

Train models with STR on ImageNet-1K:

ResNet50: python main.py --config configs/largescale/resnet50-str.yaml --multigpu 0,1,2,3

MobileNetV1: python main.py --config configs/largescale/mobilenetv1-str.yaml --multigpu 0,1,2,3

To reproduce the results in the paper, please modify the config files appropriately using the hyperparameters from the appendix of STR paper.

Train ResNet50 models with DNW and GMP on ImageNet-1K:

DNW: python main.py --config configs/largescale/resnet50-dnw.yaml --multigpu 0,1,2,3

GMP: python main.py --config configs/largescale/resnet50-gmp.yaml --multigpu 0,1,2,3

Please note that GMP implementation is not thoroughly tested, so caution is advised.

Modify the config files to tweak the performance and sparsity levels in both DNW and GMP.

Models and Logging

STR models are not compatible with the traditional dense models for simple evaluation and usage as transfer learning backbones. DNW and GMP models are compatible to the dense model.

Every experiment creates a directory inside runs folder (which will be created automatically) along with the tensorboard logs, initial model state (for LTH experiments) and best model (model_best.pth).

The runs folder also has dumps of the csv with final and best accuracies along with layer-wise sparsity distributions and thresholds in case of STR. The code checkpoints after every epoch giving a chance to resume training when pre-empted, the extra functionalities can be explored through python main.py -h.

Convert STR model to dense model:

ResNet50: python main.py --config configs/largescale/resnet50-dense.yaml --multigpu 0,1,2,3 --pretrained <ResNet50-STR-Model> --dense-conv-model

MobileNetV1: python main.py --config configs/largescale/mobilenetv1-dense.yaml --multigpu 0,1,2,3 --pretrained <MobileNetV1-STR-Model> --dense-conv-model

These models use the names provided in the corresponding config files being used but can also be modified using --name argument in the command line.

Evaluating models on ImageNet-1K:

If you want to evaluate a pretrained STR model provided below, you can either use the model as is or convert it to a dense model and use the dense model evaluation. To encourage uniformity, please try to convert the STR models to dense or use the dense compatible models if provided.

Dense Model Evaluation: python main.py --config configs/largescale/<arch>-dense.yaml --multigpu 0,1,2,3 --pretrained <Dense-Compatible-Model> --evaluate

STR Model Evaluation: python main.py --config configs/largescale/<arch>-str.yaml --multigpu 0,1,2,3 --pretrained <STR-Model> --evaluate

Sparsity Budget Transfer

If it is hard to hand-code all the budgets into a method like DNW, you can use the budget transfer functionalities of the repo. The pre-trained models provided have to be in the native STR model format and not in a converted/compatible Dense model format. You should change this piece of code to support the Dense format as well.

Transfer to DNW: python main.py --config configs/largescale/<arch>-dnw.yaml --multigpu 0,1,2,3 --pretrained <STR-Model> --ignore-pretrained-weights --use-budget

Transfer to GMP: python main.py --config configs/largescale/<arch>-gmp.yaml --multigpu 0,1,2,3 --pretrained <STR-Model> --ignore-pretrained-weights --use-budget

You should modify the corresponding config files for DNW and GMP to increase accuracy by changing the hyperparameters.

Pretrained Models

All the models provided here are trained on ImageNet-1K according to the settings in the paper.

Fully Dense Models:

These models are straightforward to train using this repo and their pre-trained models are in most of the popular frameworks. For the sake of reproducibility, pretrained dense models are provided.

Architecture Params Sparsity (%) Top-1 Acc (%) FLOPs Model Links
ResNet50 25.6M 0.00 77.01 4.09G Dense
MobileNetV1 4.21M 0.00 71.95 569M Dense

STR Sparse Models:

We are providing links to 6 models for ResNet50 and 2 models for MobileNetV1. These models represent the sparsity regime they belong to. Each model has two versions of model links to download, the first one is the vanilla STR model and the second one is the STR model converted to be compatible with Dense models and for transfer learning. Please contact Aditya Kusupati in case you need a specific model and are not able to train it from scratch. All the sparsity budgets for every model in the paper are present in the appendix, in case all you need is the non-uniform sparsity budget.

ResNet50:

No. Params Sparsity (%) Top-1 Acc (%) FLOPs Model Links
1 4.47M 81.27 76.12 705M STR, Dense
2 2.49M 90.23 74.31 343M STR, Dense
3 1.24M 95.15 70.23 162M STR, Dense
4 0.99M 96.11 67.78 127M STR, Dense
5 0.50M 98.05 61.46 73M STR, Dense
6 0.26M 98.98 51.82 47M STR, Dense

MobileNetV1 :

No. Params Sparsity (%) Top-1 Acc (%) FLOPs Model Links
1 1.04M 75.28 68.35 101M STR, Dense
2 0.46M 89.01 62.10 42M STR, Dense

Note: If you find any STR model to be 2x the size of its Dense compatible model, it might be because of an old implementation that might have resulted in a model that replicated the weights.

Sparsity Budgets

The folder budgets contains the csv files containing all the non-uniform sparsity budgets STR learnt for ResNet50 on ImageNet-1K across all the sparsity regimes along with baseline budgets for 90% sparse ResNet50 on ImageNet-1K. In case, you are not able to use the pretraining models to extract sparsity budgets, you can directly import the same budgets using these files.

Citation

If you find this project useful in your research, please consider citing:

@inproceedings{Kusupati20
  author    = {Kusupati, Aditya and Ramanujan, Vivek and Somani, Raghav and Wortsman, Mitchell and Jain, Prateek and Kakade, Sham and Farhadi, Ali},
  title     = {Soft Threshold Weight Reparameterization for Learnable Sparsity},
  booktitle = {Proceedings of the International Conference on Machine Learning},
  month     = {July},
  year      = {2020},
}

str's People

Contributors

adityakusupati avatar dependabot[bot] avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

str's Issues

No Sparsity_ResNet18 CIFAR10

Thanks for your great work and clean code.
I run STR on ResNet18 for CIFAR10.
but the generated JSON file for sparsity says that there is almost no pruning happened:
{"module.conv1": 0.0, "module.layer1.0.conv1": 0.0, "module.layer1.0.conv2": 0.0, "module.layer1.1.conv1": 0.0, "module.layer1.1.conv2": 0.0, "module.layer2.0.conv1": 0.0, "module.layer2.0.conv2": 0.0, "module.layer2.0.downsample.0": 0.0, "module.layer2.1.conv1": 0.0, "module.layer2.1.conv2": 0.0, "module.layer3.0.conv1": 0.0, "module.layer3.0.conv2": 0.0, "module.layer3.0.downsample.0": 0.0, "module.layer3.1.conv1": 0.0, "module.layer3.1.conv2": 0.0, "module.layer4.0.conv1": 0.0, "module.layer4.0.conv2": 0.0, "module.layer4.0.downsample.0": 0.0, "module.layer4.1.conv1": 0.0, "module.layer4.1.conv2": 4.172325134277344e-05, "module.fc": 0.0, "total": 8.562441436765766e-06}

Could the choice of hyperparameters affect on this scale (No sparsity at all)? I used the same YAML config as ResNEt50 on ImageNet.
I found that you also experimented ResNet18 with CIFAR10, CIFAR100 and TinyImageNet. Can you please give me the hyperparameter choice for 90% sparsity?
https://m.youtube.com/watch?v=Hrki0p_gZKk

view size error

Hi, Thanks for your great work!

I just ran the evaluation code, but i found the error

...
  File "/home/goqhadl9298/STR/utils/eval_utils.py", line 18, in accuracy
    correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

can you please check it?

Issues with training using GMP

I hope you are doing well today. I have run into some issues when training models using GMP.

First, when the pruning is supposed to begin at the epoch args.init_prune_epoch, the following error is thrown from main.py:

Traceback (most recent call last):
  File "main.py", line 494, in <module>
    main()
  File "main.py", line 43, in main
    main_worker(args)
  File "main.py", line 143, in main_worker
    prune_decay = (1 - ((args.curr_prune_epoch - args.init_prune_epoch)/total_prune_epochs))**3
AttributeError: 'Namespace' object has no attribute 'curr_prune_epoch'

I performed a grep search for curr_prune_epoch and this is the only place it appears in the entire STR repo. Additionally, looking at older versions of the code I did not see any references to curr_prune_epoch. I suspect replacing curr_prune_epoch by epoch will result in the intended level gradual magnitude pruning. However, even if I make this substitution there is a second issue. Even before the pruning begins (at the epoch args.init_prune_epoch), the model does not appear to be learning. I attempted training smaller models on CIFAR-10 (Conv2/4/6 architectures from hidden-networks repo that are compatible with STR code) and after trying 30+ initializations I have not observed any model learning. I'm not sure if this is a phenomenon you have observed but based on the first issue it is unclear to me if this publicly available version GMP was successfully tested and I wanted to reach out and ask about these issues before trying to debug it any further. Below is the output of the 0th epoch for one such run (Loss at following epochs remains around 2.303):

Epoch: [0][  0/391]     Time  1.647 ( 1.647)    Data  0.371 ( 0.371)    Loss 3.793 (3.793)      Acc@1   7.81 (  7.81)   Acc@5  47.66 ( 47.66)                 
Epoch: [0][ 10/391]     Time  0.012 ( 0.162)    Data  0.000 ( 0.034)    Loss 2.303 (12.470)     Acc@1   7.81 ( 10.23)   Acc@5  56.25 ( 48.37)                 
Epoch: [0][ 20/391]     Time  0.012 ( 0.091)    Data  0.000 ( 0.018)    Loss 2.303 (7.628)      Acc@1   8.59 ( 10.16)   Acc@5  50.00 ( 48.07)                 
Epoch: [0][ 30/391]     Time  0.012 ( 0.066)    Data  0.000 ( 0.013)    Loss 2.303 (5.910)      Acc@1  10.94 ( 10.08)   Acc@5  50.78 ( 48.39)                 
Epoch: [0][ 40/391]     Time  0.012 ( 0.053)    Data  0.000 ( 0.011)    Loss 2.303 (5.030)      Acc@1  11.72 ( 10.33)   Acc@5  60.16 ( 48.91)                 
Epoch: [0][ 50/391]     Time  0.012 ( 0.046)    Data  0.000 ( 0.009)    Loss 2.303 (4.496)      Acc@1   7.81 (  9.99)   Acc@5  48.44 ( 49.53)                 
Epoch: [0][ 60/391]     Time  0.012 ( 0.041)    Data  0.000 ( 0.008)    Loss 2.303 (4.136)      Acc@1  12.50 ( 10.11)   Acc@5  46.09 ( 49.49)                 
Epoch: [0][ 70/391]     Time  0.012 ( 0.037)    Data  0.000 ( 0.007)    Loss 2.303 (3.878)      Acc@1   5.47 (  9.97)   Acc@5  50.00 ( 49.35)                 
Epoch: [0][ 80/391]     Time  0.014 ( 0.034)    Data  0.000 ( 0.007)    Loss 2.303 (3.683)      Acc@1  10.94 ( 10.03)   Acc@5  49.22 ( 49.15)                 
Epoch: [0][ 90/391]     Time  0.012 ( 0.032)    Data  0.000 ( 0.007)    Loss 2.303 (3.532)      Acc@1   7.03 ( 10.01)   Acc@5  47.66 ( 49.43)                 
Epoch: [0][100/391]     Time  0.012 ( 0.031)    Data  0.000 ( 0.006)    Loss 2.303 (3.410)      Acc@1   5.47 ( 10.00)   Acc@5  48.44 ( 49.64)                 
Epoch: [0][110/391]     Time  0.012 ( 0.029)    Data  0.000 ( 0.006)    Loss 2.303 (3.310)      Acc@1   8.59 (  9.93)   Acc@5  46.88 ( 49.61)                 
Epoch: [0][120/391]     Time  0.012 ( 0.028)    Data  0.000 ( 0.006)    Loss 2.303 (3.227)      Acc@1  17.19 ( 10.00)   Acc@5  60.94 ( 49.84)                 
Epoch: [0][130/391]     Time  0.012 ( 0.027)    Data  0.000 ( 0.006)    Loss 2.303 (3.156)      Acc@1   7.81 ( 10.03)   Acc@5  50.00 ( 49.93)                 
Epoch: [0][140/391]     Time  0.012 ( 0.026)    Data  0.000 ( 0.005)    Loss 2.303 (3.096)      Acc@1  10.16 (  9.98)   Acc@5  51.56 ( 49.92)                 
Epoch: [0][150/391]     Time  0.012 ( 0.025)    Data  0.000 ( 0.005)    Loss 2.303 (3.043)      Acc@1   8.59 ( 10.00)   Acc@5  50.00 ( 50.06)                 
Epoch: [0][160/391]     Time  0.012 ( 0.025)    Data  0.000 ( 0.005)    Loss 2.303 (2.997)      Acc@1  12.50 ( 10.11)   Acc@5  53.91 ( 50.18)                 
Epoch: [0][170/391]     Time  0.012 ( 0.024)    Data  0.000 ( 0.005)    Loss 2.303 (2.957)      Acc@1  10.94 ( 10.00)   Acc@5  47.66 ( 50.09)                 
Epoch: [0][180/391]     Time  0.012 ( 0.024)    Data  0.000 ( 0.005)    Loss 2.303 (2.920)      Acc@1   7.81 ( 10.01)   Acc@5  50.00 ( 50.08)                 
Epoch: [0][190/391]     Time  0.012 ( 0.023)    Data  0.000 ( 0.005)    Loss 2.303 (2.888)      Acc@1   9.38 ( 10.02)   Acc@5  53.12 ( 50.09)                 
Epoch: [0][200/391]     Time  0.012 ( 0.023)    Data  0.000 ( 0.005)    Loss 2.303 (2.859)      Acc@1   9.38 ( 10.09)   Acc@5  55.47 ( 50.13)                 
Epoch: [0][210/391]     Time  0.012 ( 0.022)    Data  0.000 ( 0.005)    Loss 2.303 (2.833)      Acc@1   5.47 ( 10.10)   Acc@5  46.09 ( 50.11)                 
Epoch: [0][220/391]     Time  0.013 ( 0.022)    Data  0.000 ( 0.005)    Loss 2.303 (2.809)      Acc@1   7.03 ( 10.12)   Acc@5  49.22 ( 50.23)                 
Epoch: [0][230/391]     Time  0.012 ( 0.022)    Data  0.000 ( 0.004)    Loss 2.303 (2.787)      Acc@1  13.28 ( 10.17)   Acc@5  53.12 ( 50.28)                 
Epoch: [0][240/391]     Time  0.012 ( 0.021)    Data  0.000 ( 0.004)    Loss 2.303 (2.767)      Acc@1   9.38 ( 10.14)   Acc@5  45.31 ( 50.25)                 
Epoch: [0][250/391]     Time  0.012 ( 0.021)    Data  0.000 ( 0.004)    Loss 2.303 (2.748)      Acc@1  12.50 ( 10.15)   Acc@5  50.78 ( 50.25)                 
Epoch: [0][260/391]     Time  0.012 ( 0.021)    Data  0.000 ( 0.004)    Loss 2.303 (2.731)      Acc@1   7.03 ( 10.13)   Acc@5  48.44 ( 50.17)                 
Epoch: [0][270/391]     Time  0.012 ( 0.021)    Data  0.000 ( 0.004)    Loss 2.303 (2.715)      Acc@1  11.72 ( 10.12)   Acc@5  57.03 ( 50.13)                 
Epoch: [0][280/391]     Time  0.012 ( 0.020)    Data  0.000 ( 0.004)    Loss 2.303 (2.701)      Acc@1  12.50 ( 10.12)   Acc@5  57.81 ( 50.14)                 
Epoch: [0][290/391]     Time  0.012 ( 0.020)    Data  0.000 ( 0.004)    Loss 2.303 (2.687)      Acc@1   6.25 ( 10.08)   Acc@5  44.53 ( 50.11)                 
Epoch: [0][300/391]     Time  0.012 ( 0.020)    Data  0.000 ( 0.004)    Loss 2.303 (2.674)      Acc@1   9.38 ( 10.07)   Acc@5  51.56 ( 50.11)                 
Epoch: [0][310/391]     Time  0.012 ( 0.020)    Data  0.000 ( 0.004)    Loss 2.303 (2.662)      Acc@1  14.84 ( 10.07)   Acc@5  56.25 ( 50.11)                 
Epoch: [0][320/391]     Time  0.012 ( 0.020)    Data  0.000 ( 0.004)    Loss 2.303 (2.651)      Acc@1  11.72 ( 10.06)   Acc@5  47.66 ( 50.07)                 
Epoch: [0][330/391]     Time  0.012 ( 0.020)    Data  0.000 ( 0.004)    Loss 2.303 (2.640)      Acc@1   8.59 ( 10.05)   Acc@5  47.66 ( 50.09)                 
Epoch: [0][340/391]     Time  0.012 ( 0.019)    Data  0.000 ( 0.004)    Loss 2.303 (2.631)      Acc@1   7.03 ( 10.03)   Acc@5  38.28 ( 50.04)                 
Epoch: [0][350/391]     Time  0.012 ( 0.019)    Data  0.000 ( 0.004)    Loss 2.303 (2.621)      Acc@1  10.16 ( 10.00)   Acc@5  52.34 ( 50.05)                 
Epoch: [0][360/391]     Time  0.012 ( 0.019)    Data  0.000 ( 0.004)    Loss 2.303 (2.612)      Acc@1   9.38 ( 10.00)   Acc@5  49.22 ( 50.06)                 
Epoch: [0][370/391]     Time  0.012 ( 0.019)    Data  0.000 ( 0.004)    Loss 2.303 (2.604)      Acc@1  11.72 ( 10.05)   Acc@5  50.00 ( 50.10)                 
Epoch: [0][380/391]     Time  0.012 ( 0.019)    Data  0.000 ( 0.004)    Loss 2.303 (2.596)      Acc@1   6.25 ( 10.02)   Acc@5  46.09 ( 50.05)                 
Epoch: [0][390/391]     Time  0.115 ( 0.019)    Data  0.000 ( 0.004)    Loss 2.303 (2.589)      Acc@1   6.25 ( 10.00)   Acc@5  48.75 ( 50.01)                 
100%|#######################################################################################################################| 391/391 [00:07<00:00, 52.66it/s]
Test: [ 0/79]   Time  0.145 ( 0.145)    Loss 2.303 (2.303)      Acc@1   7.81 (  7.81)   Acc@5  46.09 ( 46.09)                                                 
Test: [10/79]   Time  0.008 ( 0.022)    Loss 2.303 (2.303)      Acc@1  11.72 (  9.80)   Acc@5  47.66 ( 49.01)                                                 
Test: [20/79]   Time  0.017 ( 0.017)    Loss 2.303 (2.303)      Acc@1   7.81 (  9.60)   Acc@5  55.47 ( 49.70)                                                 
Test: [30/79]   Time  0.009 ( 0.014)    Loss 2.303 (2.303)      Acc@1  10.94 (  9.80)   Acc@5  44.53 ( 49.97)                                                 
Test: [40/79]   Time  0.010 ( 0.014)    Loss 2.303 (2.303)      Acc@1  11.72 ( 10.12)   Acc@5  57.81 ( 50.21)                                                 
Test: [50/79]   Time  0.008 ( 0.013)    Loss 2.303 (2.303)      Acc@1  10.94 (  9.80)   Acc@5  52.34 ( 50.15)                                                 
Test: [60/79]   Time  0.008 ( 0.013)    Loss 2.303 (2.303)      Acc@1   5.47 (  9.87)   Acc@5  49.22 ( 50.27)                                                 
Test: [70/79]   Time  0.008 ( 0.013)    Loss 2.303 (2.303)      Acc@1   5.47 ( 10.01)   Acc@5  50.00 ( 50.22)                                                 
100%|#########################################################################################################################| 79/79 [00:00<00:00, 81.34it/s]
Test: [79/79]	Time  0.049 ( 0.013)	Loss 2.303 (2.303)	Acc@1   6.25 ( 10.00)	Acc@5  43.75 ( 50.00)

how to set the sparsity of STR sparse networks

Thank you for your terrific work!
I wanna train STRvonv based resnet20 with 80%, 90% and 95% sparsity, respectively. How do I set weight decay values or Sinit to make the resultant network meet the specified sparsity requirement?
Looking forward to your reply.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.