Code Monkey home page Code Monkey logo

dada's Introduction

Implementation of Data-Augmented Domain Adaptation for Deep Metric Learning (AAAI 2024)

Ren, Li and Chen, Chen and Wang, Liqiang and Hua, Kien

Abstract

Deep Metric Learning (DML) plays an important role in modern computer vision research, where we learn a distance metric for a set of image representations. Recent DML techniques utilize the proxy to interact with the corresponding image samples in the embedding space. However, existing proxy-based DML methods focus on learning individual proxy-to-sample distance while the overall distribution of samples and proxies lacks attention. In this paper, we present a novel proxy-based DML framework that focuses on aligning the sample and proxy distributions to improve the efficiency of proxy-based DML losses. Specifically, we propose the Data-Augmented Domain Adaptation (DADA) method to adapt the domain gap between the group of samples and proxies. To the best of our knowledge, we are the first to leverage domain adaptation to boost the performance of proxy-based DML.We show that our method can be easily plugged into existing proxy-based DML losses. Our experiments on benchmarks, including the popular CUB-200-2011, CARS196, Stanford Online Products, and In-Shop Clothes Retrieval, show that our learning algorithm significantly improves the existing proxy losses and achieves superior results compared to the existing methods.

Demo Figure

Image

Citation

If you find this repo useful, please consider citing:

@inproceedings{ren2024towards,
  title = {Towards improved proxy-based deep metric learning via data-augmented domain adaptation},
  author = {Ren, Li and Chen, Chen and Wang, Liqiang and Hua, Kien},
  booktitle = {Proceedings of the 38th AAAI Conference on Artificial Intelligence},
  year = {2024}
}

Installation

  1. Properly install the Anaconda or Miniconda
  2. Prepare Conda enviroment with following script
conda init
source ~/.bashrc
conda create --name dml python=3.8 -y
conda activate dml
  1. Install CUDA and Pytorch with following script
conda update -n base -c defaults conda -y
conda install -y scipy pandas termcolor
conda install -yc conda-forge
conda install -y pytorch=1.12.1=py3.8_cuda11.3_cudnn8.3.2_0 torchvision faiss-gpu cudatoolkit=11.3 -c pytorch
pip install timm tqdm pretrainedmodels

Download Data

CUB200 can be downloaded from here
CARS can be downloaded from here
Stanford_Online_Products can be downloaded from the official webset
In-Shop Clothes Retrieval can be downloaded from the official webset

The datasets should be unpacked and placed in a distinct folder. An illustration of a data structure is as follows:

$HOME/data/
├── cars196
│   └── images
|        ├── Acura Integra Type R 2001
|       ... 
|
├── cub200
│   └── images
|        ├── 001.Black_footed_Albatross
|        ...
├── inshop
│   └── img
|        ├── ...
|   └── list_eval_partition.txt
├── online_products
│   ├── images
│   │   ├── bicycle_final
|   |   ....
│   └── Info_Files

Training

Start to run the training procedure with following command:

python main.py --source_path ${data_path} \
--save_path ${save_path} \
--save_name ${save_name} \
--config ${config_path} \
--gpu ${gpu_id}

data_path is the root of your dataset.
save_path is the path to save essential data and checkpoints.
save_name is the name of this config_pathrun config_path is the path of configure file that save specific hyper-parameters.
gpu_id is the index of your GPU starting from 0

For example, to train a model from CUB200 on GPU 0

# Train DADA on CUB200
python main.py --source_path ../dml_data \
--save_path ../Results \
--config ./configs/cub200.yaml \
--gpu 0

You can also create your own training config by editing or creating files in ./configs/{your_setting}.yaml

Evaluation

To evaluate the pre-trained model, run following script:

python eval.py --source_path ${data_path} \
--dataset ${data_name} \
--test_path ${checkpoint_path} \
--evaluation_metrics ${metrics}

data_path is the root of your dataset.
data_name is the name of dataset (cub200, cars196, online_products, inshop).
checkpoint_path is the path of your pre-trained checkpoint
metrics is the list of evaluation metrics

For example,

# evaluate one checkpoint on CUB200
python eval.py --source_path ../dml_data \
--dataset cub200 \
--test_path ../Results/cub200/[email protected] \
--evaluation_metrics ['e_recall@1', 'e_recall@2', 'e_recall@4', 'e_recall@10', 'f1', 'mAP_R']

The pre-trained model to demo the evaluation on CUB200 can be downloaded here

Please note that the reproduced results may be different from the those reported in the paper due to different enviroments and the architecture of GPUs. Our original experiments were done on RTX3090 with Cuda 11.3 and Pytorch 1.12

Acknowledgement

This implementation is partially based on:

https://github.com/Confusezius/Deep-Metric-Learning-Baselines https://github.com/Confusezius/Revisiting_Deep_Metric_Learning_PyTorch

dada's People

Contributors

noahsark avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

jjoon0928

dada's Issues

Question about dataset directory

Hello. When I am in the training phase, cars196 and sop dataset.

I tried to make a hierarchy of them:

dataset/
|---cars196
|      ㄴ---images
|                  ㄴ--- 000001.jpg
|                  ㄴ--- 000002.jpg
|                  ㄴ--- ...
|---online_products 
|      ㄴ---images
|                  ㄴ--- bicycle_final
|                               ㄴ 111085122871_0.JPG
|                  ㄴ--- cabinet_final
|                  ㄴ--- ...
|      ㄴ---Info_Files

However, cars196 show NotADirectoryError: [Errno 20] Not a directory: '/dataset/cars196/images/000001.jpg'
and sop shows FileNotFoundError: [Errno 2] No such file or directory: '
/dataset/online_products/images/cabinet_final/331185934232_10.JPG' even though I added this image again.

And is there any code for visualizing features through t-SNE such as Figure 5 in your paper?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.