Code Monkey home page Code Monkey logo

dp4tl's Introduction

Dataset Pruning for Transfer Learning [NeurIPS 2023]

Welcome to the official implementation of the paper Selectivity Drives Productivity: Efficient Dataset Pruning for Enhanced Transfer Learning. This work introduces two innovative dataset pruning techniques: Label Mapping (LM) and Feature Mapping (FM), leveraging source-target domain mapping.

Requirements

You can install the necessary Python packages with:

pip install -r requirements.txt

We remark that to accelerate model training, this code repository is built based on FFCV and we refer its installation instructions to its official website. In this work, we build our argument system via fastargs, and we provide a revised version here. The installation of the latest fastargs is automatically handled by the command above.

Datasets

We studied 9 commonly used datasets for transfer learning. We use FFCV to accelerate the data loading and preprocessing. For most datasets, we provide the preprocessed data (.beton) in this link. Please download the data and put them in the data folder. For the datasets that are not provided, they are automatically downloaded by PyTorch.

For Flowers102, DTD, UCF101, Food101, EuroSAT, OxfordPets, StanfordCars and SUN397, we use datasets split configuration in CoOp. For other datasets we use official ones provided by pytorch.

Code Structure

The source code is organized as follows:

  • configs: contains the default parameters for each dataset
  • src: contains the source code for the proposed methods
    • algorithm: contains the mathematical algorithms used for our method or the baselines
    • auxiliary: contains the executable files to generate the intermediate results, e.g., the pruned data, the image features, etc.
    • data: contains the data loader for each dataset
    • experiments: contains the main executable files to run the experiments
    • tools: contains the tools and utilities for the experiments
  • arguments: contains the data arguments

Usage

In this section, we provide the instructions to reproduce the results in our paper.

Pretrain on ImageNet

We first pretrain the surrogate model (ResNet-18) on ImageNet using the following command:

python src/experiment/imagenet_train_from_scratch.py --config-file configs/imagenet_train_from_scratch/rn18_16.json 

You can change the type of the surrogate model by changing the --network.architecture argument.

Prune the Source Dataset using LM

We then prune the source dataset by 10% to 90% with a step size of 10% using LM with the following command:

python src/auxiliary/lm_selection_for_imagenet.py --cfg.data_path PATH_TO_DOWNSTREAM_TRAINING_DATA --cfg.source_train_label_path PATH_TO_IMAGNET_TRAINING_LABLE --cfg.source_val_label_path PATH_TO_IMAGNET_VALIDATION_LABLE --cfg.architecture resnet18 --cfg.pretrained_ckpt PATH_TO_PRETRAINED_CKPT --cfg.retain_class_nums 900,800,700,600,500,400,300,200,100 --cfg.write_path files/class_selection/oxfordpets 

Please note the first parameter refers to the path to the training data (.beton file) of the target data. The second and third parameters refer to the path to the generated label index for each data sample of ImageNet. This will be automatically downloaded when downloading the ImageNet .beton files. Please refer to the dataset section.

You can also generate your own label index files using the src/auxiliary/get_label_and_indices.py file.

Prune the Source Dataset using FM

We can also prune the source dataset using FM. Unlike LM, we need to first determine the features of each data sample of both the source and target dataset with the surrogate model. Below we provide an example of how to generate the features of the source dataset.

python src/auxiliary/feature_gen.py --cfg.data_path PATH_TO_IMAGENET_TRAINING_DATA --cfg.dataset imagenet --cfg.architecture resnet18 --cfg.pretrained_ckpt PATH_TO_PRETRAINED_CKPT --cfg.write_path PATH_TO_FEATURES

Next, with the features of the source and target dataset, we can prune the source dataset using FM with the following command:

python src/auxiliary/fm_selection_for_imagenet.py --dataset.src_train_fx_path PATH_TO_SOURCE_TRAINING_FEATURES --dataset.tgt_train_fx_path PATH_TO_TARGET_TRAINING_FEATURES --dataset.src_train_id_path PATH_TO_SOURCE_DATA_CLUSTER_MAPPING --dataset.src_val_id_path PATH_TO_TARGET_DATA_CLUSTER_MAPPING 

Note that the first two parameters are generated by the src/auxiliary/feature_gen.py file. The last two parameters are the clustering results. This in general indicates which cluster each data sample belongs to.

Model Pretrain with Pruned Source Dataset

We then pretrain the large model with the pruned source dataset obtained by either LM or FM. We use the same file to pretrain the model as the one used to pretrain the surrogate model. The only difference is that we need to specify the argument --dataset.prune 1 to indicate that the source dataset is pruned. Besides, we need to input the selected training and testing data indices --dataset.indices.training and --dataset.indices.testing. Below we provide an example of how to pretrain the ResNet-101 with the pruned source dataset obtained by LM.

python src/experiment/imagenet_train_from_scratch.py --config-file configs/imagenet_train_from_scratch/rn18_16.json --dataset.prune 1 --dataset.indices.training files/class_selection/oxfordpets_flm_train_top${cls_num}.indices --dataset.indices.testing files/class_selection/oxfordpets_flm_val_top${cls_num}.indices

Downstream Finetune with the Pretrained Model

We then finetune the pretrained model on the target dataset. Below we provide an example of how to finetune the pretrained ResNet-101 on OxfordPets.

python src/experiment/imagenet_transfer_to_downstream.py --config-file configs/imagenet_transfer_to_downstream/oxfordpets_rn101_ff.json --dataset.train_path ./data/oxfordpets/ffcv/train_400_10_90.beton --dataset.test_path ./data/oxfordpets/ffcv/test_400_10_90.beton --network.pretrained_ckpt PATH_TO_PRETRINED_CKPT --exp.identifier oxfordpets_rn101_ff

Downstream Train from Scratch

We also provide the option to train the model from scratch on the target dataset. Below we provide an example of how to train the ResNet-101 from scratch on OxfordPets.

python src/experiment/downstream_train_from_scratch.py --config-file configs/downstream_train_from_scratch/oxfordpets_rn101.json --dataset.train_path ../data/oxfordpets/ffcv/train_400_10_90.beton --dataset.test_path ../data/oxfordpets/ffcv/test_400_10_90.beton

dp4tl's People

Contributors

normaluhr avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar

dp4tl's Issues

ImageNet training

Hi, thank you for providing the code for this great work!

I'm right now trying to train the surrogate model by imagenet_train_from_scratch.py, while I also found out there is also adv_imagenet_train_from_scratch.py. What is this file for and how is it different from the former file?

Thanks,

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.