Code Monkey home page Code Monkey logo

cdtrans's Introduction

CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation

Introduction

Unsupervised domain adaptation (UDA) aims to transfer knowledge learned from a labeled source domain to a different unlabeled target domain. Most existing UDA methods focus on learning domain-invariant feature representation, either from the domain level or category level, using convolution neural networks (CNNs)-based frameworks. With the success of Transformer in various tasks, we find that the cross-attention in Transformer is robust to the noisy input pairs for better feature alignment, thus in this paper Transformer is adopted for the challenging UDA task. Specifically, to generate accurate input pairs, we design a two-way center-aware labeling algorithm to produce pseudo labels for target samples. Along with the pseudo labels, a weight-sharing triple-branch transformer framework is proposed to apply self-attention and cross-attention for source/target feature learning and source-target domain alignment, respectively. Such design explicitly enforces the framework to learn discriminative domain-specific and domain-invariant representations simultaneously. The proposed method is dubbed CDTrans (cross-domain transformer), and it provides one of the first attempts to solve UDA tasks with a pure transformer solution. Extensive experiments show that our proposed method achieves the best performance on all public UDA datasets including Office-Home, Office-31, VisDA-2017, and DomainNet.

framework

Results

Table 1 [UDA results on Office-31]

MethodsAvg. A->DA->WD->AD->WW->AW->D
Baseline(DeiT-S)86.7 87.686.974.997.773.599.6
model model model
CDTrans(DeiT-S)90.4 94.693.578.498.27899.6
model model model model model model
Baseline(DeiT-B)88.8 90.890.476.898.276.4100
model model model
CDTrans(DeiT-B)92.6 9796.781.19981.9100
model model model model model model

Table 2 [UDA results on Office-Home]

Methods Avg. Ar->ClAr->PrAr->ReCl->ArCl->PrCl->Re Pr->ArPr->ClPr->ReRe->ArRe->ClRe->Pr
Baseline(DeiT-S) 69.8 55.67379.470.672.976.3 67.5518174.553.282.7
model model model model
CDTrans(DeiT-S)74.7 60.679.582.475.681.082.3 72.556.784.477.059.185.5
model model model model model model model model model model model model
Baseline(DeiT-B)74.861.879.584.375.4 78.881.272.855.784.478.359.386
model model model model
CDTrans(DeiT-B) 80.5 68.88586.981.587.187.3 79.663.388.2826690.6
model model model model model model model model model model model model

Table 3 [UDA results on VisDA-2017]

Methods Per-class planebcyclbuscarhorseknife mcyclpersonplantsktbrdtraintruck
Baseline(DeiT-B) 67.3 (model) 98.148.184.665.276.359.4 94.511.889.552.294.534.1
CDTrans(DeiT-B) 88.4 (model) 97.786.39 86.8783.3397.7697.16 95.9384.0897.9383.4794.5955.3

Table 4 [UDA results on DomainNet]

Base-SclpinfopntqdrrelsktAvg. CDTrans-SclpinfopntqdrrelsktAvg.
clp - 21.2 44.2 15.3 59.9 46.0 37.3 clp - 25.3 52.5 23.2 68.3 53.2 44.5
model model model model model model model
info 36.8 - 39.4 5.4 52.1 32.6 33.3 info 47.6 - 48.3 9.9 62.8 41.1 41.9
model model model model model model model
pnt 47.1 21.7 - 5.7 60.2 39.9 34.9 pnt 55.4 24.5 - 11.7 67.4 48.0 41.4
model model model model model model model
qdr 25.0 3.3 10.4 -18.8 14.0 14.3 qdr 36.6 5.3 19.3 -33.8 22.7 23.5
model model model model model model model
rel 54.8 23.9 52.6 7.4 - 40.1 35.8 rel 61.5 28.1 56.8 12.8 - 47.2 41.3
model model model model model model model
skt 55.6 18.6 42.7 14.9 55.7 - 37.5 skt 64.3 26.1 53.2 23.9 66.2 - 46.7
model model model model model model model
Avg.43.9 17.7 37.9 9.7 49.3 34.5 32.2 Avg.53.08 21.86 46.02 16.3 59.7 42.44 39.9
Base-BclpinfopntqdrrelsktAvg. CDTrans-BclpinfopntqdrrelsktAvg.
clp - 24.2 48.9 15.5 63.9 50.7 40.6 clp - 29.4 57.2 26.0 72.6 58.1 48.7
model model model model model model model
info 43.5 - 44.9 6.5 58.8 37.6 38.3 info 57.0 - 54.4 12.8 69.5 48.4 48.4
model model model model model model model
pnt 52.8 23.3 - 6.6 64.6 44.5 38.4 pnt 62.9 27.4 - 15.8 72.1 53.9 46.4
model model model model model model model
qdr 31.8 6.1 15.6 -23.4 18.9 19.2 qdr 44.6 8.9 29.0 -42.6 28.5 30.7
model model model model model model model
rel 58.9 26.3 56.7 9.1 - 45.0 39.2 rel 66.2 31.0 61.5 16.2 - 52.9 45.6
model model model model model model model
skt 60.0 21.1 48.4 16.6 61.7 - 41.6 skt 69.0 29.6 59.0 27.2 72.5 - 51.5
model model model model model model model
Avg.49.4 20.2 42.9 10.9 54.5 39.3 36.2 Avg.59.9 25.3 52.2 19.6 65.9 48.4 45.2

