Code Monkey home page Code Monkey logo

diff-mix's Introduction

Enhance Image Classification Via Inter-Class Image Mixup With Diffusion Model

Paper PDF

Image

Introduction 👋

This repository implements various generative data augmentation strategies using stable diffusion to create synthetic datasets, aimed at enhancing classification tasks.

Requirements

The key packages and their versions are listed below. The code is tested on a single node with 4 NVIDIA RTX3090 GPUs.

torch==2.0.1+cu118
diffusers==0.25.1
transformers==4.36.2
datasets==2.16.1
accelerate==0.26.1
numpy==1.24.4

Datasets

For convenience, well-structured datasets in Hugging Face can be utilized. The fine-grained datasets CUB and Aircraft we experimented with can be downloaded from Multimodal-Fatima/CUB_train and Multimodal-Fatima/FGVC_Aircraft_train, respectively. In case of encountering network connection problem during training, please pre-download the data from the website, and the saved local path HUG_LOCAL_IMAGE_TRAIN_DIR should be specified in the dataset/instance/cub.py.

Fine-tune on a dataset 🔥

Pre-trained lora weights

We provide the lora weights fine-tuned on the full dataset in case for fast reproducation on given datasets. One can download using the following link, and unzip the file into dir ckpts and the file structure look like:

ckpts
├── cub                                                                                                                                                                                                                                          -packages/torch/nn/modules/module.py", line 1501, in _call_impl
│   └── shot-1-lora-rank10
│       ├── learned_embeds-steps-last.bin                                                                                                                                                                                                        -packages/diffusers/models/attention_processor.py", line 527, in forward
│       └── pytorch_lora_weights.safetensors
└── put_finetuned_ckpts_here.txt
Dataset data ckpts (fullshot)
CUB huggingface (train/test) google drive
Flower official website google drive
Aircraft huggingface (train/test) google drive

Customized fine-tuning

The scripts/finetune.sh script allows users to perform fine-tuning on their own datasets. By default, it implements a fine-tuning strategy combining DreamBooth and Textual Inversion. Users can customize the examples_per_class argument to fine-tune the model on a dataset with {examples_per_class} shots. The tuning process costs around 4 hours on 4 RTX3090 GPUs for full-shot cub dataset.

MODEL_NAME="runwayml/stable-diffusion-v1-5"
DATASET='cub'
SHOT=-1 # set -1 for full shot
OUTPUT_DIR="ckpts/${DATASET}/shot${SHOT}_lora_rank10"

accelerate launch --mixed_precision='fp16' --main_process_port 29507 \
    train_lora.py \
    --pretrained_model_name_or_path=$MODEL_NAME \
    --dataset_name=$DATASET \
    --resolution=224 \
    --random_flip \
    --max_train_steps=35000 \
    --num_train_epochs=10 \
    --checkpointing_steps=5000 \
    --learning_rate=5e-05 \
    --lr_scheduler='constant' \
    --lr_warmup_steps=0 \
    --seed=42 \
    --rank=10 \
    --local_files_only \
    --examples_per_class $SHOT  \
    --train_batch_size 2 \
    --output_dir=$OUTPUT_DIR \
    --report_to='tensorboard'"

Contruct synthetic data

scripts/sample.sh provides script to synthesize augmented images in a multi-processing way. Each item in GPU_IDS denotes the process running on the indexed GPU. The simplified command for sampling a $5\times$ synthetic subset in an inter-class translation manner (diff-mix) with strength $s=0.7$ is:

DATASET='cub'
# set -1 for full shot
SHOT=-1 
FINETUNED_CKPT="ckpts/cub/shot${SHOT}-lora-rank10"
# ['diff-mix', 'diff-aug', 'diff-gen', 'real-mix', 'real-aug', 'real-gen', 'ti_mix', 'ti_aug']
SAMPLE_STRATEGY='diff-mix' 
STRENGTH=0.8
# ['fixed', 'uniform']. 'fixed': use fixed $STRENGTH, 'uniform': sample from [0.3, 0.5, 0.7, 0.9]
STRENGTH_STRATEGY='fixed' 
# expand the dataset by 5 times
MULTIPLIER=5 
# spwan 4 processes
GPU_IDS=(0 1 2 3) 

