Code Monkey home page Code Monkey logo

dassl.pytorch's Introduction

Dassl

Dassl is a PyTorch toolbox designed for researching Domain Adaptation and Semi-Supervised Learning (and hence the name Dassl). It has a modular design and unified interfaces, allowing fast prototyping and experimentation of new DA/SSL methods. With Dassl, a new method can be implemented with only a few lines of code.

You can use Dassl as a library for the following research:

  • Domain adaptation
  • Domain generalization
  • Semi-supervised learning

What's new

  • [Jul 2021]: v0.3.4: Adds a new function generate_fewshot_dataset() to the base dataset class, which allows for the generation of a few-shot learning setting. One can customize a few-shot dataset by specifying _C.DATASET.NUM_SHOTS and give it to generate_fewshot_dataset().
  • [Jul 2021]: v0.3.2: Adds _C.INPUT.INTERPOLATION (default: bilinear). Available interpolation modes are bilinear, nearest, and bicubic.
  • [Jul 2021] v0.3.1: Now you can use *.register(force=True) to replace previously registered modules.
  • [Jul 2021] v0.3.0: Allows to deploy the model with the best validation performance for final test (for the purpose of model selection). Specifically, a new config variable named _C.TEST.FINAL_MODEL is introduced, which takes either "last_step" (default) or "best_val". When set to "best_val", the model will be evaluated on the val set after each epoch and the one with the best validation performance will be saved and used for final test (see this code).
  • [Jul 2021] v0.2.7: Adds attribute classnames to the base dataset class. Now you can get a list of class names ordered by numeric labels by calling trainer.dm.dataset.classnames.
  • [Jun 2021] v0.2.6: Merges MixStyle2 to MixStyle. A new variable self.mix is used to switch between random mixing and cross-domain mixing. Please see this for more details on the new features.
  • [Jun 2021] v0.2.5: Fixs a bug in the calculation of per-class recognition accuracy.
  • [Jun 2021] v0.2.4: Adds extend_cfg(cfg) to train.py. This function is particularly useful when you build your own methods on top of Dassl.pytorch and need to define some custom variables. Please see the repository mixstyle-release or ssdg-benchmark for examples.
  • [Jun 2021] New benchmarks for semi-supervised domain generalization at https://github.com/KaiyangZhou/ssdg-benchmark.
  • [Apr 2021] Do you know you can use tools/parse_test_res.py to read the log files and automatically calculate and print out the results including mean and standard deviation? Check the instructions in tools/parse_test_res.py for more details.
