Comments (12)
Hey @himkt , thanks for the prompt response!
(we can define an arbitrary objective function)
(https://stackoverflow.com/questions/63224426/how-can-i-cross-validate-by-pytorch-and-optuna)
Indeed... I'll give it a try :)
The PR of adding cross validation in AllenNLP is ongoing, I'll support it in allennlp-optuna after the PR being merged.
Sounds wonderful. Thank you!
from allennlp-optuna.
Hi @himkt , thanks for the tips.
I found the problem in my code. I used a trainer
in my objective
function but I didn't update the serialization_dir
accordingly. The first fold will store a model in it and for all the folds after it just retake the stored model and not train new ones.
trainer = GradientDescentTrainer(
model=model,
optimizer=optimizer,
data_loader=train_data_loader,
validation_data_loader=validation_data_loader,
validation_metric=val_metrics, #metrics will be summed to make the is_best decision
patience=None, # `patience=None` since it could conflict with AllenNLPPruningCallback
num_epochs=EPOCHS,
cuda_device=CUDA_DEVICE,
serialization_dir=serialization_dir,
callbacks=[AllenNLPPruningCallback(trial, "validation_accuracy_doc")],
)
metrics = trainer.train()
from allennlp-optuna.
Thanks @chuyuanli !
I'm trying to adapt it to the tune command, to take advantage of the integration and AllenNLP config files.
from allennlp-optuna.
Hello @chuyuanli,
Without allennlp-optuna, we can run optimization+cross validation by following the example provided in the link you referred. (we can define an arbitrary objective function)
(https://stackoverflow.com/questions/63224426/how-can-i-cross-validate-by-pytorch-and-optuna)
However, it is non-trivial how to implement cross-validation with allennlp-optuna. allennlp-optuna is a simple wrapper of AllenNLP & Optuna, I don't have a plan to support cross-validation in this library. The PR of adding cross validation in AllenNLP is ongoing, I'll support it in allennlp-optuna after the PR being merged.
from allennlp-optuna.
Hi again :)
I tried with the code proposed here, however I think it does not strictly calculate the average of Kfolds.
If I'am not wrong, for each trial the objective
fonction only optimize the very first fold, and for other folds it just copies the same score.
I first found it fishy coz it didn't took long to finish optimizing all the folds. So I print out the score for each fold:
0, 1, 2, 3, 4
are my 5 folds and 0.7142857142857143
is the accuracy score obtained by the fold0.
[I 2021-04-18 18:27:41,372] A new study created in RDB with name: train4-18-cv
0
objective-cv.py:137: ExperimentalWarning: AllenNLPPruningCallback is experimental (supported from v2.0.0). The interface can change in the future.
callbacks=[AllenNLPPruningCallback(trial, "validation_accuracy_doc")],
accuracy_turn: 0.1920, accuracy_block: 0.5070, accuracy_doc: 0.4231, loss_turn: 1.6889, loss_block: 0.6916, loss_doc: 0.6997, batch_loss: 1.0244, loss: 1.0020 ||: 100%|##########| 26/26 [00:11<00:00, 2.28it/s]
accuracy_turn: 0.1572, accuracy_block: 0.3879, accuracy_doc: 0.7143, loss_turn: 1.8102, loss_block: 0.6991, loss_doc: 0.6260, batch_loss: 1.0528, loss: 1.0155 ||: 100%|##########| 7/7 [00:01<00:00, 4.25it/s]
accuracy_turn: 0.1890, accuracy_block: 0.4958, accuracy_doc: 0.4615, loss_turn: 1.6335, loss_block: 0.7201, loss_doc: 0.6732, batch_loss: 1.0136, loss: 1.0082 ||: 100%|##########| 26/26 [00:11<00:00, 2.29it/s]
accuracy_turn: 0.1585, accuracy_block: 0.3848, accuracy_doc: 0.7143, loss_turn: 1.7802, loss_block: 0.6995, loss_doc: 0.6264, batch_loss: 1.0431, loss: 1.0089 ||: 100%|##########| 7/7 [00:01<00:00, 4.21it/s]
0.7142857142857143
1
0.7142857142857143
2
0.7142857142857143
3
0.7142857142857143
4
0.7142857142857143
So after fold0, the rest didn't go through optimizing trial.
I feel like it's the reason of optuna.trial.TrialState
or of such, meaning that once a trial is complete it won't test on other data... but for cross-validation, we expect that for one trial it repeats n times and tests on all folds. Not sure just a guess ^^
I'm not sure if I'll dig deeper in this topic, as you say it's non-trivial. I have some "silly" but doable ways to get the average of cross-validation, I'll make it happen first :)
from allennlp-optuna.
I also test the colab notebook, objective()
returned different values at each round in the trial.
https://colab.research.google.com/gist/himkt/664174dc6d3bf6ddb51a023d19760995/copy-of-optuna-pytorch-cross-validation.ipynb?authuser=1
and for other folds it just copies the same score.
I think it is not the case. I'm wondering the accuracy could be 0.7143 in the specific situation (e.g. the model predicted one category for all the example).
p.s. One thing I noticed in your notebook: it would be better to call trial.suggest_xxx
in objective_cv
.
Before
def objective(trial, train_loader, valid_loader):
# Generate the model.
model = define_model(trial).to(DEVICE)
# Generate the optimizers.
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
After
def objective(trial, train_loader, valid_loader, optimizer_name, lr):
# Generate the model.
model = define_model(trial).to(DEVICE)
# Generate the optimizers.
optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)
....
def objective_cv(trial):
# Get the MNIST dataset.
dataset = datasets.MNIST(DIR, train=True, download=True, transform=transforms.ToTensor())
# Sample hyperparameters here!!
optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
from allennlp-optuna.
If you have some further questions about Optuna, please open an issue in Optuna!
from allennlp-optuna.
Hi @chuyuanli ! Do you have an updated functioning code, following @himkt suggestions? Do you mind sharing?
Thanks!
from allennlp-optuna.
Hi @pvcastro !
Actually the updated code is from @himkt :) Have you tried this one? https://colab.research.google.com/gist/himkt/664174dc6d3bf6ddb51a023d19760995/copy-of-optuna-pytorch-cross-validation.ipynb?authuser=1
Best,
from allennlp-optuna.
Great, thanks @chuyuanli , I'll take a look.
Have you run your own experiments with this code? Do you think you were able to get a better model, that generalized better to your test set? (Considering the test set was not used in the cross validation).
from allennlp-optuna.
No problem @pvcastro
Yes I ran with this code (slight different cause I used a different model), but the objective
and cv_objective
functions were the same. I think that for now, if you want to use cv within optuna, this script is quite good already.
from allennlp-optuna.
Let me close this issue but please reopen if you have some questions. 😸
from allennlp-optuna.
Related Issues (17)
- include package is not being passed during distributed training HOT 10
- PruningCallback doesn't work HOT 6
- Erroneous poetry run commands?
- Clarify License HOT 2
- AllenNLP v2
- jsonnet_evaluate_file HOT 13
- retrain runtime error: fail to load study HOT 2
- KeyError: 'attributes' for optuna-param-path config file HOT 11
- Using SuccessiveHalvingPruner HOT 11
- retrain command not getting environment values HOT 4
- Trials with repeated set of hyperparameters HOT 4
- Different results from `allennlp tune` and `allennlp retrain` with transformers HOT 4
- Support multi-objective optimization
- Passing overrides to tune command? HOT 1
- Trial X failed because of the following error: ValueError('nan loss encountered') HOT 3
- Provide default/good hyperparameters to start search HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from allennlp-optuna.