Code Monkey home page Code Monkey logo

torch-model-compression's Introduction

pytorch自动化模型压缩工具库

介绍

pytorch自动化模型压缩工具库是针对于pytorch模型的基于ONNX静态图结构分析的自动化模型压缩工具库,用户无需理解模型的结构,即可以直接对模型完成剪枝和查找替换等操作,所有的剪枝参数分析和修改以及模块结构搜索均由该工具完成。
工具库包含有两个部分,第一个部分为torchpruner模型分析与修改工具库,通过该工具库可以自动对模型完成模块修改和通道修改等常规操作,而无需关注模型的结构细节。
第二个部分为torchslim模型压缩算法库,包含了模型重参数化、剪枝、感知量化训练等多种模型压缩算法,用户仅需给出需要被压缩的模型并定义好用于训练的hook函数,即可以对模型进行自动压缩,并输出被压缩模型产物。

requirement

  • onnx>=1.6
  • onnxruntime>=1.5
  • pytorch>=1.7
  • tensorboardX>=1.8
  • scikit-learn

安装

python setup.py install

总揽

torchpruner

torchpruner为pytorch模型分析与修改工具库,包含了以下功能:
1)自动分析模型结构和剪枝
2)特定模型结构查找与替换

torchslim

torchslim内包含了模型压缩的特定算法:
1)ACNet、CnC、ACBCorner等一系列重参数化方法 2)ResRep模型剪枝方法
3)CSGD模型剪枝方法
4)QAT量化感知训练,并将pytorch模型导出为tenosrrt模型

examples

examples文件夹主要包含了多种支持的模型的测试列表support_model,torchpruner工具库的使用示例以及torchslim工具库的使用示例。
1)support_model:支持的若干种模型
2)torchpruner:使用torchpruner剪枝和模块修改的简单示例
3)torchslim:使用torchslim 在分类模型上的简单示例

torchpruner模型修改

import torch
import torchpruner
import torchvision
#加载模型
model=torchvision.models.resnet50()

#创建ONNXGraph对象,绑定需要被剪枝的模型
graph=torchpruner.ONNXGraph(model)
##build ONNX静态图结构,需要指定输入的张量
graph.build_graph(inputs=(torch.zeros(1,3,224,224),))

#获取conv1模块对应的module
conv1_module=graph.modules['self.conv1']
#剪枝分析
result=conv1_module.cut_analysis(attribute_name='weight',index=[0,1,2,3],dim=0)
#执行剪枝操作
model,context=torchpruner.set_cut(model,result)
#对卷积模块进行剪枝操作

torchslim模型压缩

import torchslim

#predict_function的第一个参数为model,第二个参数为一个batch的data,data已经被放置到了GPU上
config['task_name']='resnet56_prune'
config['epoch']=90
config['lr']=0.1
config['prune_rate']=0.5
config['save_path']="model/save/path"
config['dataset_generator']=dataset_generator
config['predict_function']=predict_function
config['evaluate_function']=evaluate_function
config['calculate_loss_function']=calculate_loss_function

model=torch.load("model/path")

#创建solver
solver=torchslim.pruning.resrep.ResRepSolver(model,config)
#执行压缩
solver.run()

使用说明

常规用法见
examples
详细使用说明见各自文件夹README.md
torchpruner
torchslim

支持模型结构

该工具理论上支持所有复杂结构模型的剪枝操作,然而由于精力有限,仅有部分的模型和结构被测试,其他模型和结构不代表不支持,但未测试。

已测试常用模型

  • AlexNet
  • VGGNet系列
  • ResNet系列
  • MobileNet系列
  • ShuffleNet系列
  • Inception系列
  • MNASNet系列
  • Unet系列
  • FCN
  • DeepLab V3
  • ResNet/Unet QAT感知量化训练模型(QDQ节点)