python  scripts/sample_mp.py \
--model-path='runwayml/stable-diffusion-v1-5' \
--output_root='outputs/aug_samples' \
--dataset=$DATASET \
--finetuned_ckpt=$FINETUNED_CKPT \
--syn_dataset_mulitiplier=$MULTIPLIER \
--strength_strategy=$STRENGTH_STRATEGY \
--sample_strategy=$SAMPLE_STRATEGY \
--examples_per_class=$SHOT \
--resolution=512 \
--batch_size=1 \
--aug_strength=0.8 \
--gpu-ids=${GPU_IDS[@]}

The output synthetic dir will be located at aug_samples/cub/diff-mix_-1_fixed_0.7. To create a 5-shot setting, set the examples_per_class argument to 5 and the output dir will be at aug_samples/cub/diff-mix_5_fixed_0.7. Please ensure that the finetuned_ckpt is also fine-tuned under the same 5-shot setting.

Downstream classification

After completing the sampling process, you can integrate the synthetic data into downstream classification and initiate training using the script scripts/classification.sh:

GPU=1
DATASET="cub"
SHOT=-1
# "shot{args.examples_per_class}_{args.sample_strategy}_{args.strength_strategy}_{args.aug_strength}"
SYNDATA_DIR="aug_samples/cub/shot${SHOT}_diff-mix_fixed_0.7" # shot-1 denotes full shot
SYNDATA_P=0.1
GAMMA=0.8

python downstream_tasks/train_hub.py \
    --dataset $DATASET \
    --syndata_dir $SYNDATA_DIR \
    --syndata_p $SYNDATA_P \
    --model "resnet50" \
    --gamma $GAMMA \
    --examples_per_class $SHOT \
    --gpu $GPU \
    --amp 2 \
    --note $(date +%m%d%H%M) \
    --group_note "fullshot" \
    --nepoch 120 \
    --res_mode 224 \
    --lr 0.05 \
    --seed 0 \
    --weight_decay 0.0005 

We also provides the scripts for robustness test and long-tail classification in scripts/classification_waterbird.sh and scripts/classification_imb.sh, respectively.

Acknowledgements

This project is built upon the repository Da-fusion and diffusers. Special thanks to the contributors.

diff-mix's People

Contributors

brandontrabucco avatar zhicaiwww avatar kw1ksand avatar

Stargazers

YutingXie avatar TinkeZ avatar Nayeon_suki avatar Ethan Chen avatar Multimedia Understanding and Processing avatar  avatar  avatar  avatar  avatar Won Dong Kyu avatar  avatar  avatar Yanghao Wang avatar  avatar Xiaobing Han avatar sjli avatar An-zhi WANG avatar  avatar

Watchers

 avatar

Forkers

dl-diffusion

diff-mix's Issues

关于dataset路径的问题

作者你好,我再复现你的代码时遇到了数据集路径上的问题
你给出的huggingface仓库里的文件和你在dataset/cub.py给出的路径有些区别,下载你给出的dataset仓库后并没有找到Multimodal-Fatima___parquet文件夹,在第25行HUG_LOCAL_IMAGE_TRAIN_DIR = Multimodal-Fatima___parquet/Multimodal-Fatima--CUB_train-bc20d158956ded0c,所对应的是huggingface仓库里的哪一个文件?

Reprdouce results

Hello,

I'd like to reproduce the results.
I just want to make sure what are the correct params: how can I know what are the exact params to be used for each scenrio? are the default ones in scripts/classification.sh are the ones to use?

The split in the training process

Hi, great work!

I wanna ask one question, for the conventional image classification task, do you use the validation set or directly train and get the best performance on the test set?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.