Code Monkey home page Code Monkey logo

domainbed's Introduction

MixStyle setup

Quick start

Download the PACS dataset:

python3 -m domainbed.scripts.download \
       --data_dir=./domainbed/data

Train a model:

CUDA_VISIBLE_DEVICES=1 python3 -m domainbed.scripts.train --data_dir=./domainbed/data --algorithm=MixStyle --dataset=PACS --test_env 0 --save_tb False

Eval a model:

CUDA_VISIBLE_DEVICES=7 python3 -m domainbed.scripts.eval_model --data_dir=./domainbed/data --algorithm=CLIP_MixStyle --dataset=PACS --test_env 0 --load ./train_output/model_clip_mixstyle_2.pkl
python3 -m domainbed.scripts.train\
       --data_dir=./domainbed/data/\
       --algorithm MixStyle\
       --dataset PACS\
       --test_env 2

New additions

  • domainbed/tb_reduce.py
  • domainbed/scripts/eval_model.py # copy of train.py

Jan 18 runs

Training

   - ERM
   - MixStyle
   - CLIP_ERM (mixup=False)
   - CLIP_MixStyle (mixup=False)
examples
CUDA_VISIBLE_DEVICES=5 python3 -m domainbed.scripts.train --data_dir=./domainbed/data --algorithm=ERM --dataset=PACS --test_env 0 # If you don't want to save the tensorboard
CUDA_VISIBLE_DEVICES=5 python3 -m domainbed.scripts.train --data_dir=./domainbed/data --algorithm=ERM --dataset=PACS --test_env 0 --save_tb # If you want to save the tensorboard

training
CUDA_VISIBLE_DEVICES=5 python3 -m domainbed.scripts.train --data_dir=./domainbed/data --algorithm=ERM --dataset=PACS --test_env 0 --save_tb 
CUDA_VISIBLE_DEVICES=4 python3 -m domainbed.scripts.train --data_dir=./domainbed/data --algorithm=MixStyle --dataset=PACS --test_env 0 --save_tb 

CUDA_VISIBLE_DEVICES=5 python3 -m domainbed.scripts.train --data_dir=./domainbed/data --algorithm=CLIP_ERM --dataset=PACS --test_env 0 --save_tb 
CUDA_VISIBLE_DEVICES=4 python3 -m domainbed.scripts.train --data_dir=./domainbed/data --algorithm=CLIP_MixStyle --dataset=PACS --test_env 0 --save_tb 


generating tsne plots
CUDA_VISIBLE_DEVICES=5 python3 -m domainbed.scripts.eval_model --data_dir=./domainbed/data --algorithm=ERM --dataset=PACS --test_env 0 --load ./train_output/model_ERM.pkl
CUDA_VISIBLE_DEVICES=4 python3 -m domainbed.scripts.eval_model --data_dir=./domainbed/data --algorithm=MixStyle --dataset=PACS --test_env 0 --load ./train_output/model_MixStyle.pkl

CUDA_VISIBLE_DEVICES=5 python3 -m domainbed.scripts.eval_model --data_dir=./domainbed/data --algorithm=CLIP_ERM --dataset=PACS --test_env 0 --load ./train_output/model_CLIP_ERM.pkl
CUDA_VISIBLE_DEVICES=4 python3 -m domainbed.scripts.eval_model --data_dir=./domainbed/data --algorithm=CLIP_MixStyle --dataset=PACS --test_env 0 --load ./train_output/model_CLIP_MixStyle.pkl

DomainBed (adapted for OoD-Bench)

This a fork of the test suite DomainBed. For the purpose of benchmarking the algorithms in OoD-Bench, the following additional contents are introduced:

  • six new datasets: ColoredMNIST_IRM, CelebA_Blond, NICO_Mixed, ImageNet_A, ImageNet_R, ImageNet_V2;
  • a new algorithm: Domain Generalization via Entropy Regularization (ERDG, Zhao et al., 2020);
  • a new model selection method: OODValidationSelectionMethod, which can be triggered by setting the arguments --fixed_val_envs and --fixed_test_envs of domainbed.scripts.sweep;
  • a new network architecture: MNIST_MLP for ColoredMNIST_IRM, adapted from the IRM implementation;
  • a new data augmentation scheme (slightly different from the default data augmentation scheme of DomainBed) , adapted from JigenDG, which can be activated by adding "data_augmentation_scheme": "jigen" to --hparams;
  • an option to unfreeze the batch normalization of ResNets, which can be activated by adding "freeze_bn": false to --hparams.