已测试常用结构和操作

  • Conv/Group Conv/TransposeConv/FC
  • Pooling/Upsampling
  • BatchNorm
  • Relu/Sigmoid
  • concat/transpose/view
  • 残差结构/倒置残差结构/Inception结构/Unet结构
  • quantize_per_tensor/dequantize_per_tensor

确定暂不支持模型

  • FasterRCNN/MaskRCNN

未来重点测试和支持的模型和结构

  • RNN/LSTM/GRU
  • Transformer
  • FasterRCNN

torch-model-compression's People

Contributors

gdh1995 avatar thumig avatar zizhoujia avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

torch-model-compression's Issues

TypeError: cannot assign 'torch.cuda.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)

Hello, I met a issue as described in title, the complete error message is:

Traceback (most recent call last):
File "tasks/prune_helmet.py", line 154, in
solver.run()
File "/home/bengui/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torchpruner-0.0.1-py3.7.egg/torchslim/slim_solver.py", line 345, in run
File "/home/bengui/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torchpruner-0.0.1-py3.7.egg/torchslim/slim_solver.py", line 220, in run_hook
File "/home/bengui/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torchpruner-0.0.1-py3.7.egg/torchslim/pruning/resrep.py", line 304, in after_iteration_hook
File "/home/bengui/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torchpruner-0.0.1-py3.7.egg/torchslim/pruning/resrep.py", line 195, in prune_model
File "/home/bengui/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torchpruner-0.0.1-py3.7.egg/torchpruner/model_pruner.py", line 71, in set_cut
File "/home/bengui/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torchpruner-0.0.1-py3.7.egg/torchpruner/module_pruner/pruners.py", line 109, in set_cut
File "/home/bengui/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/nn/modules/module.py", line 801, in setattr
.format(torch.typename(value), name))
TypeError: cannot assign 'torch.cuda.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)

Can you help me fix this bug? The version of torch is 1.7.0, and cuda is 10.2. Thx!

Can the QAT-quantized model calculate inference speed?

前辈您好,首先感谢您提供的剪枝(Resrep)和性能补偿(Acnet)方法,目前我已成功实践对Transformer和Conformer的剪枝,这是对未测试模型应用的补充,想请教下您,针对QAT量化后的TensorRT格式(.trt)的模型还可以测试其算力吗(例如推理速度或吞吐量)?我尝试了几种方法但都无法成功,请问您之前做过对量化后模型的算力计算吗?

resnet50剪枝报错

