Code Monkey home page Code Monkey logo

pruning_filters_for_efficient_convnets's Introduction

Pruning Filters For Efficient ConvNets

Unofficial PyTorch implementation of pruning VGG on CIFAR-10 Data set

Reference: Pruning Filters For Efficient ConvNets, ICLR2017

Contact: Minseong Kim ([email protected])

Requirements

  • torch (version: 1.2.0)
  • torchvision (version: 0.4.0)
  • Pillow (version: 6.1.0)
  • matplotlib (version: 3.1.1)
  • numpy (version: 1.16.5)

Usage

Arguments

  • --train-flag: Train VGG on CIFAR Data set
  • --save-path: Path to save results, ex) trained_models/
  • --load-path: Path to load checkpoint, add 'checkpoint.pht' with save_path, ex) trained_models/checkpoint.pth
  • --resume-flag: Resume the training from checkpoint loaded with load-path
  • --prune-flag: Prune VGG
  • --prune-layers: List of target convolution layers for pruning, ex) conv1 conv2
  • --prune-channels: List of number of channels for pruning the prune-layers, ex) 4 14
  • --independent-prune-flag: Prune multiple layers by independent strategy
  • --retrain-flag: Retrain the pruned nework
  • --retrain-epoch: Number of epoch for retraining pruned network
  • --retrain-lr: Number of epoch for retraining pruned network

Example Scripts

Train VGG on CIFAR-10 Data set

python main.py --train-flag --data-set CIFAR10 --vgg vgg16_bn --save-path ./trained_models/

Prune VGG by 'greedy strategy'

python main.py --prune-flag --load-path ./trained_models/check_point.pth --save-path ./trained_models/pruning_reuslts/ --prune-layers conv1 conv2 --prune-channels 1 1 

Prune VGG by 'independent strategy'

python main.py --prune-flag --load-path ./trained_models/check_point.pth --save-path ./trained_models/pruning_reuslts/ --prune-layers conv1 conv2 --prune-channels 1 1 --independent-prune-flag

Retrain the pruned network

python main.py --prune-flag --load-path ./trained_models/check_point.pth --save-path ./trained_models/pruning_reuslts/ --prune-layers conv1 --prune-channels 1 --retrain-flag --retrain-epoch 20 --retrain-lr 0.001

Results

Absolute sum of filter weights for each layer of VGG-16 trained on CIFARA-10

figure1

Pruning filters with the lowest absolute weights sum and their corresponding test accuracies on CIFAR-10

figure2

Prune and retrain for each single layer of VGG-16 on CIFAR-10

figure3

pruning_filters_for_efficient_convnets's People

Contributors

tyui592 avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.