Code Monkey home page Code Monkey logo

mattdl / clsurvey Goto Github PK

View Code? Open in Web Editor NEW
187.0 9.0 23.0 169 KB

Continual Hyperparameter Selection Framework. Compares 11 state-of-the-art Lifelong Learning methods and 4 baselines. Official Codebase of "A continual learning survey: Defying forgetting in classification tasks." in IEEE TPAMI.

Home Page: https://ieeexplore.ieee.org/abstract/document/9349197

License: Other

Python 98.82% Shell 1.18%
continual-learning tpami defy-forgetting classification-tasks deep-learning neural-networks framework hyperparameter-tuning inaturalist tinyimagenet

clsurvey's Introduction

Matthias De Lange ~ Professional Portfolio

clsurvey's People

Contributors

mattdl avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

clsurvey's Issues

Bug about finetuning

when I run the experiment of fintuning, the code appears bugs. Can you publish main_tingimgnet.sh document for fintune?
Thank you.

How to run the entire pipeline?

Hi~
Thanks for your great job. This code is a great help to the community and people who just started.
I have checked the main_tinyimagenet.sh and only find the run with SI and EBLL. Forgive me for just being a newcomer, and I am lost in the code.
I want to know how to run the whole pipeline with 11 methods reported in your paper.
Thanks!

Request for a template for evaluating other sotas

Thanks for the share of the CL study frame! However I wonder how to use iCarl with the framework. I have modified the main_tinyimagenet.sh file via adding static_hyperparams term to the arguments, but still got several errors. Could you please provide a shell file for running iCarl ? Appreciate that.

Multihead classification

Hi, thank you for the great baseline repo.
I am trying to setup each dataset (with different number of classes) as a task and perform continual learning. However I am a little lost regarding how the models are handling the multiple heads with potentially different output features. Do you have any suggestions on how this might be addressed?

Currently I am using something like:

class MultiTaskModel(ResNet):
    def __init__(self):
        super(MultiTaskModel, self).__init__(BasicBlock, [2,2,2,2])
        resnet = torchvision.models.resnet34(pretrained=True)
        self.in_feature = resnet.fc.in_features
        self.tasks = []
        self.fc = None

        # add all layers that are not fc or classifier to the model
        self.shared = nn.Sequential()
        for name, module in resnet.named_children():
            if name != 'fc' and name != 'classifiers':
                self.shared.add_module(name, module)
        # self.classifiers.append(resnet.fc)

    
    def set_task(self, task):
        print("Setting task to", task)
        self.tasks.append(task)
        print(f"tasks are {self.tasks}")
        print(f"task index is {task_list.index(task)}")
        # add a new fc layer for the new task
        self.fc = nn.Linear(self.in_feature, classes_per_task[task_list.index(task)])
        self.fc.apply(kaiming_normal_init)
        print(f"fc is {self.fc}")
        
    def forward(self, x):
        x = self.shared(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

But unable to replicate the results for some methods, such as the EWC . If I use the default lambda = 400 as in the repo, the loss becomes Nan .
Currently, I am maintaining the sequence of recogseq dataset but observing Catastrophic forgetting for EWC and LWF. To be precise, I calculated the performance of each task after training the last task.

This is my training loop:

for task in dataset_names:
    train_loader, val_loader, _, _, _ = get_dataloaders(task, 0.8, batch_size)
    all_train_loaders[task] = train_loader
    all_val_loaders[task] = val_loader

for idx, task in tqdm.tqdm(enumerate(task_list)):
    current_train_loader = all_train_loaders[task]
    current_val_loader = all_val_loaders[task]
    model = MultiTaskModel().to(device)
    if idx > 0:
        print(f"Previous task: {task_list[idx-1]}")
        ckpt = torch.load(f"epoch_{task_list[idx-1]}.pth.tar")
        model.load_state_dict(ckpt['state_dict'])
        model = model.to(device)

    start_time = time.time()
    model, acc = fine_tune_EWC_acuumelation(current_train_loader, current_val_loader, model, reg_lambda=1, num_epochs=num_epochs, lr=0.008, batch_size=batch_size, weight_decay=0, current_task=task)
    ```

Here's a sample output : 

```bash
Model loaded for task svhn

Performance of previous task: flowers
fc set to Linear(in_features=512, out_features=103, bias=True)
Accuracy of the network on the 1311 test images: 3.75

Performance of previous task: scenes
fc set to Linear(in_features=512, out_features=67, bias=True)
Accuracy of the network on the 3123 test images: 5.013020833333333

Performance of previous task: birds
fc set to Linear(in_features=512, out_features=201, bias=True)
Accuracy of the network on the 2358 test images: 0.9982638888888888

Performance of previous task: cars
fc set to Linear(in_features=512, out_features=196, bias=True)
Accuracy of the network on the 1621 test images: 0.6875

Performance of previous task: aircraft
fc set to Linear(in_features=512, out_features=56, bias=True)
Accuracy of the network on the 2000 test images: 15.574596774193548

Performance of previous task: chars
fc set to Linear(in_features=512, out_features=63, bias=True)
Accuracy of the network on the 12599 test images: 43.247767857142854

Performance of previous task: svhn
fc set to Linear(in_features=512, out_features=10, bias=True)
Accuracy of the network on the 26032 test images: 96.13223522167488

Running accuracy for task svhn is [3.75, 5.013020833333333, 0.9982638888888888, 0.6875, 15.574596774193548, 43.247767857142854, 96.13223522167488]
Mean accuracy for task svhn is 23.629054939319072

Hello,I what to know some detail about code

Excuse me, I'd like to ask some questions.

  1. Is it true that all data sets are kept at the same size in the experiment?
  2. Which labels are adopted in the VOC dataset?
  3. Does the experiment provide a mutli-head setting for each method?

If you have time, can you help me answer it? Thank you very much.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.