你好,我在使用resnet18为主干网的retinanet时,自己使用稀疏训练后的模型剪枝会报错,我的做法是:

  1. 首先将训练好的模型计算bn的阈值得到每个bn层应该要剪枝的索引,并保存到一个dict里。
  2. 然后循环1中的dict使用torchprunner去剪枝,会遇到前面的某些层如果剪了过多通道,后面层再剪时会出现索引越界。
  3. 下面是我的部分代码。
        import torchpruner 
        # 创建ONNXGraph对象,绑定需要被剪枝的模型
        self.model.eval()
        graph = torchpruner.ONNXGraph(self.model.cpu())
        ##build ONNX静态图结构,需要指定输入的张量
        graph.build_graph(inputs=(torch.zeros(1, 3, 640, 640),))
        for i, (k, v) in enumerate(mask_dict_for_pruner.items()):
        # 获取conv1模块对应的module
            conv1_module = graph.modules[k]

            # 对前四个通道进行剪枝分析,指定对weight权重进行剪枝,剪枝前四个通道
            # weight权重out_channels对应的通道维度为0
            result = conv1_module.cut_analysis(attribute_name="weight", index=v, dim=0)

            # 剪枝执行模块执行剪枝操作,对模型完成剪枝过程.context变量提供了用于剪枝恢复的上下文
            self.model, context = torchpruner.set_cut(self.model, result)
        # 新的model即为剪枝后的模型
        print(self.model)```

请问是我的用法不对吗还是说这种先计算剪枝的索引再调用torchpruner的方法不对呢

yolov5 resrep剪枝

使用resrep对yolov5l, yolov5m剪枝,网络无法训练,一开始就就不收敛直接变nan。

About parameter Settings during training

前辈您好,首先感谢您提供的方法,我在回归模型上复现您的方法时,出现了较大的性能损失,可以向您请教几个问题吗?
(1)在加载预训练好的模型后,算法在边训练边压缩,这个过程中的训练/验证性能和loss如下图所示,在正常的模型训练中属于异常情况,请问下压缩阶段属于异常情况吗?
image
(2)在您提供的样例中,压缩达到目标压缩率后在设定的epoch之前仍将继续训练,这会在压缩的基础上有性能恢复效果吗?
十分感谢您的阅读和解答,祝您在工作和生活中一切顺利!

剪枝时前面的层数正常,最后几层在对齐masks same size时报错

您好,我在剪枝自己写的模型时前面大部分层的剪枝都是正常的,但是最后某一层出现了如下问题:
Cutting layer is: self.channel_down.compactor
AAAAAAAAAAAAAA [32, 128, 1, 400] # 以下两行是我打印出来的 mask_dict[name], return_masks[name]
BBBBBBBBBBBBBB [400, 1, 128, 32]
Traceback (most recent call last):
File "/.../torch-model-compression-main/examples/torchslim/pytorch_cifar/prune.py", line 208, in
solver.run()
File "/.../torch-model-compression-main/torchslim/slim_solver.py", line 792, in run
self.run_hook(self.iteration_end_hook)
File "/.../torch-model-compression-main/torchslim/slim_solver.py", line 662, in run_hook
function(self)
File "/.../torch-model-compression-main/torchslim/pruning/resrep.py", line 838, in after_iteration_hook
self.config["min_channels"],
File "/.../torch-model-compression-main/torchslim/pruning/resrep.py", line 722, in prune_model
"conv.weight", index=min_index, dim=0
File "/.../torch-model-compression-main/torchpruner/graph.py", line 342, in cut_analysis
return current_module.terminal_node.cut_analysis(index, dim)
File "/.../torch-model-compression-main/torchpruner/graph.py", line 246, in cut_analysis
return self.cut_analysis_with_mask(mask)
File "/.../torch-model-compression-main/torchpruner/graph.py", line 292, in cut_analysis_with_mask
[mask_dict[name], return_masks[name]]
File "/.../torch-model-compression-main/torchpruner/mask_utils.py", line 70, in combine_mask
raise RuntimeError("The input mask size should be same")
RuntimeError: The input mask size should be same

Process finished with exit code 1
我屏蔽了一些name,发现该层中很多name都是报相同的错误,可以向您请教下该如何解决吗

torchslim中在cifar10上的示例代码输出为64维,而不是10维

问题如题,想请教下为什么最后一层使用conv而不是linear,改成conv会导致最终输出维度不定

self.linear=nn.Conv2d(self.base_channel*4*block.expansion,self.base_channel*4*block.expansion,1,1,0)
# self.linear = nn.Linear(self.base_channel*4*block.expansion, num_classes)

demo中剪枝后预测结果差距很大?

我在prune_by_class.py程序中测试了一下剪枝前后模型预测结果的差别,发现结果差距很大,这种现象正常吗,还是说我的理解有问题。代码如下

import sys

sys.path.append("..")
import torch
import torchpruner
import torchvision
import numpy as np

# 以下代码示例了对每一个BN层去除其weight系数绝对值前20%小的层
inputs_sample = torch.ones(1, 3, 224, 224).to('cpu')
# 加载模型
model = torchvision.models.vgg11_bn()
result_source = model(inputs_sample)
# 创建ONNXGraph对象,绑定需要被剪枝的模型
graph = torchpruner.ONNXGraph(model)
##build ONNX静态图结构,需要指定输入的张量
graph.build_graph(inputs=(torch.zeros(1, 3, 224, 224),))

# 遍历所有的Module
for key in list(graph.modules):
    module = graph.modules[key]
    # 如果该module对应了BN层
    if isinstance(module.nn_object, torch.nn.BatchNorm2d):
        # 获取该对象
        nn_object = module.nn_object
        # 排序,取前20%小的权重值对应的index
        weight = nn_object.weight.detach().cpu().numpy()
        index = np.argsort(np.abs(weight))[: int(weight.shape[0] * 0.02)]
        result = module.cut_analysis("weight", index=index, dim=0)
        model, context = torchpruner.set_cut(model, result)
        if context:
            # graph 存放了各层参数和输出张量的 numpy.ndarray 版本,需要更新
            graph = torchpruner.ONNXGraph(model)  # 也可以不重新创建 graph
            graph.build_graph(inputs=(torch.zeros(1, 3, 224, 224),))

# 新的model即为剪枝后的模型
print(model)

result_prune = model(inputs_sample)
print(f"剪枝前结果:{result_source.sum()}")
print(f"剪枝后结果:{result_prune.sum()}")
print(f"数据差距{(abs(result_source-result_prune)).sum()}")

关于resnet和自建模型prune时遇到的相同报错

前辈您好,最近学习了您ACNet、ResRep、RepVGG几篇文章,您真的有非凡的创造力!但我使用您在examples中给出的prune.py示例对resnet56和自己模型的压缩中遇到了相同的报错,可以向您请教下吗?
(1)对resnet56压缩时的报错如下:
Traceback (most recent call last):
File "/home3/xxx/torch-model-compression-main/examples/torchslim/pytorch_cifar/prune.py", line 294, in
solver.run()
File "/home3/xxx/torch-model-compression-main/torchslim/slim_solver.py", line 790, in run
self.run_hook(self.iteration_end_hook)
File "/home3/xxx/torch-model-compression-main/torchslim/slim_solver.py", line 660, in run_hook
function(self)
File "/home3/xxx/torch-model-compression-main/torchslim/pruning/resrep.py", line 827, in after_iteration_hook
current_graph.build_graph(graph_inputs)
File "/home3/xxx/torch-model-compression-main/torchpruner/graph.py", line 622, in build_graph
operator_dict[operator].fill_shape()
File "/home3/xxx/torch-model-compression-main/torchpruner/operator/operator.py", line 191, in fill_shape
+ "'"
RuntimeError: Fail to predict the shape on operator name: 'self.bn1.BatchNormalization', type: 'BatchNormalization'
(2)对自己构建模型的报错如下:
Traceback (most recent call last):
File "/home3/xxx/.pycharm_helpers/pydev/pydevd.py", line 1434, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "/home3/xxx/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/home3/gaoya/torch-model-compression-main/examples/torchslim/pytorch_cifar/prune.py", line 203, in
solver.run()
File "/home3/xxx/torch-model-compression-main/torchslim/slim_solver.py", line 790, in run
self.run_hook(self.iteration_end_hook)
File "/home3/xxx/torch-model-compression-main/torchslim/slim_solver.py", line 660, in run_hook
function(self)
File "/home3/xxx/torch-model-compression-main/torchslim/pruning/resrep.py", line 827, in after_iteration_hook
current_graph.build_graph(graph_inputs)
File "/home3/xxx/torch-model-compression-main/torchpruner/graph.py", line 622, in build_graph
operator_dict[operator].fill_shape()
File "/home3/xxx/torch-model-compression-main/torchpruner/operator/operator.py", line 191, in fill_shape
+ "'" RuntimeError: Fail to predict the shape on operator name: 'self.net_encoder.encoder_inp_lnorm.BatchNormalization', type: 'BatchNormalization'
我的理解是在构建静态图时无法正确预测BN层中节点的形状,我的系统是Linux,torch==1.13.1,torchvision==0.14.1,可以向您请教下关于这个问题吗?

examples中prune.py运行报错

运行到第六个epoch时报错如下:
Traceback (most recent call last):
File "D:\Pycharm\PyCharm 2021.2.3\plugins\python\helpers\pydev\pydevd.py", line 1483, in _exec
pydev_imports.execfile(file, globals, locals) # execute the script
File "D:\Pycharm\PyCharm 2021.2.3\plugins\python\helpers\pydev_pydev_imps_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "E:/python project/torch-model-compression-main/torch-model-compression-main/examples/torchslim/pytorch-cifar/prune.py", line 94, in
solver.run()
File "D:\Anoconda\lib\site-packages\torchpruner-0.1.0-py3.8.egg\torchslim\slim_solver.py", line 354, in run
File "D:\Anoconda\lib\site-packages\torchpruner-0.1.0-py3.8.egg\torchslim\slim_solver.py", line 224, in run_hook
File "D:\Anoconda\lib\site-packages\torchpruner-0.1.0-py3.8.egg\torchslim\pruning\resrep.py", line 297, in after_iteration_hook
File "D:\Anoconda\lib\site-packages\torchpruner-0.1.0-py3.8.egg\torchpruner\graph.py", line 622, in build_graph
File "D:\Anoconda\lib\site-packages\torchpruner-0.1.0-py3.8.egg\torchpruner\operator\operator.py", line 186, in fill_shape
"""
RuntimeError: Fail to predict the shape on operator name: 'self.bn1.BatchNormalization', type: 'BatchNormalization'
是什么原因呢?

About how to import self-built models for compression

前辈您好,我是模型压缩的初学者,您的方法相当有创造性,目前我在尝试用您的算法剪枝量化自己构建的模型,但与直接调用torchvision.models.resnet50()不同是,报错显示无法找到对应的层,具体报错如下:AttributeError: type object 'Net' has no attribute 'enc_conv1' ;RuntimeError: Can not find the enc_conv1 in model
我目前的具体操作是:
(1)from model.sesnet import Net来调用自建模型中的类(其中sesnet是构建模型的py文件,Net是模型的类)
(2)直接用model=Net来构建模型
(3)根据名称获取nn.Module对象:conv1 = tools.get_object(model, "self.enc_conv1")
debug中显示model={type}<class 'model.sesnet.Net'>,而不是像例子中获得构建好的resnet,关于这个问题可以向您请教下吗

剪枝分割网络报错-bisenetv2

File "examples/torchpruner/prune_by_class_bisenetv2.py", line 39, in <module>
    model, context = torchpruner.set_cut(model, result)
  File "site-packages/torchpruner-0.0.1-py3.8.egg/torchpruner/model_pruner.py", line 71, in set_cut
  File "site-packages/torchpruner-0.0.1-py3.8.egg/torchpruner/module_pruner/pruners.py", line 188, in set_cut
  File "site-packages/torchpruner-0.0.1-py3.8.egg/torchpruner/module_pruner/pruners.py", line 60, in set_cut
  File "site-packages/torchpruner-0.0.1-py3.8.egg/torchpruner/module_pruner/prune_function.py", line 42, in set_cut_tensor
IndexError: index 78 is out of bounds for dimension 0 with size 78

是否有相关的文章支持

前辈您好,我是模型压缩相关工作的初学者,您的工作十分有创造性,请问本代码是否有相关的paper?

关于ResRep模型性能对比

你好,最近刚好用到ResRep剪枝,我看本框架和原始ResRep论文的实现方式稍有差异, 本框架直接移除选中的卷积通道层但原论文是对选中通道施加惩罚因子使其逐渐趋向0,或者说反向传播过程中对保留和移除卷积通道层施加不同的梯度更新策略。
if isinstance(nn_object, Compactor): lasso_grad = value.data * ((value.data ** 2).sum(dim=(1, 2, 3), keepdim=True) ** (-0.5)) value.grad.data.add_(self.config["lasso_decay"], lasso_grad)
请问实际测试中有比对两种方案的性能差异么~

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.