Code Monkey home page Code Monkey logo

lure's Introduction

Learn, Unlearn and Relearn: An Online Learning Paradigm for Deep Neural Networks

This repository contains the official implementation of the TMLR paper Learn, Unlearn and Relearn: An Online Learning Paradigm for Deep Neural Networks [Paper] by Vijaya Raghavan T Ramkumar, Elahe Arani and Bahram Zonooz in PyTorch.

Abstract

Deep neural networks (DNNs) are often trained on the premise that the complete training data set is provided ahead of time. However, in real-world scenarios, data often arrive in chunks over time. This leads to important considerations about the optimal strategy for training DNNs, such as whether to fine-tune them with each chunk of incoming data (warm-start) or to retrain them from scratch with the entire corpus of data whenever a new chunk is available. While employing the latter for training can be resource intensive, recent work has pointed out the lack of generalization in warm-start models. Therefore, to strike a balance between efficiency and generalization, we introduce Learn, Unlearn, and Relearn (LURE) an online learning paradigm for DNNs. LURE interchanges between the unlearning phase, which selectively forgets the undesirable information in the model through weight reinitialization in a data-dependent manner, and the relearning phase, which emphasizes learning on generalizable features. We show that our training paradigm provides consistent performance gains across datasets in both classification and few-shot settings. We further show that it leads to more robust and well-calibrated models.

alt text

For more details, please see the Paper and Presentation.

Requirements

The code has been implemented and tested with Python 3.8 and PyTorch 1.12.1. To install the required packages:

$ pip install -r requirements.txt

Training

Run LURE_main.py for training the model in Anytiem framework with selective forgetting on CIFAR10 and CIFAR100. Run ALMA.py for training the model without selective forgetting which is the warm-started model.

$ python .\LURE_main.py --data <data_dir> --log-dir <log_dir> --run <name_of_the_experiment> --dataset cifar10 --arch resnet18 \
--seed 10 --epochs 50 --decreasing_lr 20,40 --batch_size 64 --weight_decay 1e-4 --meta_batch_size 6250 --meta_batch_number 8 --snip_size 0.20 \
--save_dir <save-dir> --sparsity_level 1 -wb --gamma 0.1 --use_snip

For training the model with R-ImageNet,

$ python ./LURE_main.py --data <data_dir> --imagenet_path <imagenet data path> --run <name_of_the_experiment> --dataset restricted_imagenet --arch resnet50 \
--seed 10 --epochs 50 --decreasing_lr 20,40 --batch_size 128 --weight_decay 1e-4 --meta_batch_size 6250 --meta_batch_number 8 --snip_size 0.20 \
--save_dir <save-dir> --sparsity_level 1 -wb --gamma 0.1 --use_snip

Note Use -buffer_replay, -no_replay for training the model with buffer and without buffer data respectively. If no args is given then by default the model is trained in full replay setting.

Reference & Citing this work

If you use this code in your research, please cite the original works [Paper] :

@article{
ramkumar2023learn,
title={Learn, Unlearn and Relearn: An Online Learning Paradigm for Deep Neural Networks},
author={Vijaya Raghavan T Ramkumar and Elahe Arani and Bahram Zonooz},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2023},
url={https://openreview.net/forum?id=WN1O2MJDST}
}

lure's People

Contributors

vijayraven95 avatar elahearani avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.