Code Monkey home page Code Monkey logo

refilled's Introduction

REFILLED

This is the code of CVPR 2020 oral paper "Distilling Cross-Task Knowledge via Relationship Matching". If you use any content of this repo for your work, please cite the following bib entry:

@inproceedings{ye2020refilled,
  author    = {Han-Jia Ye and
               Su Lu and
               De-Chuan Zhan},
  title     = {Cross-Task Knowledge Distillation via Relationship Matching},
  booktitle = {Computer Vision and Pattern Recognition (CVPR)},
  year      = {2020}
}

Cross-Task Knowledge Distillation

It is intuitive to take advantage of the learning experience from related pre-trained models to facilitate model training in the current task. Different from fine-tuning or parameter regularization, knowledge distillation/knowledge reuse extracts kinds of dark knowledge/privileged information from a fixed strong model (a.k.a. "teacher"), and enrich the target model (a.k.a. "student") training with more signals. Owing to the strong correspondence between classifier and class,it is difficult to reuse the classification knowledge from a cross-task teacher model.

Two-Stage Solution - REFILLED

We propose the RElationship FacIlitated Local cLassifiEr Distillation (REFILLED), which decomposes the knowledge distillation flow for embedding and the top-layer classifier respectively. REFILLED contains two stages. First, the discriminative ability of features is emphasized. For those hard triplets determined by the embedding of the student model, the teacher’s comparison between them is used as the soft supervision. A teacher enhances the discriminative embedding of the student by specifying the proportion for each object how much a dissimilar impostor should be far away from a target nearest neighbor. Furthermore, the teacher constructs the soft supervision for each instance by measuring its similarity to a local center. By matching the "instance-label" predictions across models, the cross-task teacher improves the learning efficacy of the student.

Important Improvements to ReFilled

We further improve our proposed method by extending the dimension of matched tuple probabilities in stage1 and replacing local class centers with global class centers in stage2.

Experiment Results

REFILLED can be used in several applications, e.g., standard knowledge distillation, cross-task knowledge distillation and middle-shot learning. Standard knowledge distillation is widely used and we show the results under this setting below. Experiment results of cross-task knowledge distillation and middle-shot learning can be found in the paper.

CIFAR-100 with wide_resnet

(depth, width) (40,2) (16,2) (40,1) (16,1)
Teacher 76.04
Student 76.04 70.15 71.53 66.30
Paper Results 77.49 74.01 72.72 67.56
REFILLED after stage1 (paper) 55.47 50.14 45.04 38.06
REFILLED after stage1 (new) 62.12 53.86 52.71 44.33

Results after stage1 are accuracies of NCM classifier, rather than NMI of clustering.

CUB-200 with mobile_net

width multiplier 1.00 0.75 0.50 0.25
Teacher 76.19
Student 76.19 74.49 72.68 68.80
Paper Results 78.95 78.01 76.11 73.42
REFILLED after stage1 (paper) 36.56 33.00 29.60 19.10
REFILLED after stage1 (new) 38.47 36.95 33.71 25.34

Results after stage1 are accuracies of NCM classifier, rather than NMI of clustering.

Code and Arguments

This code implements REFILLED under the setting where a source task and a target task is given. main.py is the main file and the arguments it take are listed below.

Task Arguments

  • data_name: name of dataset
  • teacher_network_name: architecture of teacher model
  • student_network_name: architecture of student model

Experiment Environment Arguments

  • devices: list of gpu ids
  • flag_gpu: whether to use gpu or not
  • flag_no_bar: whether to use a bar
  • n_workers: number of workers in data loader
  • flag_tuning: whether to tune the hyperparameters on validation set or train on the whole training set

Optimizer Arguments

  • lr1: initial learning rate in stage 1
  • lr2: initial learning rate in stage 2
  • point: when to decrease the learning rate
  • gamma: the extent of learning rate decrease
  • wd: weight decay
  • mo: momentum

Network Arguments

  • depth: depth of resnet and wide_resnet
  • width: width of wide_resnet
  • ca: channel coefficient of mobile_net
  • dropout_rate: dropout rate of the network

Training Procedure Arguments

  • n_training_epochs1: number of training epochs in stage 1
  • n_training_epochs2: number of training epochs in stage 2
  • batch_size: batch size in training
  • tau1: temperature for stochastic triplet embedding in stage 1
  • tau2: temperature for local distillation in stage 2
  • lambd: weight of teaching loss in stage 2

refilled's People

Contributors

njulus avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

refilled's Issues

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.