Code Monkey home page Code Monkey logo

fsl-dapna's Introduction

Jiechao Guan, Zhiwu Lu, Tao Xiang, Ji-Rong Wen

Abstract

To recognize the unseen classes with only few samples, few-shot learning (FSL) uses prior knowledge learned from the seen classes. A major challenge for FSL is that the distribution of the unseen classes is different from that of those seen, resulting in poor generalization even when a model is meta-trained on the seen classes. This class-difference-caused distribution shift can be considered as a special case of domain shift. In this paper, for the first time, we propose a domain adaptation prototypical network with attention (DAPNA) to explicitly tackle such a domain shift problem in a meta-learning framework. Specifically, armed with a set transformer based attention module, we construct each episode with two sub-episodes without class overlap on the seen classes to simulate the domain shift between the seen and unseen classes. To align the feature distributions of the two sub-episodes with limited training samples, a feature transfer network is employed together with a margin disparity discrepancy (MDD) loss. Importantly, theoretical analysis is provided to give the learning bound of our DAPNA. Extensive experiments show that our DAPNA outperforms the state-of-the-art FSL alternatives, often by significant margins.

Citation

If you find it useful, please consider citing our work using the bibtex:

@misc{guan2020fewshot,
  title={Few-Shot Learning as Domain Adaptation: Algorithm and Analysis},
  author={Jiechao Guan and Zhiwu Lu and Tao Xiang and Ji-Rong Wen},
  year={2020},
  eprint={2002.02050},
  archivePrefix={arXiv},
  primaryClass={cs.LG}
}

Environment

  • Python 3.7
  • Pytorch 1.3.1

Get Started

Data Preparation

  1. Folder '\data' should contain the raw images of 3 FSL datasets (e.g. miniImageNet, tieredImageNet, CUB). We download the original images of ImageNet to construct mini- and tieredImageNet datasets based on the splitting strategies. We download the CUB dataset and use the bounding-box images. You can download the miniImageNet (Enter Code: 4u9g, 2.86 Gb) and CUB (Enter Code: kscu, 2.23 Gb) processed by us. The tieredImageNet zip exceeds 88 Gb so we are unable to upload it to Baidu Disk. Email to me ([email protected]) if you want it.

  2. Folder '\saves' should include the pretrained WRN-28-10 models on three FSL datasets. You can pretrain a new one by following below instructions, or use our pretrained models: miniImageNet (Enter Code: 2p7s), tieredImageNet (Enter Code: zigs), CUB (Enter Code: 1h2y).

Model Training and Test

--Standard FSL Setting
1. Pre-train and save a model.
  -- python pretrain.py
2. Train a DAPNA model.
  -- sh train_proto_mdd.sh
3. Evaluate DAPNA's performance.
  -- python eval.py

--Cross domain FSL setting
1. Train a DAPNA model.
  -- sh cross_domain_train_proto_mdd.sh
2. Evaluate DAPNA's performance.
  -- python cross_domain_eval.py

Reference

We thank following repos providing helpful components/functions in our work.

  1. A Closer Look at Few-shot Classification https://github.com/wyharveychen/CloserLookFewShot
  2. Learning Embedding Adaptation for Few-Shot Learning https://github.com/Sha-Lab/FEAT
  3. Bridging Theory and Algorithm for Domain Adaptation https://github.com/thuml/MDD

fsl-dapna's People

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.