Code Monkey home page Code Monkey logo

Comments (12)

chuyuanli avatar chuyuanli commented on June 4, 2024 1

Hey @himkt , thanks for the prompt response!

(we can define an arbitrary objective function)
(https://stackoverflow.com/questions/63224426/how-can-i-cross-validate-by-pytorch-and-optuna)

Indeed... I'll give it a try :)

The PR of adding cross validation in AllenNLP is ongoing, I'll support it in allennlp-optuna after the PR being merged.

Sounds wonderful. Thank you!

from allennlp-optuna.

chuyuanli avatar chuyuanli commented on June 4, 2024 1

Hi @himkt , thanks for the tips.
I found the problem in my code. I used a trainer in my objective function but I didn't update the serialization_dir accordingly. The first fold will store a model in it and for all the folds after it just retake the stored model and not train new ones.

    trainer = GradientDescentTrainer(
        model=model,
        optimizer=optimizer,
        data_loader=train_data_loader,
        validation_data_loader=validation_data_loader,
        validation_metric=val_metrics, #metrics will be summed to make the is_best decision
        patience=None,  # `patience=None` since it could conflict with AllenNLPPruningCallback
        num_epochs=EPOCHS,
        cuda_device=CUDA_DEVICE,
        serialization_dir=serialization_dir,
        callbacks=[AllenNLPPruningCallback(trial, "validation_accuracy_doc")],        
        )
    metrics = trainer.train()

from allennlp-optuna.

pvcastro avatar pvcastro commented on June 4, 2024 1

Thanks @chuyuanli !
I'm trying to adapt it to the tune command, to take advantage of the integration and AllenNLP config files.

from allennlp-optuna.

himkt avatar himkt commented on June 4, 2024

Hello @chuyuanli,

Without allennlp-optuna, we can run optimization+cross validation by following the example provided in the link you referred. (we can define an arbitrary objective function)
(https://stackoverflow.com/questions/63224426/how-can-i-cross-validate-by-pytorch-and-optuna)

However, it is non-trivial how to implement cross-validation with allennlp-optuna. allennlp-optuna is a simple wrapper of AllenNLP & Optuna, I don't have a plan to support cross-validation in this library. The PR of adding cross validation in AllenNLP is ongoing, I'll support it in allennlp-optuna after the PR being merged.

from allennlp-optuna.

chuyuanli avatar chuyuanli commented on June 4, 2024

Hi again :)
I tried with the code proposed here, however I think it does not strictly calculate the average of Kfolds.

If I'am not wrong, for each trial the objective fonction only optimize the very first fold, and for other folds it just copies the same score.
I first found it fishy coz it didn't took long to finish optimizing all the folds. So I print out the score for each fold:
0, 1, 2, 3, 4 are my 5 folds and 0.7142857142857143 is the accuracy score obtained by the fold0.

[I 2021-04-18 18:27:41,372] A new study created in RDB with name: train4-18-cv
0
objective-cv.py:137: ExperimentalWarning: AllenNLPPruningCallback is experimental (supported from v2.0.0). The interface can change in the future.
callbacks=[AllenNLPPruningCallback(trial, "validation_accuracy_doc")],
accuracy_turn: 0.1920, accuracy_block: 0.5070, accuracy_doc: 0.4231, loss_turn: 1.6889, loss_block: 0.6916, loss_doc: 0.6997, batch_loss: 1.0244, loss: 1.0020 ||: 100%|##########| 26/26 [00:11<00:00, 2.28it/s]
accuracy_turn: 0.1572, accuracy_block: 0.3879, accuracy_doc: 0.7143, loss_turn: 1.8102, loss_block: 0.6991, loss_doc: 0.6260, batch_loss: 1.0528, loss: 1.0155 ||: 100%|##########| 7/7 [00:01<00:00, 4.25it/s]
accuracy_turn: 0.1890, accuracy_block: 0.4958, accuracy_doc: 0.4615, loss_turn: 1.6335, loss_block: 0.7201, loss_doc: 0.6732, batch_loss: 1.0136, loss: 1.0082 ||: 100%|##########| 26/26 [00:11<00:00, 2.29it/s]
accuracy_turn: 0.1585, accuracy_block: 0.3848, accuracy_doc: 0.7143, loss_turn: 1.7802, loss_block: 0.6995, loss_doc: 0.6264, batch_loss: 1.0431, loss: 1.0089 ||: 100%|##########| 7/7 [00:01<00:00, 4.21it/s]
0.7142857142857143
1
0.7142857142857143
2
0.7142857142857143
3
0.7142857142857143
4
0.7142857142857143

So after fold0, the rest didn't go through optimizing trial.
I feel like it's the reason of optuna.trial.TrialState or of such, meaning that once a trial is complete it won't test on other data... but for cross-validation, we expect that for one trial it repeats n times and tests on all folds. Not sure just a guess ^^

I'm not sure if I'll dig deeper in this topic, as you say it's non-trivial. I have some "silly" but doable ways to get the average of cross-validation, I'll make it happen first :)

