Code Monkey home page Code Monkey logo

mixvit's Introduction

Mix-ViT: Mixing Attentive Vision Transformer for Ultra-Fine-Grained Visual Categorization

Official PyTorch implementation of Mix-ViT: Mixing Attentive Vision Transformer for Ultra-Fine-Grained Visual Categorization accepted by Pattern Recognition.

If you use the code in this repo for your work, please cite the following bib entries:

@article{yu2023mix,
  title={Mix-ViT: Mixing attentive vision transformer for ultra-fine-grained visual categorization},
  author={Yu, Xiaohan and Wang, Jun and Zhao, Yang and Gao, Yongsheng},
  journal={Pattern Recognition},
  volume={135},
  pages={109131},
  year={2023},
  publisher={Elsevier}
}

Abstract

Ultra-fine-grained visual categorization (ultra-FGVC) moves down the taxonomy level to classify sub-granularity categories of fine-grained objects. This inevitably poses a challenge, i.e., classifying highly similar objects with limited samples, which impedes the performance of recent advanced vision transformer methods. To that end, this paper introduces Mix-ViT, a novel mixing attentive vision transformer to address the above challenge towards improved ultra-FGVC. The core design is a self-supervised module that mixes the high-level sample tokens and learns to predict whether a token has been substituted after attentively substituting tokens. This drives the model to understand the contextual discriminative details among inter-class samples. Via incorporating such a self-supervised module, the network gains more knowledge from the intrinsic structure of input data and thus improves generalization capability with limited training sample. The proposed Mix-ViT achieves competitive performance on seven publicly available datasets, demonstrating the potential of vision transformer compared to CNN for the first time in addressing the challenging ultra-FGVC tasks.

Prerequisites

The following packages are required to run the scripts:

  • [Python >= 3.6]
  • [PyTorch = 1.8]
  • [Torchvision]
  • [Apex]

Download Google pre-trained ViT models

wget https://storage.googleapis.com/vit_models/imagenet21k/{MODEL_NAME}.npz

Dataset

You can download the datasets from the links below:

Run the experiments.

Using the scripts on scripts directory to train the model, e.g., train on SoybeanGene dataset.

$ sh scripts/train_soybean_gene.sh

Download Trained Models

Trained model Google Drive

Acknowledgment

Our project references the codes in the following repos. Thanks for thier works and sharing.

mixvit's People

Contributors

markin-wang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

Forkers

shihao28 yiyizh

mixvit's Issues

CUB数据集上复现报错

在使用MixViT-main/scripts/train_cub.sh复现mixvit时会报错:
Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/ml_collections/config_dict/config_dict.py", line 883, in getitem
field = self._fields[key]
KeyError: 'resnet'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/ml_collections/config_dict/config_dict.py", line 807, in getattr
return self[attribute]
File "/usr/local/lib/python3.6/dist-packages/ml_collections/config_dict/config_dict.py", line 889, in getitem
raise KeyError(self._generate_did_you_mean_message(key, str(e)))
KeyError: "'resnet'"

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/opt/data/private/MixViT-main/train.py", line 458, in
main()
File "/opt/data/private/MixViT-main/train.py", line 451, in main
args, model = setup(args)
File "/opt/data/private/MixViT-main/train.py", line 102, in setup
model = make_model(config, args, zero_head=True, num_classes=num_classes, vis=True)
File "/opt/data/private/MixViT-main/models/make_model.py", line 94, in make_model
model = build_transformer(config, args, zero_head=True, num_classes=num_classes, vis=True)
File "/opt/data/private/MixViT-main/models/make_model.py", line 43, in init
self.inter_arch = ResNetV2(block_units=config.resnet.num_layers,
File "/usr/local/lib/python3.6/dist-packages/ml_collections/config_dict/config_dict.py", line 809, in getattr
raise AttributeError(e)
AttributeError: "'resnet'"
请问是什么情况呢?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.