Code Monkey home page Code Monkey logo

adamatch's Introduction

AdaMatch

Code for the paper: "AdaMatch: A Unified Approach to Semi-Supervised Learning and Domain Adaptation" by David Berthelot, Rebecca Roelofs, Kihyuk Sohn, Nicholas Carlini, and Alex Kurakin.

This is not an officially supported Google product.

AdaMatch diagram

Setup

sudo apt install python3-dev python3-virtualenv python3-tk imagemagick
virtualenv -p python3 --system-site-packages ~/jax3
. ~/jax3/bin/activate

# Install dependencies (replace with your installed CUDA version)
CUDA_VERSION=11.2
pip install --upgrade -r requirements.txt
pip install -f https://storage.googleapis.com/jax-releases/jax_releases.html jaxlib==`python3 -c 'import jaxlib; print(jaxlib.__version__)'`+cuda`echo $CUDA_VERSION | sed s:\\\.::g`

Required environment variables

export ML_DATA="path to where you want the datasets saved"
export PYTHONPATH=$PYTHONPATH:.

Potentially useful environment variables

# Use this config if you won't want JAX to take the whole GPU memory
export XLA_PYTHON_CLIENT_PREALLOCATE=false

Data preparation

# Download datasets and save them as TF records
CUDA_VISIBLE_DEVICES= python scripts/create_datasets.py

# Alternatively if your machine has many cores, this parallel version is much faster to run.
bash ./runs/create_datasets.sh

Training

Options

  • --logdir where the results are saved.
  • --uratio ratio of unlabeled data / label data.
  • --augment typically can take two values (weak,strong) augmentation and a prefix (Control-Theory Augment (CTA) introduced in ReMixMatch or no prefix). Examples:
    • --augment=\(sm,smc\) weak augmentation is shift and mirror, strong augmentation is shift, mirror and cutout.
    • --augment=CTA\(sm,sm\) weak augmentation is shift and mirror, strong augmentation CTA on top of shift and mirror.

Specific to Domain Adaptation (DA) and Semi-Supervised Domain Adaptation (SSDA)

  • --dataset dataset to use with desired size, example: domainnet32 (32x32 images)
  • --source source subset (for example clipart)
  • --target target subset to which to adapt the source (for example quickdraw)

Note: for SSDA, the target takes a different form.

  • --target target subset to which to adapt the source and how many label per class to use and what random seed to use (for example quickdraw(3,seed=2) means use 3 labels per class picked at random using random seed 2)

Specific to Semi-Supervised Learning (SSL)

  • --dataset combines both dataset and source from SSDA in a single option. For example: domainnet32_quickdraw(3,seed=2).

Fully supervised

# Baseline: source and target must be the same. Additionally, one can specify extra test sets.
python fully_supervised/baseline.py --dataset=domainnet32 --source=clipart --target=clipart\
    --logdir experiments/2021/02.12-32 --augment=CTA\(sm,sm\)\
    --test_extra=clipart,infograph,quickdraw,real,sketch,painting

Domain adaptation

# Baseline: does nothing for unlabeled except running it with labeled as a single batch through the network.
python domain_adaptation/baseline.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)
python domain_adaptation/baseline.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --uratio=1 --augment=\(sm,smc\)

# FixMatch
python domain_adaptation/fixmatch_da.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)
python domain_adaptation/fixmatch_da.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --uratio=1 --augment=\(sm,smc\)

# AdaMatch
python domain_adaptation/adamatch.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)
python domain_adaptation/adamatch.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --uratio=1 --augment=\(sm,smc\)

# NoisyStudent
## Teacher
python domain_adaptation/noisy_student.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --pseudo_label_th=0.9 --augment=CTA\(sm,sm\)\
    --id=0
