Code Monkey home page Code Monkey logo

gate-decorator-pruning's Introduction

Gate Decorator (NeurIPS 2019)

License Python 3.6

This repo contains required scripts to reproduce results from paper:

Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks


Requirements

python 3.6+ and PyTorch 1.0+

Installation

  1. clone the code
  2. pip install --upgrade git+https://github.com/youzhonghui/pytorch-OpCounter.git
  3. pip install tqdm

How to use

(1). Notebook (ResNet-56)

In the run/resnet-56 folder, we provide an example which reduces the FLOPs of resnet-56 by 70%, but still maintains 93.15% accuracy on CIFAR-10:

  1. The run/resnet-56/resnet56_prune.ipynb prunes the network with Tick-Tock framework.
  2. The run/resnet-56/finetune.ipynb shows how to finetune the pruned network to get better results.

If you want to run the demo code, you may need to install jupyter notebook

(2). Command line (VGG-16)

In the run/vgg16 folder, we provide an example executed by command line, which reduces the FLOPs of VGG-16 by 90% (98% parameters), and keep 92.07% accuracy on CIFAR-10.

The instructions can be found here

(3). Save and load the pruned model

In the run/load_pruned_model/ folder, we provide an example shows how to save and load a pruned model (VGG-16 with only 0.3M float parameters).

Todo

  • Basic running example.
  • PyTorch 1.2 compatibility test.
  • The command-line execution demo.
  • Save and load the pruned model.
  • ResNet-50 pruned model.

Citation

If you use this code for your research, please cite our paper:

@inproceedings{zhonghui2019gate,
  title={Gate Decorator: Global Filter Pruning Method for Accelerating Deep Convolutional Neural Networks},
  author={Zhonghui You and
          Kun Yan and
          Jinmian Ye and
          Meng Ma and
          Ping Wang},
  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
  year={2019}
}

gate-decorator-pruning's People

Contributors

youzhonghui avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

gate-decorator-pruning's Issues

最后一层是卷积层,该如何去操作?谢谢

您好,网络的最后一层是卷积层,请问如果按照下面这种方式去处理最后一层,有什么问题吗?
class FinalConvLayerObserver(Meltable):

def __init__(self, conv2d):
    super(FinalLayerObserver, self).__init__()
    assert isinstance(conv2d, nn.Conv2d)
    self.conv2d = conv2d

    self.in_mask =nn.init.constant_(conv2d.weight, 0).to('cpu')# torch.zeros(conv2d.weight,0).to('cpu')
    self.f_hook = conv2d.register_forward_hook(self._forward_hook)


def _forward_hook(self, m, _in, _out):
    x = _in[0]
    self.in_mask += x.data.abs().cpu().sum(0, keepdim=True).view(-1)

def forward(self, x):
    return self.conv2d(x)

def melt(self):
    with torch.no_grad():
        replacer = nn.conv2d(int((self.in_mask != 0).sum()), self.conv2d.weight).to(
            self.conv2d.weight.device)

        replacer.weight.set_(self.conv2d.weight[:, self.in_mask != 0])

        replacer.bias.set_(self.conv2d.bias)
    return replacer

Can you provide the pruned resnet50 model

Dear author,
GBN elegantly relieve restraint from the shortcut in ResNet. It gets good results on resnet56. I really like your work.But I encountered problems while experimenting on rennet50,
Firstly, the BASE_FLOPS I calculate is 4135.790M instead of your 4142.714M,
In addition, the test accuracy of the first Tock is lower than yours.
it's appreciated if you can provide your pruned model or give some suggestions

怎么正确加载剪枝完的pth模型呢

剪枝完保存的pth文件加载时权重名都变了,还有对应的通道数,这个咋看,怎么正确加载剪枝完的pth模型呢

Unexpected key(s) in state_dict: "bn1.g", "bn1.area", "bn1.score", "bn1.bn_mask", "bn1.bn.weight", "bn1.bn.bias", "bn1.bn.running_mean", "bn1.bn.running_var", "bn1.bn.num_batches_tracked", "layer1.0.bn1.g", "layer1.0.bn1.area", "layer1.0.bn1.score", "layer1.0.bn1.bn_mask", "layer1.0.bn1.bn.weight", "layer1.0.bn1.bn.bias", "layer1.0.bn1.bn.running_mean", "layer1.0.bn1.bn.running_var", "layer1.0.bn1.bn.num_batches_tracked", "layer1.0.bn2.g", "layer1.0.bn2.area", "layer1.0.bn2.score", "layer1.0.bn2.bn_mask", "layer1.0.bn2.bn.weight", "layer1.0.bn2.bn.bias", "layer1.0.bn2.bn.running_mean", "layer1.0.bn2.bn.running_var", "layer1.0.bn2.bn.num_batches_tracked", "layer1.0.shortcut.1.g", "layer1.0.shortcut.1.area", "layer1.0.shortcut.1.score", "layer1.0.shortcut.1.bn_mask", "layer1.0.shortcut.1.bn.weight", "layer1.0.shortcut.1.bn.bias", "layer1.0.shortcut.1.bn.running_mean", "layer1.0.shortcut.1.bn.running_var", "layer1.0.shortcut.1.bn.num_batches_tracked", "layer1.1.bn1.g", "layer1.1.bn1.area", "layer1.1.bn1.score", "layer1.1.bn1.bn_mask", "layer1.1.bn1.bn.weight",