More
  • [Apr 2021] v0.2.3: A MixStyle layer can now be deactivated or activated by using model.apply(deactivate_mixstyle) or model.apply(activate_mixstyle) without modifying the source code. See dassl/modeling/ops/mixstyle.py for the details.
  • [Apr 2021] v0.2.2: Adds RandomClassSampler, which samples from a certain number of classes a certain number of images to form a minibatch (the code is modified from Torchreid).
  • [Apr 2021] v0.2.1: Slightly adjusts the ordering in setup_cfg() (see tools/train.py).
  • [Apr 2021] v0.2.0: Adds _C.DATASET.ALL_AS_UNLABELED (for the SSL setting) to the config variable list. When this variable is set to True, all labeled data will be included in the unlabeled data set.
  • [Apr 2021] v0.1.9: Adds VLCS to the benchmark datasets (see dassl/data/datasets/dg/vlcs.py).
  • [Mar 2021] v0.1.8: Allows optim and sched to be None in register_model().
  • [Mar 2021] v0.1.7: Adds MixStyle models to dassl/modeling/backbone/resnet.py. The training configs in configs/trainers/dg/vanilla can be used to train MixStyle models.
  • [Mar 2021] v0.1.6: Adds CIFAR-10/100-C to the benchmark datasets for evaluating a model's robustness to image corruptions.
  • [Mar 2021] We have just released a survey on domain generalization at https://arxiv.org/abs/2103.02503, which summarizes the ten-year development in this topic with coverage on the history, related problems, datasets, methodologies, potential directions, and so on.
  • [Jan 2021] Our recent work, MixStyle (mixing instance-level feature statistics of samples of different domains for improving domain generalization), is accepted to ICLR'21. The code is available at https://github.com/KaiyangZhou/mixstyle-release where the cross-domain image classification part is based on Dassl.pytorch.
  • [May 2020] v0.1.3: Adds the Digit-Single dataset for benchmarking single-source DG methods. The corresponding CNN model is dassl/modeling/backbone/cnn_digitsingle.py and the dataset config file is configs/datasets/dg/digit_single.yaml. See Volpi et al. NIPS'18 for how to do evaluation.
  • [May 2020] v0.1.2: 1) Adds EfficientNet models (B0-B7) (credit to https://github.com/lukemelas/EfficientNet-PyTorch). To use EfficientNet, set MODEL.BACKBONE.NAME to efficientnet_b{N} where N={0, ..., 7}. 2) dassl/modeling/models is renamed to dassl/modeling/network (build_model() to build_network() and MODEL_REGISTRY to NETWORK_RESIGTRY).

Overview

Dassl has implemented the following methods:

Feel free to make a PR to add your methods here to make it easier for others to benchmark!

Dassl supports the following datasets:

Get started

Installation

Make sure conda is installed properly.

# Clone this repo
git clone https://github.com/KaiyangZhou/Dassl.pytorch.git
cd Dassl.pytorch/

# Create a conda environment
conda create -n dassl python=3.7

# Activate the environment
conda activate dassl

# Install dependencies
pip install -r requirements.txt

# Install torch (version >= 1.7.1) and torchvision
conda install pytorch torchvision cudatoolkit=10.1 -c pytorch

# Install this library (no need to re-build if the source code is modified)
python setup.py develop

Follow the instructions in DATASETS.md to preprocess the datasets.

Training

The main interface is implemented in tools/train.py, which basically does

  1. initialize the config with cfg = setup_cfg(args) where args contains the command-line input (see tools/train.py for the list of input arguments);
  2. instantiate a trainer with build_trainer(cfg) which loads the dataset and builds a deep neural network model;
  3. call trainer.train() for training and evaluating the model.

Below we provide an example for training a source-only baseline on the popular domain adaptation dataset, Office-31,

CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31

$DATA denotes the location where datasets are installed. --dataset-config-file loads the common setting for the dataset (Office-31 in this case) such as image size and model architecture. --config-file loads the algorithm-specific setting such as hyper-parameters and optimization parameters.

To use multiple sources, namely the multi-source domain adaptation task, one just needs to add more sources to --source-domains. For instance, to train a source-only baseline on miniDomainNet, one can do

CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains clipart painting real \
--target-domains sketch \
--dataset-config-file configs/datasets/da/mini_domainnet.yaml \
--config-file configs/trainers/da/source_only/mini_domainnet.yaml \
--output-dir output/source_only_minidn

After the training finishes, the model weights will be saved under the specified output directory, along with a log file and a tensorboard file for visualization.

To print out the results saved in the log file (so you do not need to exhaustively go through all log files and calculate the mean/std by yourself), you can use tools/parse_test_res.py. The instruction can be found in the code.

For other trainers such as MCD, you can set --trainer MCD while keeping the config file unchanged, i.e. using the same training parameters as SourceOnly (in the simplest case). To modify the hyper-parameters in MCD, like N_STEP_F (number of steps to update the feature extractor), you can append TRAINER.MCD.N_STEP_F 4 to the existing input arguments (otherwise the default value will be used). Alternatively, you can create a new .yaml config file to store your custom setting. See here for a complete list of algorithm-specific hyper-parameters.

Test

Model testing can be done by using --eval-only, which asks the code to run trainer.test(). You also need to provide the trained model and specify which model file (i.e. saved at which epoch) to use. For example, to use model.pth.tar-20 saved at output/source_only_office31/model, you can do

CUDA_VISIBLE_DEVICES=0 python tools/train.py \
--root $DATA \
--trainer SourceOnly \
--source-domains amazon \
--target-domains webcam \
--dataset-config-file configs/datasets/da/office31.yaml \
--config-file configs/trainers/da/source_only/office31.yaml \
--output-dir output/source_only_office31_test \
--eval-only \
--model-dir output/source_only_office31 \
--load-epoch 20

Note that --model-dir takes as input the directory path which was specified in --output-dir in the training stage.

Write a new trainer

A good practice is to go through dassl/engine/trainer.py to get familar with the base trainer classes, which provide generic functions and training loops. To write a trainer class for domain adaptation or semi-supervised learning, the new class can subclass TrainerXU. For domain generalization, the new class can subclass TrainerX. In particular, TrainerXU and TrainerX mainly differ in whether using a data loader for unlabeled data. With the base classes, a new trainer may only need to implement the forward_backward() method, which performs loss computation and model update. See dassl/enigne/da/source_only.py for example.

Add a new backbone/head/network

backbone corresponds to a convolutional neural network model which performs feature extraction. head (which is an optional module) is mounted on top of backbone for further processing, which can be, for example, a MLP. backbone and head are basic building blocks for constructing a SimpleNet() (see dassl/engine/trainer.py) which serves as the primary model for a task. network contains custom neural network models, such as an image generator.

To add a new module, namely a backbone/head/network, you need to first register the module using the corresponding registry, i.e. BACKBONE_REGISTRY for backbone, HEAD_REGISTRY for head and NETWORK_RESIGTRY for network. Note that for a new backbone, we require the model to subclass Backbone as defined in dassl/modeling/backbone/backbone.py and specify the self._out_features attribute.

We provide an example below for how to add a new backbone.

from dassl.modeling import Backbone, BACKBONE_REGISTRY

class MyBackbone(Backbone):

    def __init__(self):
        super().__init__()
        # Create layers
        self.conv = ...

        self._out_features = 2048

    def forward(self, x):
        # Extract and return features

@BACKBONE_REGISTRY.register()
def my_backbone(**kwargs):
    return MyBackbone()

Then, you can set MODEL.BACKBONE.NAME to my_backbone to use your own architecture. For more details, please refer to the source code in dassl/modeling.

Citation

If you find this code useful to your research, please give credit to the following paper

@article{zhou2020domain,
  title={Domain Adaptive Ensemble Learning},
  author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
  journal={arXiv preprint arXiv:2003.07325},
  year={2020}
}

dassl.pytorch's People

Contributors

kaiyangzhou avatar siaimes avatar wyf0912 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.