Code Monkey home page Code Monkey logo

fixmatch-pytorch-1's Introduction

FixMatch-pytorch

Unofficial pytorch code for "FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence," NeurIPS'20.
This implementation can reproduce the results (CIFAR10 & CIFAR100), which are reported in the paper.
In addition, it includes trained models with semi-supervised and fully supervised manners (download them on below links).

Requirements

  • python 3.6
  • pytorch 1.6.0
  • torchvision 0.7.0
  • tensorboard 2.3.0
  • pillow

Results: Classification Accuracy (%)

In addition to the results of semi-supervised learning in the paper, we also attach extra results of fully supervised learning (50000 labels, sup only) + consistency regularization (50000 labels, sup+consistency).
Consistency regularization also improves the classification accuracy, even though the labels are fully provided.
Evaluation is conducted by EMA (exponential moving average) of models in the SGD training trajectory.

CIFAR10

#Labels 40 250 4000 sup + consistency sup only
Paper (RA) 86.19 ± 3.37 94.93 ± 0.65 95.74 ± 0.05 - -
kekmodel - - 94.72 - -
valencebond 89.63(85.65) 93.08 94.72 - -
Ours 87.11 94.61 95.62 96.86 94.98
Trained Moels checkpoint checkpoint checkpoint checkpoint checkpoint

CIFAR100

#Labels 400 2500 10000 sup + consistency sup only
Paper (RA) 51.15 ± 1.75 71.71 ± 0.11 77.40 ± 0.12 - -
kekmodel - - - - -
valencebond 53.74 67.3169 73.26 - -
Ours 48.96 71.50 78.27 83.86 80.57
Trained Moels checkpoint checkpoint checkpoint checkpoint checkpoint

In the case of CIFAR100@40, the result does not reach the paper's result and is out of the confidence interval.
Despite the result, the accuracy with a small amount of labels highly depends on the label selection and other hyperparameters.
For example, we find that changing the momentum of batch normalization can give better results, closed to the reported accuracies.

Evaluation of Checkpoints

Download Checkpoints

In here, we attached some google drive links, which includes training logs and the trained models.
Because of security issues of google drive,
you may fail to download each checkpoint in the result tables by curl/wget.
Then, use gdown to download without the issues.

All checkpoints are included in this directory

Evaluation Example

After unzip the checkpoints into your own path, you can run

python eval.py --load_path saved_models/cifar10_400/model_best.pth --dataset cifar10 --num_classes 10

How to Use to Train

Important Notes

For the detailed explanations of arguments, see here.

  • In training, the model is saved at os.path.join(args.save_dir, args.save_name), after making new directory. If there already exists the path, the code will raise an error to prevent overwriting of trained models by mistake. If you want to overwrite the files, give --overwrite.
  • By default, FixMatch uses hard (one-hot) pseudo labels. If you want to use soft pseudo labels and sharping (T), give --hard_label False. Also, you can adjust the sharping parameters --T (YOUR_OWN_VALUE) .
  • This code assumes 1 epoch of training, but the number of iterations is 2**20.
  • If you restart the training, use --resume --load_path [YOUR_CHECKPOINT_PATH]. Then, the checkpoint is loaded to the model, and continues to training from the ceased iteration. see here and the related method.
  • We set the number of workers for DataLoader when distributed training with a single node having V100 GPUs x 4 is used.
  • If you change the confidence threshold to generate masks in consistency regularization, change --p_cutoff.
  • With 4 GPUs, for the fast update, running statistics of BN is not gathered in distributed training. However, a larger number of GPUs with the same batch size might affect overall accuracies. Then, you can 1) replace BN to syncBN (see here) or 2) use torch.distributed.all_reduce for BN buffers before this line.
  • We checked that syncBN slightly improves accuracies, but the training time is much increased. Thus, this code doesn't include it.

Use single GPU

python train.py --rank 0 --gpu [0/1/...] @@@other args@@@

Use multi-GPUs (with DataParallel)

python train.py --world-size 1 --rank 0 @@@other args@@@

Use multi-GPUs (with distributed training)

When you use multi-GPUs, we strongly recommend using distributed training (even with a single node) for high performance.

With V100x4 GPUs, CIFAR10 training takes about 16 hours (0.7 days), and CIFAR100 training takes about 62 hours (2.6 days).

  • single node
python train.py --world-size 1 --rank 0 --multiprocessing-distributed @@@other args@@@
  • multiple nodes (assuming two nodes)
# at node 0
python train.py --world-size 2 --rank 0 --dist_url [rank 0's url] --multiprocessing-distributed @@@@other args@@@@
# at node 1
python train.py --world-size 2 --rank 1 --dist_url [rank 0's url] --multiprocessing-distributed @@@@other args@@@@

Run Examples (with single node & multi-GPUs)

CIFAR10

python train.py --world-size 1 --rank 0 --multiprocessing-distributed --num_labels 4000 --save_name cifar10_4000 --dataset cifar10 --num_classes 10

CIFAR100

python train.py --world-size 1 --rank 0 --multiprocessing-distributed --num_labels 10000 --save_name cifar100_10000 --dataset cifar100 --num_classes 100 --widen_factor 8 --weight_decay 0.001

To reproduce the results on CIFAR100, the --widen_factor has to be increased to --widen_factor=8. (see this issue in the official repo.), and --weight_decay=0.001.

Change the backbone networks

In this repo, we use WideResNet with LeakyReLU activations, implemented in models/net/wrn.py.
When you use the WideResNet, you can change widen_factor, leaky_slope, and dropRate by the argument changes.

For example,
If you want to use ReLU, just use --leaky_slope 0.0 in arugments.

Also, we support to use various backbone networks in torchvision.models.
If you want to use other backbone networks in torchvision, change the arguments
--net [MODEL's NAME in torchvision] --net_from_name True

when --net_from_name True, other model arguments are ignored except --net.

Mixed Precision Training

If you want to use mixed-precision training for speed-up, add --amp in the argument.
We checked that the training time of each iteration is reduced by about 20-30 %.

Tensorboard

We trace various metrics, including training accuracy, prefetch & run times, mask ratio of unlabeled data, and learning rates. See the details in here. You can see the metrics in tensorboard

tensorboard --logdir=[SAVE PATH] --port=[YOUR PORT]


Collaborator

fixmatch-pytorch-1's People

Contributors

leedoyup avatar yeongjae avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.