Code Monkey home page Code Monkey logo

fkd's Introduction

๐Ÿš€ FKD: A Fast Knowledge Distillation Framework for Visual Recognition

Official PyTorch implementation of paper A Fast Knowledge Distillation Framework for Visual Recognition (ECCV 2022, ECCV paper, arXiv), Zhiqiang Shen and Eric Xing.

Abstract

Knowledge Distillation (KD) has been recognized as a useful tool in many visual tasks, such as the supervised classification and self-supervised representation learning. While the main drawback of a vanilla KD framework lies in its mechanism that most of the computational overhead is consumed on forwarding through the giant teacher networks, which makes the whole learning procedure a low-efficient and costly manner.

๐Ÿš€ Fast Knowledge Distillation (FKD) is a novel framework that addresses the low-efficiency drawback, simulates the distillation training phase, and generates soft labels following the multi-crop KD procedure, meanwhile enjoying a faster training speed than other methods. FKD is even more efficient than the conventional classification framework when employing multi-crop in the same image for data loading. It achieves 80.1% (SGD) and 80.5% (AdamW) using ResNet-50 on ImageNet-1K with plain training settings. This work also demonstrates the efficiency advantage of FKD on the self-supervised learning task.

Citation

@article{shen2021afast,
      title={A Fast Knowledge Distillation Framework for Visual Recognition}, 
      author={Zhiqiang Shen and Eric Xing},
      year={2021},
      journal={arXiv preprint arXiv:2112.01528}
}

What's New

Supervised Training

Preparation

FKD Training on CNNs

To train a model, run train_FKD.py with the desired model architecture and the path to the soft label and ImageNet dataset:

python train_FKD.py -a resnet50 --lr 0.1 --num_crops 4 -b 1024 --cos --softlabel_path [soft label path] [imagenet-folder with train and val folders]

For --softlabel_path, use format as ./FKD_soft_label_500_crops_marginal_smoothing_k_5/imagenet.

Multi-processing distributed training on a single node with multiple GPUs:

python train_FKD.py \
--dist-url 'tcp://127.0.0.1:10001' \
--dist-backend 'nccl' \
--multiprocessing-distributed --world-size 1 --rank 0 \
-a resnet50 --lr 0.1 --num_crops 4 -b 1024 --cos -j 32 \
--save_checkpoint_path ./FKD_nc_4_res50_plain \
--softlabel_path [soft label path, e.g., ./FKD_soft_label_500_crops_marginal_smoothing_k_5/imagenet] \
[imagenet-folder with train and val folders]

For multiple nodes multi-processing distributed training, please refer to official PyTorch ImageNet training code for details.

Evaluation

python train_FKD.py -a resnet50 -e --resume [model path] [imagenet-folder with train and val folders]

Training Speed Comparison

The training speed of each epoch is tested on HPC/CIAI cluster at MBZUAI with 8 NVIDIA V100 GPUs. The batch size is 1024 for all three methods: (i) regular/vanilla classification framework, (ii) Relabel and (iii) FKD. For Vanilla and ReLabel, we use the average of 10 epochs after the speed is stable. For FKD, we perform num_crops = 4 to calculate the average of (4 $\times$ 10) epochs, note that using 8 will give faster training speed. All other settings are the same for the comparison.

Method Network Training time per-epoch
Vanilla ResNet-50 579.36 sec/epoch
ReLabel ResNet-50 762.11 sec/epoch
FKD (Ours) ResNet-50 486.77 sec/epoch

Trained Models

Method Network accuracy (Top-1) weights configurations
ReLabel ResNet-50 78.9 -- --
FKD ResNet-50 80.1+1.2% link same as ReLabel while initial lr = 0.1 $\times$ $batch size \over 512$
FKD(Plain) ResNet-50 79.8 link Table 12 in paper
(w/o warmup&colorJ )
FKD(AdamW) ResNet-50 80.5 link Table 13 in paper
(same as our settings on ViT and SReT)
ReLabel ResNet-101 80.7 -- --
FKD ResNet-101 81.9+1.2% link Table 12 in paper
FKD(Plain) ResNet-101 81.7 link Table 12 in paper
(w/o warmup&colorJ )

Mobile-level Efficient Networks

Method Network FLOPs accuracy (Top-1) weights
FBNet FBNet-c100 375M 75.12% --
FKD FBNet-c100 375M 77.13%+2.01% link
EfficientNetv2 EfficientNetv2-B0 700M 78.35% --
FKD EfficientNetv2-B0 700M 79.94%+1.59% link

The training protocol is the same as we used for ViT/SReT:

# Use the same settings as on ViT and SReT
cd train_ViT
# Train the model
python -u train_ViT_FKD.py \
--dist-url 'tcp://127.0.0.1:10001' \
--dist-backend 'nccl' \
--multiprocessing-distributed --world-size 1 --rank 0 \
-a tf_efficientnetv2_b0 \
--lr 0.002 --wd 0.05 \
--epochs 300 --cos -j 32 \
--num_classes 1000 \
-b 1024 --num_crops 4 \
--save_checkpoint_path ./FKD_nc_4_224_efficientnetv2_b0 \
--soft_label_type marginal_smoothing_k5  \
--softlabel_path [soft label path] \
[imagenet-folder with train and val folders]

FKD Training on ViT/DeiT and SReT

To train a ViT model, run train_ViT_FKD.py with the desired model architecture and the path to the soft label and ImageNet dataset:

cd train_ViT
python train_ViT_FKD.py \
--dist-url 'tcp://127.0.0.1:10001' \
--dist-backend 'nccl' \
--multiprocessing-distributed --world-size 1 --rank 0 \
-a SReT_LT --lr 0.002 --wd 0.05 --num_crops 4 -b 1024 --cos \
--softlabel_path [soft label path] \
[imagenet-folder with train and val folders]

For the instructions of SReT_LT model, please refer to SReT for details.

Evaluation

python train_ViT_FKD.py -a SReT_LT -e --resume [model path] [imagenet-folder with train and val folders]

Trained Models

Model FLOPs #params accuracy (Top-1) weights configurations
DeiT-T-distill 1.3B 5.7M 74.5 -- --
FKD ViT/DeiT-T 1.3B 5.7M 75.2 link Table 13 in paper
SReT-LT-distill 1.2B 5.0M 77.7 -- --
FKD SReT-LT 1.2B 5.0M 78.7 link Table 13 in paper

Fast MEAL V2

Please see MEAL V2 for the instructions to run FKD with MEAL V2.

Self-supervised Representation Learning Using FKD

Please see FKD-SSL for the instructions to run FKD for SSL task.

Contact

Zhiqiang Shen (zhiqiangshen0214 at gmail.com or zhiqians at andrew.cmu.edu)

fkd's People

Contributors

szq0214 avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.