Requirements

Installation

pip install -r requirements.txt
(Python version is the 3.7 and the GPU is the V100 with cuda 10.1, cudatoolkit 10.1)

Prepare Datasets

Download the UDA datasets Office-31, Office-Home, VisDA-2017, DomainNet

Then unzip them and rename them under the directory like follow: (Note that each dataset floader needs to make sure that it contains the txt file that contain the path and lable of the picture, which is already in data/the_dataset of this project.)

data
├── OfficeHomeDataset
│   │── class_name
│   │   └── images
│   └── *.txt
├── domainnet
│   │── class_name
│   │   └── images
│   └── *.txt
├── office31
│   │── class_name
│   │   └── images
│   └── *.txt
├── visda
│   │── train
│   │   │── class_name
│   │   │   └── images
│   │   └── *.txt 
│   └── validation
│       │── class_name
│       │   └── images
│       └── *.txt 

Prepare DeiT-trained Models

For fair comparison in the pre-training data set, we use the DeiT parameter init our model based on ViT. You need to download the ImageNet pretrained transformer model : DeiT-Small, DeiT-Base and move them to the ./data/pretrainModel directory.

Training

We utilize 1 GPU for pre-training and 2 GPUs for UDA, each with 16G of memory.

Scripts.

Command input paradigm

bash scripts/[pretrain/uda]/[office31/officehome/visda/domainnet]/run_*.sh [deit_base/deit_small]

For example

DeiT-Base scripts

# Office-31     Source: Amazon   ->  Target: Dslr, Webcam
bash scripts/pretrain/office31/run_office_amazon.sh deit_base
bash scripts/uda/office31/run_office_amazon.sh deit_base

#Office-Home    Source: Art      ->  Target: Clipart, Product, Real_World
bash scripts/pretrain/officehome/run_officehome_Ar.sh deit_base
bash scripts/uda/officehome/run_officehome_Ar.sh deit_base

# VisDA-2017    Source: train    ->  Target: validation
bash scripts/pretrain/visda/run_visda.sh deit_base
bash scripts/uda/visda/run_visda.sh deit_base

# DomainNet     Source: Clipart  ->  Target: painting, quickdraw, real, sketch, infograph
bash scripts/pretrain/domainnet/run_domainnet_clp.sh deit_base
bash scripts/uda/domainnet/run_domainnet_clp.sh deit_base

DeiT-Small scripts Replace deit_base with deit_small to run DeiT-Small results. An example of training on office-31 is as follows:

# Office-31     Source: Amazon   ->  Target: Dslr, Webcam
bash scripts/pretrain/office31/run_office_amazon.sh deit_small
bash scripts/uda/office31/run_office_amazon.sh deit_small

Evaluation

# For example VisDA-2017
python test.py --config_file 'configs/uda.yml' MODEL.DEVICE_ID "('0')" TEST.WEIGHT "('../logs/uda/vit_base/visda/transformer_best_model.pth')" DATASETS.NAMES 'VisDA' DATASETS.NAMES2 'VisDA' OUTPUT_DIR '../logs/uda/vit_base/visda/' DATASETS.ROOT_TRAIN_DIR './data/visda/train/train_image_list.txt' DATASETS.ROOT_TRAIN_DIR2 './data/visda/train/train_image_list.txt' DATASETS.ROOT_TEST_DIR './data/visda/validation/valid_image_list.txt'  

Acknowledgement

Codebase from TransReID

cdtrans's People

Contributors

cdtrans avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.