from allennlp-optuna.

himkt avatar himkt commented on June 4, 2024

I also test the colab notebook, objective() returned different values at each round in the trial.
https://colab.research.google.com/gist/himkt/664174dc6d3bf6ddb51a023d19760995/copy-of-optuna-pytorch-cross-validation.ipynb?authuser=1

and for other folds it just copies the same score.
I think it is not the case. I'm wondering the accuracy could be 0.7143 in the specific situation (e.g. the model predicted one category for all the example).


p.s. One thing I noticed in your notebook: it would be better to call trial.suggest_xxx in objective_cv.

Before

def objective(trial, train_loader, valid_loader):

    # Generate the model.
    model = define_model(trial).to(DEVICE)

    # Generate the optimizers.
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)

After

def objective(trial, train_loader, valid_loader, optimizer_name, lr):

    # Generate the model.
    model = define_model(trial).to(DEVICE)

    # Generate the optimizers.
    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)

    ....

def objective_cv(trial):

    # Get the MNIST dataset.
    dataset = datasets.MNIST(DIR, train=True, download=True, transform=transforms.ToTensor())

    # Sample hyperparameters here!!
    optimizer_name = trial.suggest_categorical("optimizer", ["Adam", "RMSprop", "SGD"])
    lr = trial.suggest_float("lr", 1e-5, 1e-1, log=True)

from allennlp-optuna.

himkt avatar himkt commented on June 4, 2024

If you have some further questions about Optuna, please open an issue in Optuna!

from allennlp-optuna.

pvcastro avatar pvcastro commented on June 4, 2024

Hi @chuyuanli ! Do you have an updated functioning code, following @himkt suggestions? Do you mind sharing?
Thanks!

from allennlp-optuna.

chuyuanli avatar chuyuanli commented on June 4, 2024

Hi @pvcastro !
Actually the updated code is from @himkt :) Have you tried this one? https://colab.research.google.com/gist/himkt/664174dc6d3bf6ddb51a023d19760995/copy-of-optuna-pytorch-cross-validation.ipynb?authuser=1

Best,

from allennlp-optuna.

pvcastro avatar pvcastro commented on June 4, 2024

Great, thanks @chuyuanli , I'll take a look.
Have you run your own experiments with this code? Do you think you were able to get a better model, that generalized better to your test set? (Considering the test set was not used in the cross validation).

from allennlp-optuna.

chuyuanli avatar chuyuanli commented on June 4, 2024

No problem @pvcastro
Yes I ran with this code (slight different cause I used a different model), but the objective and cv_objective functions were the same. I think that for now, if you want to use cv within optuna, this script is quite good already.

from allennlp-optuna.

himkt avatar himkt commented on June 4, 2024

Let me close this issue but please reopen if you have some questions. 😸

from allennlp-optuna.

Related Issues (17)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.