Benchmarking

The launching scripts of all the benchmarking experiments are provided here. Example usage:

dataset="ColoredMNIST_IRM"
command="launch"  # or "delete_incomplete"
launcher="local"  # or "multi_gpu"
data_dir="/path/to/data"
sh sweep/$dataset/run.sh $command $launcher $data_dir

To launch your own experiments and for more usages, please refer to the DomainBed documentation below.

Welcome to DomainBed

DomainBed is a PyTorch suite containing benchmark datasets and algorithms for domain generalization, as introduced in In Search of Lost Domain Generalization.

Available algorithms

The currently available algorithms are:

Send us a PR to add your algorithm! Our implementations use ResNet50 / ResNet18 networks (He et al., 2015) and the hyper-parameter grids described here.

Available datasets

The currently available datasets are:

Send us a PR to add your dataset! Any custom image dataset with folder structure dataset/domain/class/image.xyz is readily usable. While we include some datasets from the WILDS project, please use their official code if you wish to participate in their leaderboard.

Available model selection criteria

Model selection criteria differ in what data is used to choose the best hyper-parameters for a given model:

  • IIDAccuracySelectionMethod: A random subset from the data of the training domains.
  • LeaveOneOutSelectionMethod: A random subset from the data of a held-out (not training, not testing) domain.
  • OracleSelectionMethod: A random subset from the data of the test domain.

Quick start

Download the datasets:

python3 -m domainbed.scripts.download \
       --data_dir=./domainbed/data

Train a model:

python3 -m domainbed.scripts.train\
       --data_dir=./domainbed/data/MNIST/\
       --algorithm IGA\
       --dataset ColoredMNIST\
       --test_env 2

Launch a sweep:

python -m domainbed.scripts.sweep launch\
       --data_dir=/my/datasets/path\
       --output_dir=/my/sweep/output/path\
       --command_launcher MyLauncher

Here, MyLauncher is your cluster's command launcher, as implemented in command_launchers.py. At the time of writing, the entire sweep trains tens of thousands of models (all algorithms x all datasets x 3 independent trials x 20 random hyper-parameter choices). You can pass arguments to make the sweep smaller:

python -m domainbed.scripts.sweep launch\
       --data_dir=/my/datasets/path\
       --output_dir=/my/sweep/output/path\
       --command_launcher MyLauncher\
       --algorithms ERM DANN\
       --datasets RotatedMNIST VLCS\
       --n_hparams 5\
       --n_trials 1

After all jobs have either succeeded or failed, you can delete the data from failed jobs with python -m domainbed.scripts.sweep delete_incomplete and then re-launch them by running python -m domainbed.scripts.sweep launch again. Specify the same command-line arguments in all calls to sweep as you did the first time; this is how the sweep script knows which jobs were launched originally.

To view the results of your sweep:

python -m domainbed.scripts.collect_results\
       --input_dir=/my/sweep/output/path

Running unit tests

DomainBed includes some unit tests and end-to-end tests. While not exhaustive, but they are a good sanity-check. To run the tests:

python -m unittest discover

By default, this only runs tests which don't depend on a dataset directory. To run those tests as well:

DATA_DIR=/my/datasets/path python -m unittest discover

License

This source code is released under the MIT license, included here.

domainbed's People

Contributors

lopezpaz avatar igul222 avatar jc-audet avatar sahilkhose avatar shahtalebi avatar sirrob1997 avatar yugeten avatar m-just avatar zdhnarsil avatar alexrame avatar dnap512 avatar ranliu98 avatar ashok-arjun avatar daysm avatar jungwon-choi avatar accumulated avatar ryoungj avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.