cloned.module.linear = FinalLinearObserver(cloned.module.linear)

Hi,
我自己的模型最后一层不是linear,而是自己构建的其他模块名字。我应该怎么替换它呢?
如果我直接注释这一行代码的话,会出现如下错误:
ndexError: index 0 is out of bounds for dimension 0 with size 0
谢谢

MACs vs. Flops

Hello, I have a question about the way you counted FLOPs.
According to your code prune/utils.py, you used thop.
However, the code is recently corrected that it actually measured MACs rather than FLOPs.
(Please refer to Lyken17/pytorch-OpCounter#37
and Lyken17/pytorch-OpCounter@41fd65c)

In this case, are the FLOPs values in your paper still valid?

Thank you.
Jung

如果我的输出是多个,并且是我自己定义的函数,我应该怎么修改呢?

def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.deconv_layers(x) #自己定义的一个模块
    #print("x:",x)

    ret = {}
    
    ret['hm'] = self.hm(x) #输出1
    ret['reg'] = self.reg(x)#输出2
    ret['wh'] = self.wh(x)#输出3
   return [ret]

如上,我应该如何修改代码才能让最后一层只改变输入大小?谢谢
输出1,2,3是我模型输出的三个分支
output= {'hm': 13, 'wh': 2, 'reg': 2}

Resnet50 for Imagenet

Hello,

I am getting an error on mismatched Batch normalization layer when pruning resnet50 for imagenet dataset.

Could you please provide Resnet50 model definition and how it is used for imagenet dataset pruning?

AttributeError: 'Resnet50' object has no attribute 'classifier'

Hi, I tried to modify your code by using the imagenet's resnet50 on cifar10. I followed the vgg16 example and changed it to resnet50. However, I got the following error when running the pruning, CUDA_VISIBLE_DEVICES=0 python ./run/resnet-50/resnet50_prune_demo.py --config ./run/resnet-50/prune.json.

File "/home/ccng/anaconda3/envs/gatedeco/lib/python3.7/site-packages/torch/nn/modules/module.py", line 591, in getattr
type(self).name, name))
AttributeError: 'Resnet50' object has no attribute 'classifier'

May I know how to resolve this error? Thanks.

question

When i use my own resnet56 model, it got a very poor performance at the start Tock step.
I use SGD to train my model with lr, moment , weight decay set to 0.1, 0.9, 1e-4. The learning rate is divide by 0.1 at epoch 80 and 120. The training setting is same as the one you mentioned in your paper.
屏幕快照 2019-10-07 下午3 37 02
屏幕快照 2019-10-07 下午3 36 35
During training process, when the lr is decayed to 0.01, the model can get at least 92.8% test acc. But your first tock step shows that the acc is lower than 91% when the lr is 0.01, I think it is impossible.

convolution layer with cardinality

Dear author,

Thank for this impressive piece of work.
How to implement convolution layer with cardinality ?

in universal.py we have:

def melt(self):
        if self.conv.groups == 1:
            groups = 1
        elif self.conv.groups == self.conv.out_channels:
            groups = int((self.out_mask != 0).sum())
        else:
            assert False

        replacer = nn.Conv2d(
            in_channels = int((self.in_mask != 0).sum()),
            out_channels = int((self.out_mask != 0).sum()),
            kernel_size = self.conv.kernel_size,
            stride = self.conv.stride,
            padding = self.conv.padding,
            dilation = self.conv.dilation,
            groups = groups,
            bias = (self.conv.bias is not None)
        ).to(self.conv.weight.device)

        with torch.no_grad():
            if self.conv.groups == 1:
                replacer.weight.set_(self.conv.weight[self.out_mask != 0][:, self.in_mask != 0])
            else:
                replacer.weight.set_(self.conv.weight[self.out_mask != 0])
            if self.conv.bias is not None:
                replacer.bias.set_(self.conv.bias[self.out_mask != 0])
        return replacer

if the convolution layer have a cardinality (like in several modern model), we get assert False.
How to implement cardinality for convolution layers ?

