Code Monkey home page Code Monkey logo

pytorch-quantization-demo's Introduction

pytorch-quantization-demo

A simple network quantization demo using pytorch from scratch. This is the code for my tutorial about network quantization written in Chinese.

也欢迎感兴趣的读者关注我的知乎专栏:大白话模型量化

pytorch-quantization-demo's People

Contributors

genggng avatar jermmy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-quantization-demo's Issues

FileNotFoundError: [Errno 2] No such file or directory: 'ckpt/mnist_cnnbn.pt'

when i execute python quantization_aware_training.py, it outputs the error that:

Traceback (most recent call last):
  File "quantization_aware_training.py", line 81, in <module>
    model.load_state_dict(torch.load('ckpt/mnist_cnnbn.pt', map_location='cpu'))
  File "/home/jyg/anaconda3/envs/mixcompre/lib/python3.8/site-packages/torch/serialization.py", line 791, in load
    with _open_file_like(f, 'rb') as opened_file:
  File "/home/jyg/anaconda3/envs/mixcompre/lib/python3.8/site-packages/torch/serialization.py", line 271, in _open_file_like
    return _open_file(name_or_buffer, mode)
  File "/home/jyg/anaconda3/envs/mixcompre/lib/python3.8/site-packages/torch/serialization.py", line 252, in __init__
    super().__init__(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: 'ckpt/mnist_cnnbn.pt'

跑fashion mnist数据集,带BN的网络量化精度损失较大

你好,我这边在跑fashion mnist数据集时,带BN的网络浮点精度时88%,量化精度只有70%,精度损失比较大。不带BN的网络量化精度几乎无损。请问以下您有遇到过这种情况吗? 或者可能是什么原因呢? 谢谢!

q_model无法在cpu和gpu之间正确转移

问题描述

量化模型在cpu和gpu之间转移时,有成员变量不能够随之转移。

model.cuda()
model.qconv1.M.device    #无法跟着转移到gpu

原因

QConv2d和QConvBNReLU中,self.M是非module成员,不会随着model转移。

解决方法

将self.M注册为buffer

M = torch.tensor([], requires_grad=False) # 将M注册为buffer
self.register_buffer('M', M)  

# in freeze():
self.M.data = (self.qw.scale*self.qi.scale / self.qo.scale).data

关于1bit的准确率问题

我在自己的量化卷积神经网络中(将每一次反馈的权重和偏置做量化),1bit条件下在FashionMNIST数据集下也能到50%左右的准确率,代码或许出现了问题?

量化输出的模型更大

执行train.py生成的mnist_cnnbn.pt有105.7kb 执行quantization_aware_training.py得到的mnist_cnnbn_qat.pt有121.6kb,请问大佬正常吗
量化应该是能减小模型提升推理速度

想请教一些问题

您好, 我有两个问题想请教一下:
(模型和训练MNIST的代码直接使用, 没做改动)

  1. 做完QAT后获得的权重, 我发现在经过scale变换后, 仍然不是一个很贴近整数的浮点数, round()之后就会产生不小误差。 请问这是因为QAT算法无法实现的效果, 还是说无所谓能不能变成一个很贴近整数的浮点数?
>>> new_weight.shape
torch.Size([40, 1, 3, 3])
>>> new_weight[0]
tensor([[[ 0.6816, -0.6561, -0.1549],
         [ 0.5218, -1.5822,  0.5984],
         [-0.1848,  0.4262,  0.2918]]], device='cuda:0')
>>> new_weight[0]/qat['qconv1.qw.scale']
tensor([[[  54.2636,  -52.2335,  -12.3310],
         [  41.5433, -125.9577,   47.6389],
         [ -14.7157,   33.9326,   23.2337]]], device='cuda:0')
>>> (new_weight[0]/qat['qconv1.qw.scale']).round()
tensor([[[  54.,  -52.,  -12.],
         [  42., -126.,   48.],
         [ -15.,   34.,   23.]]], device='cuda:0')

2.我在quantization_aware_training.py最后加了以下代码, 在做完QAT之后, 想比较模型“freeze前使用直接inference“以及”freeze后使用quantize_inference()“的结果

    # test_loader set batch_size = 1
    miss_cnt = 0
    result = torch.zeros(10).cuda()
    correct_freeze=0
    correct_before_freeze=0
    for i, (data, target) in enumerate(test_loader, 1):
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            output1 = model_freeze.quantize_inference(data)[0]
            output2 = model_before_freeze(data)[0]
        result+=(output1-output2).abs()
        pred1 = output1.argmax().item()
        pred2 = output2.argmax().item()
        miss_cnt += not(pred1==pred2)
        correct_freeze += pred1==target.item()
        correct_before_freeze += pred2==target.item()

    result=result/len(test_loader.dataset)
    print(f'qat model correct {correct_freeze}\nori model correct {correct_before_freeze}')
    print(f'qat miss ori {miss_cnt}')
    print(f'diff {result}')
qat model correct 9881
ori model correct 9880
qat miss ori 7
diff tensor([0.0587, 0.0569, 0.0789, 0.0846, 0.0658, 0.0600, 0.0645, 0.0644, 0.0987, 0.0617], device='cuda:0')

我发现两者的output差距并不小, 请问这么比较有意义吗?或者说我是否应该期望量化后的模型和原始模型在最后数值的输出上也要很接近?

how to set qo for Sigmoid layer

Hi,
first of all, thank you so much for this very comprehensive codebase.
I was wondering: when using the newly added QSigmoid layer, how do I set the qo parameter when freezing?
Thank you so much in advance.

遇到一个神奇的 Bug

在 quantization_aware_training.py 88行,full_inference(model, test_loader),将这一句复制一遍,即这一句跑两次,会发现量化后掉点10%,这一步就是执行了推理操作,但是没有改变参数,调试了一上午,也没有发现问题在哪。带BN的网络有问题,不带BN的网络是正常的,但是经过参数比对,这两部操作后,BN的参数也没有发生变化

关于load量化好模型的问题

hello码主,我测试的时候发现不用quantize_forward矫正模型
直接load已量化模型的模型(先freeze再save的已量化模型),发现模型dict对不上,请问博主是否有教程如何直接load已量化模型做推理而不使用train做矫正吗~

当zero point超出qmin, qmax范围时是否应该扩展原范围(rmin, rmax)

我将一个[1, 2]的向量进行8位量化之后反量化得到[1, 1],精度损失很大,如果把calcScaleZeroPoint函数改成下面这样效果就会好很多:

def calcScaleZeroPoint(rmin, rmax, num_bits=8):
    qmin = 0
    qmax = 2 ** num_bits -1
    scale = float((rmax - rmin) / (qmax - qmin))
    
    zero_point = qmax - rmax / scale

    #when out of range, then recalc scale
    if zero_point < qmin:
        zero_point = qmin
        scale = float((rmax - 0) / (qmax - qmin))
    elif zero_point > qmax:
        zero_point = qmax
        scale = float((0 - rmin) / (qmax - qmin))

    zero_point = int(zero_point)

    return scale, zero_point

量化VGG16模型

你好,我想要量化vgg16模型,我应该改那部分呀,module.py文件里改的是不是很多

post_training 精度

post training 在fp32精度训练得到的 99% 的 模型 在 int8 推理时只有 88% 的准确率,没有修改代码,全原封

QAT找不到模型文件

你好。我运行QAT的时候,发现找不到模型文件:
Test set: Full Model Accuracy: 99%

Quantization bit: 8
Traceback (most recent call last):
File "/Users/cheungbh/Desktop/pytorch-quantization-demo/quantization_aware_training.py", line 100, in
model.load_state_dict(torch.load(load_quant_model_file))
File "/Users/cheungbh/opt/anaconda3/envs/py36/lib/python3.6/site-packages/torch/serialization.py", line 571, in load
with _open_file_like(f, 'rb') as opened_file:
File "/Users/cheungbh/opt/anaconda3/envs/py36/lib/python3.6/site-packages/torch/serialization.py", line 229, in _open_file_like
return _open_file(name_or_buffer, mode)
File "/Users/cheungbh/opt/anaconda3/envs/py36/lib/python3.6/site-packages/torch/serialization.py", line 210, in init
super(_open_file, self).init(open(name, mode))
FileNotFoundError: [Errno 2] No such file or directory: 'ckpt/mnist_cnnbn_qat.pt'

可是这个文件不是应该运行完才保存的嘛。这要怎么修改

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.