Code Monkey home page Code Monkey logo

continual-learning-benchmark's Introduction

Continual-Learning-Benchmark

Evaluate three types of task shifting with popular continual learning algorithms.

This repository implemented and modularized following algorithms with PyTorch:

  • EWC: code, paper (Overcoming catastrophic forgetting in neural networks)
  • Online EWC: code, paper
  • SI: code, paper (Continual Learning Through Synaptic Intelligence)
  • MAS: code, paper (Memory Aware Synapses: Learning what (not) to forget)
  • GEM: code, paper (Gradient Episodic Memory for Continual Learning)
  • (More are coming)

All the above algorithms are compared to following baselines with the same static memory overhead:

Key tables:

If this repository helps your work, please cite:

@inproceedings{Hsu18_EvalCL,
  title={Re-evaluating Continual Learning Scenarios: A Categorization and Case for Strong Baselines},
  author={Yen-Chang Hsu and Yen-Cheng Liu and Anita Ramasamy and Zsolt Kira},
  booktitle={NeurIPS Continual learning Workshop },
  year={2018},
  url={https://arxiv.org/abs/1810.12488}
}

Preparation

This repository was tested with Python 3.6 and PyTorch 1.0.1.post2. Part of the cases is tested with PyTorch 1.5.1 and gives the same results.

pip install -r requirements.txt

Demo

The scripts for reproducing the results of this paper are under the scripts folder.

  • Example: Run all algorithms in the incremental domain scenario with split MNIST.
./scripts/split_MNIST_incremental_domain.sh 0
# The last number is gpuid
# Outputs will be saved in ./outputs
  • Eaxmple outputs: Summary of repeats
===Summary of experiment repeats: 3 / 3 ===
The regularization coefficient: 400.0
The last avg acc of all repeats: [90.517 90.648 91.069]
mean: 90.74466666666666 std: 0.23549144829955856
  • Eaxmple outputs: The grid search for regularization coefficient
reg_coef: 0.1 mean: 76.08566666666667 std: 1.097717733400629
reg_coef: 1.0 mean: 77.59100000000001 std: 2.100847606721314
reg_coef: 10.0 mean: 84.33933333333334 std: 0.3592671553160509
reg_coef: 100.0 mean: 90.83800000000001 std: 0.6913701372395712
reg_coef: 1000.0 mean: 87.48566666666666 std: 0.5440161353816179
reg_coef: 5000.0 mean: 68.99133333333333 std: 1.6824762174313899

Usage

  • Enable the grid search for the regularization coefficient: Use the option with a list of values, ex: -reg_coef 0.1 1 10 100 ...
  • Repeat the experiment N times: Use the option -repeat N

Lookup available options:

python iBatchLearn.py -h

Other results

Below are CIFAR100 results. Please refer to the scripts for details.

continual-learning-benchmark's People

Contributors

yenchanghsu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

continual-learning-benchmark's Issues

[Question] Calculation of Importance of the weights for EWC

Thanks for your great work.

I have a question about the regularization method; we take EWC as an example.

In https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/agents/regularization.py#L43-L44, you calculate the importance weights in each batch, but I think the calculation is useless during the training process. The importance weights only need to be calculated after training one task.

What is the reason for that? I think it is time-consuming if we calculate at each batch.

In https://github.com/srvCodes/continual_learning_with_vit/blob/main/src/approach/ewc.py#L117-L132, they calculate the importance weight only at the end of the training process for one task.

Reproducibility issue

Hello,

Thanks for the CIFAR100 scripts. I am having some issues trying to reproduce naive rehearsal results for class incremental learning task. I executed the naive rehearsal 5600 baseline with the following command:

python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid 0 --repeat 5 --incremental_class --optimizer Adam --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 80 120 160 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name Naive_Rehearsal_5600 --lr 0.001

Comparing the result with Naive Rehearsal-C reported in https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/master/fig/results_split_cifar100.png, I obtained the following results:
With pytorch 1.0.0 and torchvision 0.2.2, the final result is around 40%, instead of 51.28.
With pytorch 1.4.0 and torchvision 0.5.0, the final result is around 20%, instead of 51.28.

This is consistently happening with various GPUs (RTX 2080Ti, Titan Xp and Tesla P4/K80).

Could you share the environment configuration you used to obtain reported results on CIFAR100?

Grid search script for hyper parameters of EWC, SI, L2

Hello,

Thanks a lot for uploading the additional scripts for CIFAR-100. I'm trying to replicate the results for various methods through your code. First of all I am very much thankful for the scripts as it helps me avoid the burden of finding the best results through hyper parameter sweep. (Although I would try it in the future.)

One thing I want to ask is how did you perform the search for regularization coefficients for various datasets and architectures? Did you set some intuitive limits for the grid search by looking at the train and test graphs? I'm asking this as I want to replicate this on other datasets lets say CIFAR-10. I can add the scripts for CIFAR-10 and Image net in the future to your repository once I finish experimenting.

Would be great if you share your script for the grid search.

Code of DGR, RtF?

Hello! Thank you for the beautiful code open.

The benchmark results include DGR and RtF. However, there is no code for them.
Can I get DGR and RtF codes for this benchmark?

GEM uses memory data in active training set and updates memory twice

The GEM implementation here inherits from Naive_Rehearsal and calls super(GEM, self).learn_batch(train_loader, val_loader). It therefore uses the learn_batch method of Naive_Rehearsal which uses memory data together with the new data to compute the original gradients before checking conflicts with any memory gradients. From my understanding that is not the intent of the original paper and might affect the training results.

Additionally, Naive_Rehearal's learn_batch method already updates the memory and task_count, but GEM does this a second time once the call returns.

torchvision.datasets.cifar has no training labels, test labels

Hi,
I'm using pytorch=1.1.0 and torchvision=0.3.0. For this current version there is a small bug in the cifar100 and cifar10 data loaders.

It seems to be that cifar class has no attribute training_labels or test_labels which was used in dataloaders/base.py. Also the documentation states no such attributes CIFAR. However MNIST has such attributes.

Is this is a bug or was it working for you for the previous versions? It's strange that torchvision lacks homogenity among datasets.

Lemme know if it's a bug I can raise a pull request

EWC implementation

Hi! Thanks for your awesome code! It really helps!
I have a problem with the implementation of EWC in this repo.

According to the paper of DeepMind, EWC adopt the Fisher information matrix to approximate the
CodeCogsEqn

I think the $p(\mathcal D_A|\theta)$ should be selected according to the ground truth label.
If adopting the maximum value of predictions as
https://github.com/GT-RIPL/Continual-Learning-Benchmark/blob/d78b9973b6ec0059b2d2577872db355ae2489f6b/agents/regularization.py#L133
it may have chances to wrongly get the likelihood of a false class when computing the fisher information.

EWC mini-batch sampling

Hi, Thank you so much for this awesome repo. It's the clearest implementation I found out there :)

I have a question regarding the mini-batch sampling. In the code, it is commented that it gives similar performance to (sub-sampling with batch_size=1, i.e., the correct mathematical way). But I'm worried that they are very different.
So I'm curious to know whether there are papers that used this sampling instead and they confirmed its similar performance?

The reason for my doubt is that in general, the expected value of the squared gradients of log-likelihoods which is an estimator for the diagonal of the Fisher matrix is not the same as the expected squared expected gradients of log-likelihoods.

Thank you for your consideration,
Arash

Request for detailed hyper parameters used

Hello,
Thanks a lot for your repository. I'm trying to reproduce your results for my experiments and used them as baseline. However I have some trouble in reproducing the exact numbers for MNIST data using MLP network that was reported in the paper. I feel the values for the following missing hyper parameters are missing in the paper and the code

  1. L2 regularization coefficient used for L2 baseline.
  2. Was lr decay used ? If yes can you shed some details on it .

It would be great if you can share a config file for the best results you got. Although the paper presents some details about these which is very helpful I feel entire configurations are missing.

Thanks in advance

EWC online weights explosion using SGD

The problem is reproducible by running

python -u iBatchLearn.py --dataset CIFAR100 --train_aug --gpuid 0 --repeat 1 --incremental_class --optimizer SGD --force_out_dim 100 --no_class_remap --first_split_size 20 --other_split_size 20 --schedule 1 --batch_size 128 --model_name WideResNet_28_2_cifar --model_type resnet --agent_type customization --agent_name EWC_online --lr 0.001 --reg_coef 100

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.