TypeError: Object of type 'Config' is not JSON serializable

Hi,

Thank you for your excellent work.

When running your "Command line (VGG-16) demo", I get a JSON encoder error in logger.py at line 63 under the save_network function - "TypeError: Object of type 'Config' is not JSON serializable". Can you please suggest a fix to this issue?

PS- I am using python 3.6.10 and PyTorch 1.0.1.post2.

Best,
Niam

如果我的网络最后一层是卷积层,我应该怎么去修改Conv2dObserver中的代码呢?谢谢

class Conv2dObserver(Meltable):
def init(self, conv):
super(Conv2dObserver, self).init()
assert isinstance(conv, nn.Conv2d)
self.conv = conv
self.in_mask = torch.zeros(conv.in_channels).to('cpu')
self.out_mask = torch.zeros(conv.out_channels).to('cpu')
self.f_hook = conv.register_forward_hook(self._forward_hook)

def extra_repr(self):
    return '(%d, %d) -> (%d, %d)' % (self.conv.in_channels, self.conv.out_channels, int((self.in_mask != 0).sum()), int((self.out_mask != 0).sum()))

def _forward_hook(self, m, _in, _out):
    x = _in[0]
    self.in_mask += x.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)

def _backward_hook(self, grad):
    self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)
    new_grad = torch.ones_like(grad)
    return new_grad

def forward(self, x):
    output = self.conv(x)
    noise = torch.zeros_like(output).normal_()
    output = output + noise
    if self.training:
        output.register_hook(self._backward_hook)
    return output

def melt(self):
    if self.conv.groups == 1:
        groups = 1
    elif self.conv.groups == self.conv.out_channels:
        groups = int((self.out_mask != 0).sum())
    else:
        assert False

    print("in_channels:",int((self.in_mask != 0).sum()))
    print("out_channels:", int((self.out_mask != 0).sum()))
    print("kernel_size:", self.conv.kernel_size)
    print("stride:", self.conv.stride)
    print("padding:", self.conv.padding)
    print("dilation:", self.conv.dilation)
    print("groups:", groups)
    print("bias:", (self.conv.bias is not None))

    replacer = nn.Conv2d(
        in_channels = int((self.in_mask != 0).sum()),
        out_channels = int((self.out_mask != 0).sum()),
        kernel_size = self.conv.kernel_size,
        stride = self.conv.stride,
        padding = self.conv.padding,
        dilation = self.conv.dilation,
        groups = groups,
       bias = (self.conv.bias is not None)
    ).to(self.conv.weight.device)

    with torch.no_grad():
        if self.conv.groups == 1:
            replacer.weight.set_(self.conv.weight[self.out_mask != 0][:, self.in_mask != 0])
        else:
            replacer.weight.set_(self.conv.weight[self.out_mask != 0])
        if self.conv.bias is not None:
            replacer.bias.set_(self.conv.bias[self.out_mask != 0])
    return replacer

@classmethod
def transform(cls, net):
    r = []
    def _inject(modules):
        keys = modules.keys()
        for k in keys:
            if len(modules[k]._modules) > 0:
                _inject(modules[k]._modules)
            if isinstance(modules[k], nn.Conv2d):
                modules[k] = Conv2dObserver(modules[k])
                r.append(modules[k])
    _inject(net._modules)
    return r

Why is the pruning not working on Multi-GPU?

Hello,

I am using 2 V100 GPUs. I trying to apply the pruning technique on pre-trained resnet50 model for imagenet dataset.

The pruning is running. However, it is using only 1 of the two GPUs. I am running the python script in the following way (I made myself the resnet50_prune.py without changing any of the pruning code).
CUDA_VISIBLE_DEVICES=0,1 python ./run/resnet-50/resnet50_prune.py

Question about backward hook

Thank you for work and sharing this code.

I was wondering the meaning of new_grad = torch.ones_like(grad) in the following code in universal.py.


    def _backward_hook(self, grad):
        self.out_mask += grad.data.abs().sum(2, keepdim=True).sum(3, keepdim=True).cpu().sum(0, keepdim=True).view(-1)
        new_grad = torch.ones_like(grad)
        return new_grad

How to save pruned model?

I run the examples ,
for point in [i for i in list(flops_save_points)]:
if flops_ratio <= point:
torch.save(pack.net.module.state_dict(), './logs/fashion_ticktock/%s.pth' % str(point)),

But ,seved model size never change to small.

关于Conv2dObserver里out_mask的疑问

感谢youzhonghui提供如此好的剪枝方法,请问Conv2dObserver里,out_mask是通过_backward_hook累加梯度得到的,后面的通道选择上需要选取self.out_mask != 0的通道,由于梯度是float类型,这样子的话累加和怎么等于0呢?谢谢

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.