Code Monkey home page Code Monkey logo

qat-acs's Introduction

QAT-ACS: Efficient Quantization-aware Training with Adaptive Coreset Selection (TMLR) [Paper] [OpenReview]

This is the pytorch implementation of Transactions on Machine Learning Research (TMLR) paper "Efficient Quantization-aware Training with Adaptive Coreset Selection" [Paper]

Introduction

In this work, we propose a new angle through the coreset selection to improve the training efficiency of quantization-aware training. Our method can achieve an accuracy of 68.39% of 4-bit quantized ResNet-18 on the ImageNet-1K dataset with only a 10% subset, which has an absolute gain of 4.24% compared to the previous SoTA.

Datasets and Models

For dataset, we have evaluate our method on ImageNet-1K and CIFAR-100. There are other popular computer vision datasets including MNIST, QMNIST, FashionMNIST, SVHN, CIFAR10, and TinyImageNet. For models, we have evaluate our method on ResNet-18 and MobileNet-V2. There are other network architectures implementation such as MLP, LeNet, AlexNet, VGG, Inception-v3, WideResNet and MobileNet-V3, and quantization-aware training has not been implemented on these models currently.

Results

ResNet-18 on ImageNet-1K (4-bit W/A Quantization)

MobileNet-V2 on CIFAR-100 (2-bit W Quantization)

Run

Requirements

pip install -r requirements.txt

Dataset

Download ImageNet LSVRC 2012 dataset following the official PyTorch ImageNet training code

Pretrained Model

For ResNet-18 experiments, PyTorch Official Pretrained ResNet-18 will be automatically loaded. For MobileNet-V2, please use --resume ./pretrained_model/CIFAR100_Mobilenetv2_72.56.ckpt.

Getting Started

QAT of 4-bit ResNet-18 with our ACS and training on the ImageNet-1K coreset with fraction 0.1.

CUDA_VISIBLE_DEVICES=0,1 python main.py --fraction 0.1 --dataset ImageNet --data_path /datasets-to-imagenet --num_exp 1 --workers 8 --optimizer Adam -se 5 --selection ACS --adaptive cosine --model QResNet18 --bitwidth 4 --lr 1.25e-3 --batch 512 --teacher resnet101 --epochs 120 --data_update_epochs 10 --log ./logs/logs_4bit_10_cosine_acs_update10.txt

QAT of 2-bit MobileNetV2 with our ACS and training on the CIFAR-100 coreset with fraction 0.1 (repeat 5 times).

CUDA_VISIBLE_DEVICES=0 python main.py --fraction 0.1 --dataset CIFAR100 --model QMobilenetv2 --selection ACS --num_exp 5 --epochs 200 --min_lr 0  --lr 0.01 --weight_decay 5e-4 --batch-size 256 --scheduler LambdaLR --adaptive cosine --resume ./pretrained_model/CIFAR100_Mobilenetv2_72.56.ckpt --bitwidth 2 --log ./logs/lsq_2bit_mobilenetv2_cifar100_ACS10new.log

Additional Dataset and Model

Our code is mainly based on DeepCore, which is highly modular and scalable. It allows to add new architectures, datasets and selection methods easily, to help coreset methods to be evaluated in a richer set of scenarios, and also to facilitate new methods for comparison. Here is an example for datasets. To add a new dataset, you need implement a function whose input is the data path and outputs are number of channels, size of image, number of classes, names of classes, mean, std and training and testing dataset inherited from torch.utils.data.Dataset.

from torchvision import datasets, transforms


def MNIST(data_path):
    channel = 1
    im_size = (28, 28)
    num_classes = 10
    mean = [0.1307]
    std = [0.3081]
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
    dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
    dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
    class_names = [str(c) for c in range(num_classes)]
    return channel, im_size, num_classes, class_names, mean, std, dst_train, dst_test

This is an example for implementing network architecture.

import torch.nn as nn
import torch.nn.functional as F
from torch import set_grad_enabled
from .nets_utils import EmbeddingRecorder


class MLP(nn.Module):
    def __init__(self, channel, num_classes, im_size, record_embedding: bool = False, no_grad: bool = False,
                 pretrained: bool = False):
        if pretrained:
            raise NotImplementedError("torchvison pretrained models not available.")
        super(MLP, self).__init__()
        self.fc_1 = nn.Linear(im_size[0] * im_size[1] * channel, 128)
        self.fc_2 = nn.Linear(128, 128)
        self.fc_3 = nn.Linear(128, num_classes)

        self.embedding_recorder = EmbeddingRecorder(record_embedding)
        self.no_grad = no_grad

    def get_last_layer(self):
        return self.fc_3

    def forward(self, x):
        with set_grad_enabled(not self.no_grad):
            out = x.view(x.size(0), -1)
            out = F.relu(self.fc_1(out))
            out = F.relu(self.fc_2(out))
            out = self.embedding_recorder(out)
            out = self.fc_3(out)
        return out

To implement the new coreset method, you need to inherit the new method from the CoresetMethod class and return the selected indices via the select method.

class CoresetMethod(object):
    def __init__(self, dst_train, args, fraction=0.5, random_seed=None, **kwargs):
        if fraction <= 0.0 or fraction > 1.0:
            raise ValueError("Illegal Coreset Size.")
        self.dst_train = dst_train
        self.num_classes = len(dst_train.classes)
        self.fraction = fraction
        self.random_seed = random_seed
        self.index = []
        self.args = args

        self.n_train = len(dst_train)
        self.coreset_size = round(self.n_train * fraction)

    def select(self, **kwargs):
        return

Acknowledgement and Citation

The code is mainly based on DeepCore. If you find our code is helpful for your research, please cite:

@article{huang2023efficient,
    title={Efficient Quantization-aware Training with Adaptive Coreset Selection}, 
    author={Xijie Huang, Zechun Liu, Shih-yang Liu, Kwang-Ting Cheng},
    year={2023},
    archivePrefix={arXiv},
}

If you have any questions, feel free to contact Xijie Huang ([email protected])

qat-acs's People

Contributors

huangowen avatar

Stargazers

Jeff Carpenter avatar  avatar  avatar Júlio César avatar 千古兴亡知衡权 avatar Haomin Li avatar  avatar Thomson avatar  avatar Salar Shakib avatar CV Chiranthan avatar  avatar Peyton avatar tianchen avatar lixc avatar cccpr avatar Yiqian He avatar Jinyu Bai avatar  avatar Zechun Liu avatar LIU, Shih-Yang avatar Wenyu Jiang avatar  avatar  avatar

Watchers

 avatar

qat-acs's Issues

Question about coreset selection

Thank you for sharing your excellent work! I have a question about coreset selection. I noticed that in Algorithm 1, all the samples are re-sorted according to dACS and then reconstituted in the subset. It appears that the coreset selection is dynamic, akin to dropping out some unimportant samples during the training phase (the dropped ones can be reselected). However, some of the comparison methods are static (the dropping is permanent). Is the comparison reasonable?

I'm looking forward to your reply!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.