Comments (7)
Got back from vacation (happy new year!).
Looking into the baseline, the default flag values we use in the CIFAR script are indeed the ones used for reporting the #s in the CIFAR-100 leaderboard. The only flags we override in commandline is: dataset=cifar100,cifar100_c_path=$CIFAR100_C_PATH
.
You mention using only 1 GPU, which sounds like the likely culprit. Have you verified that the global batch size you're using for each gradient step is equivalent to the default script's setup? This is FLAGS.num_cores * FLAGS.per_core_batch_size.
from uncertainty-baselines.
Happy new year!
When we run python3 /dashavoronkova8/ens_project/mimo/cifar.py --output_dir '/dashavoronkova8/ens_project/mimo/cifar' --seed 0 --use_gpu --dataset cifar100 --per_core_batch_size 512 --num_cores 1 --batch_repetitions 4 --corruptions_interval -1 --ensemble_size 3 --width_multiplier 10 --base_learning_rate 0.1 --train_epochs 250 --lr_decay_ratio 0.1 --lr_warmup_epochs 1 --num_bins 15 --input_repetition_probability 0. --l2 3e-4 --checkpoint_interval 50
and get 80.76% accuracy, our deviations from the default parameters are:
- ensemble_size 3 (vs default 4), according to the paper the accuracy is even better at this value
- per_core_batch_size 512 (vs default 64)
- num_cores 1 (vs default 8), along with per_core_batch_size=512 this results in total batch size of 512, as in the paper and in the original script
- use_gpu True (vs default False), because we don't use a TPU
- lr_decay_ratio 0.1 (vs default 0.2), as reported in the paper (is 0.2 better to use?)
- corruptions_interval -1 (vs default 250), to skip evaluation on corrupted data
from uncertainty-baselines.
Hi, thanks for the interest in the paper!
It could be the lr_decay_ratio
. The one in the code (lr_decay_ratio=0.2) is definitely the correct one. It's possible that I put the incorrect value in the paper, I will update it ASAP.
As Dustin said, the results that we report come from running the code flags with dataset=cifar100,cifar100_c_path=$CIFAR100_C_PATH
with seed
varying from 0 to 9 and averaging the 10 runs.
Did you also try cifar10? Did it work as expected?
from uncertainty-baselines.
With CIFAR-10 we had the accuracy of 96.17% vs 96.40% in the paper, with the same parameters as for CIFAR-100 - should be outside the standard deviation range. Thanks, we'll try the 0.2 value!
from uncertainty-baselines.
Another potential issue is that in the default setup, batch norm is applied separately over the 64 data points per core. BN across the full large batch size removes the implicit regularization benefits which can sometimes be important.
from uncertainty-baselines.
ensemble_size 3 (vs default 4), according to the paper the accuracy is even better at this value
lr_decay_ratio 0.1 (vs default 0.2), as reported in the paper (is 0.2 better to use?)
That's a good catch. I relaunched the code and here's the results, each averaged over 10 seeds:
default flags + ensemble_size=3,dataset=cifar100,cifar100_c_path=$CIFAR100_C_PATH
test_log_likelihood | test_accuracy | test_ece | test_nll_mean_corrupted | test_accuracy_mean_corrupted | test_ece_mean_corrupted |
---|---|---|---|---|---|
-0.683279 | 0.8195 | 0.020577 | 2.285121 | 0.537657 | 0.129027 |
default flags + dataset=cifar100,cifar100_c_path=$CIFAR100_C_PATH
test_log_likelihood | test_accuracy | test_ece | test_nll_mean_corrupted | test_accuracy_mean_corrupted | test_ece_mean_corrupted |
---|---|---|---|---|---|
-0.685981 | 0.815759 | 0.021685 | 2.25507 | 0.533822 | 0.112056 |
default flags + ensemble_size=3
(so cifar-10)
test_log_likelihood | test_accuracy | test_ece | test_nll_mean_corrupted | test_accuracy_mean_corrupted | test_ece_mean_corrupted | |
---|---|---|---|---|---|---|
-0.125419 | 0.96287 | 0.011104 | 0.909436 | 0.767745 | 0.107714 |
So yeah, the table for both cifar datasets is in fact reported with ensemble_size=3
. lr_decay_ratio=0.2
as per default. I sent a PR to fix ensemble_size
's default (#269).
from uncertainty-baselines.
Closing, feel free to reopen if needed!
from uncertainty-baselines.
Related Issues (20)
- Question: a good strategy to estimate covariance matrix for SNGP when there are no clear "epochs" HOT 1
- Implement eval_model.py script using a Pytorch ResNet-50 MC Dropout trained model - Diabetic Retinopathy Detection
- Question: SNGP initialization of random features HOT 2
- Question about SNGP BERT flags HOT 2
- input_utils to importing in Colab notebook
- plex_vit_demo : Failed to get url-imagenetv2public
- tensorflow_federated no attribute 'Popen' error
- ModuleNotFoundError: No module named 'official.nlp.bert' HOT 1
- Unsupported data type for TPU String HOT 3
- Where is keras_tokenizer.json supposed to come from for clinc_intent.py?
- Question about SNGP differences from the paper HOT 1
- Pretrained models or logits of these (Recalibration)
- mixup hyperparameters
- Remove need for registration mechanisms?
- bert_dropout_test is failing
- Clean up reliance on local utility files in the JFT directory
- Example code for computing Mahalanobis score in TF-2 HOT 4
- question on kl rescaling! HOT 1
- feature- baselines for vision tasks
- SNGP inference speed
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from uncertainty-baselines.