## Student
python domain_adaptation/noisy_student.py --dataset=domainnet32 --source=clipart --target=quickdraw\
    --logdir experiments/2021/02.12-32 --pseudo_label_th=0.9 --augment=CTA\(sm,sm\) --id=1\
    --pseudo_label_file=experiments/2021/03.1-32/DA/domainnet32/clipart/quickdraw/CTA\(sm,sm\)/NoisyStudent/archwrn28-2_batch64_lr0.03_lr_decay0.25_wd0.001/0/predictions.npy

# MCD
python domain_adaptation/mcd.py --dataset=domainnet64 --source=clipart --target=quickdraw\
    --uratio=1 --arch=wrn28-2 --augment='CTA(sm,sm)' --train_mimg=8 --logdir experiments/2021/02.12-32 --lr_decay 0.25

Multi-source domain adaptation

Use no in front of the subdomain to use all domains but the one concerned.

python domain_adaptation/adamatch.py --dataset=domainnet32 --source=no_quickdraw --target=quickdraw\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)

Semi-Supervised Domain adaptation

# FixMatch
python semi_supervised_domain_adaptation/fixmatch_da.py --dataset=domainnet32 --source=clipart\
    --target=quickdraw\(10,seed=1\)\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)

# AdaMatch
python semi_supervised_domain_adaptation/adamatch.py --dataset=domainnet32 --source=clipart\
    --target=quickdraw\(10,seed=1\)\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)

Semi-Supervised Learning

# FixMatch
python semi_supervised/fixmatch_da.py --dataset=domainnet32_quickdraw\(10,seed=1\)\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)

# AdaMatch
python semi_supervised/adamatch.py --dataset=domainnet32_quickdraw\(10,seed=1\)\
    --logdir experiments/2021/02.12-32 --uratio=3 --augment=CTA\(sm,sm\)

Tensorboard

tensorboard --logdir experiments

adamatch's People

Contributors

rolloff avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

adamatch's Issues

Security Policy violation Binary Artifacts

This issue was automatically created by Allstar.

Security Policy Violation
Project is out of compliance with Binary Artifacts policy: binaries present in source code

Rule Description
Binary Artifacts are an increased security risk in your repository. Binary artifacts cannot be reviewed, allowing the introduction of possibly obsolete or maliciously subverted executables. For more information see the Security Scorecards Documentation for Binary Artifacts.

Remediation Steps
To remediate, remove the generated executable artifacts from the repository.

First 10 Artifacts Found

  • fully_supervised/pycache/init.cpython-38.pyc
  • fully_supervised/libml/pycache/init.cpython-38.pyc
  • fully_supervised/libml/pycache/models.cpython-38.pyc
  • fully_supervised/libml/pycache/train.cpython-38.pyc
  • fully_supervised/libml/pycache/util.cpython-38.pyc
  • fully_supervised/libml/augment/pycache/init.cpython-38.pyc
  • fully_supervised/libml/augment/pycache/augment.cpython-38.pyc
  • fully_supervised/libml/augment/pycache/core.cpython-38.pyc
  • fully_supervised/libml/augment/pycache/ctaugment.cpython-38.pyc
  • fully_supervised/libml/data/pycache/init.cpython-38.pyc
  • Run a Scorecards scan to see full list.

Additional Information
This policy is drawn from Security Scorecards, which is a tool that scores a project's adherence to security best practices. You may wish to run a Scorecards scan directly on this repository for more details.


Allstar has been installed on all Google managed GitHub orgs. Policies are gradually being rolled out and enforced by the GOSST and OSPO teams. Learn more at http://go/allstar

This issue will auto resolve when the policy is in compliance.

Issue created by Allstar. See https://github.com/ossf/allstar/ for more information. For questions specific to the repository, please contact the owner or maintainer.

Why are your settings not in accordance with the standards of ssl, UDA and SSDA?

Thank you for your excellent work. I'm curious why you didn't keep the same settings as before to experiment? For example, in the SSL field, the commonly used data sets are cifar10 and cifar100, as well as SVHN and STL-10. The commonly used in UDA are office-home and office31. The commonly used in SSDA are Domainnet and Officehome. If you change the settings, it is difficult to compare fairly with previous work.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.