Code Monkey home page Code Monkey logo

soft-filter-pruning's Introduction

Soft Filter Pruning for Accelerating Deep Convolutional Neural Networks

The PyTorch implementation for our IJCAI 2018 paper. This implementation is based on ResNeXt-DenseNet.

Table of Contents

Requirements

  • Python 3.6
  • PyTorch 0.3.1
  • TorchVision 0.3.0

Models and log files

The trained models with log files can be found in Google Drive.

The pruned model without zeros: Release page.

Training ImageNet

Usage of Pruning Training

We train each model from scratch by default. If you wish to train the model with pre-trained models, please use the options --use_pretrain --lr 0.01.

Run Pruning Training ResNet (depth 152,101,50,34,18) on Imagenet: (the layer_begin and layer_end is the index of the first and last conv layer, layer_inter choose the conv layer instead of BN layer):

python pruning_train.py -a resnet152 --save_dir ./snapshots/resnet152-rate-0.7 --rate 0.7 --layer_begin 0 --layer_end 462 --layer_inter 3  /path/to/Imagenet2012

python pruning_train.py -a resnet101 --save_dir ./snapshots/resnet101-rate-0.7 --rate 0.7 --layer_begin 0 --layer_end 309 --layer_inter 3  /path/to/Imagenet2012

python pruning_train.py -a resnet50  --save_dir ./snapshots/resnet50-rate-0.7 --rate 0.7 --layer_begin 0 --layer_end 156 --layer_inter 3  /path/to/Imagenet2012

python pruning_train.py -a resnet34  --save_dir ./snapshots/resnet34-rate-0.7 --rate 0.7 --layer_begin 0 --layer_end 105 --layer_inter 3  /path/to/Imagenet2012

python pruning_train.py -a resnet18  --save_dir ./snapshots/resnet18-rate-0.7 --rate 0.7 --layer_begin 0 --layer_end 57 --layer_inter 3  /path/to/Imagenet2012

Usage of Initial with Pruned Model

We use unpruned model as initial model by default. If you wish to initial with pruned model, please use the options --use_sparse --sparse path_to_pruned_model.

Usage of Normal Training

Run resnet(100 epochs):

python original_train.py -a resnet50 --save_dir ./snapshots/resnet50-baseline  /path/to/Imagenet2012 --workers 36

Inference the pruned model with zeros

sh scripts/inference_resnet.sh

Inference the pruned model without zeros

sh scripts/infer_pruned.sh

The pruned model without zeros could be downloaded at the Release page.

Scripts to reproduce the results in our paper

To train the ImageNet model with / without pruning, see the directory scripts (we use 8 GPUs for training).

Training Cifar-10

sh scripts/cifar10_resnet.sh

Please be care of the hyper-parameter layer_end for different layer of ResNet.

Notes

Torchvision Version

We use the torchvision of 0.3.0. If the version of your torchvision is 0.2.0, then the transforms.RandomResizedCrop should be transforms.RandomSizedCrop and the transforms.Resize should be transforms.Scale.

Why use 100 epochs for training

This can improve the accuracy slightly.

Process of ImageNet dataset

We follow the Facebook process of ImageNet. Two subfolders ("train" and "val") are included in the "/path/to/ImageNet2012". The correspding code is here.

FLOPs Calculation

Refer to the file.

Citation

@inproceedings{he2018soft,
  title     = {Soft Filter Pruning for Accelerating Deep Convolutional Neural Networks},
  author    = {He, Yang and Kang, Guoliang and Dong, Xuanyi and Fu, Yanwei and Yang, Yi},
  booktitle = {International Joint Conference on Artificial Intelligence (IJCAI)},
  pages     = {2234--2240},
  year      = {2018}
}

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.