Code Monkey home page Code Monkey logo

disc's Introduction



Discover and Cure: Concept-aware Mitigation of Spurious Correlation (ICML 2023)

License: MIT

Reference

If you found this code/work useful in your own research, please consider citing the following:

@inproceedings{
    wu23disc,
    title={Discover and Cure: Concept-aware Mitigation of Spurious Correlation},
    author={Shirley Wu and Mert Yuksekgonul and Linjun Zhang and James Zou},
    booktitle={ICML},
    year={2023},
}

Overview

What is DISC?

DISC is an algorithm on image classification tasks which adaptively discovers and removes spurious correlations during model training, using a concept bank generated by Stable Diffusion.

Why DISC?

  • ๐Ÿ”‘ Effectively remove strong spurious correlation and make models generalize better! Go for the green decision boundary!



  • ๐Ÿ”Ž No more ambiguous interpretations! DISC tells you exactly what attributes contribute to the spurious correlation and how significant their contributions are.



  • ๐ŸŒฑ Monitor how models learn spurious correlations!



How does DISC do it?



  • Build a concept bank with multiple categories.
  • In each iteration, discover spurious concepts by computing concept sensitivity.
  • In each iteration, mix up concept images with the training dataset guided by the concept sensitivity, and update model parameters on the balanced dataset.

See our paper for details!


Get Started

Installation

See requirements.txt or install the environment via

conda create -n disc python=3.9
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch
pip install scikit-learn transformers wilds umap-learn diffusers nltk
pip install tarfile zipfile gdown # Used for data download

Data download

(Recommended) Download all the datasets via the commands below!

cd disc
python download_datasets.py
Manual download (If auto download fails)

  • MetaShift: Download the dataset from here. Unzipping this should result in a folder metashifts, which should be moved as $ROOT/data/metashifts depending on your root directory.

  • Waterbirds: Download the dataset from here. Unzipping this should result in a folder waterbird_complete95_forest2water2. Place this folder under $ROOT/data/cub/.

  • FMoW: Dataset download is automatic and can be found in $ROOT/data/fmow/fmow_v1.1. We recommend following the setup instructions provided by the official WILDS website.

  • ISIC: Download the dataset from here. Unzipping this should result in a folder isic, which should be moved as $ROOT/data/isic depending on your root directory.

Prepare Concept Bank

(Recommended) Download the concept bank we have already generated via the commands below!

cd concept_bank
python download.py
Manual generation. Can be used for customizing your own concept bank!

  • Define the concept bank in synthetic_concepts/metadata.json
  • Run the generation using Stable Diffusion v1-4:
    cd concept_bank
    python generate_concept_bank.py --n_samples 200 
    

Run the Training Process

ERM

We provide commands under scripts folder. For example, train an ERM model on MetaShift:

SEED=0
ROOT=./DISC # Set your code root here
python run_expt.py \
-s confounder -d MetaDatasetCatDog -t cat -c background --lr 0.001 --batch_size 16 \
--weight_decay 0.0001 --model resnet50 --n_epochs 100 --log_dir $ROOT/output/ \
--root_dir $ROOT/data/metashifts/MetaDatasetCatDog --save_best --save_last --seed $SEED 

DISC

We provide commands under scripts folder. For example, with a trained ERM model on MetaShift, you can train the DISC model via:

SEED=0
N_CLUSTERS=2
ROOT=./DISC # Set your code root here
python run_expt.py \
-s confounder -d MetaDatasetCatDog -t cat -c background --lr 0.0005 --batch_size 16 \
--weight_decay 0.0001 --model resnet50 --n_epochs 100  --log_dir $ROOT/output/ \
--root_dir $ROOT/data/metashifts/MetaDatasetCatDog \
--erm_path <the erm model path ends with .pth> \
--concept_img_folder $ROOT/synthetic_concepts --concept_categories everything \
--n_clusters $N_CLUSTERS --augment_data --save_last --save_best --seed $SEED --disc 

Contact Us

Free feel to create an issue under this repo or contact [email protected] if you have any questions!

disc's People

Contributors

wuyxin avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

disc's Issues

Problem of results on MetaShift

Dear authors,
Thank you so much for your work! I have tried the codes on your scripts on both the ERM and DISC models. Here is what I exactly ran:

SEED=0
ROOT=/data/data/DISC # Set your code root here
python run_expt.py \
-shift_type confounder -dataset MetaDatasetCatDog -target_name cat -confounder_names background --lr 0.001 --batch_size 16 \
--weight_decay 0.0001 --model resnet50 --n_epochs 100 --log_dir $ROOT/output/ \
--root_dir $ROOT/data/metashifts/MetaDatasetCatDog --save_best --save_last --seed $SEED 

SEED=0
N_CLUSTERS=2
ROOT=/data/data/DISC # Set your code root here
CUDA_VISIBLE_DEVICES=1 python run_expt.py \
--shift_type confounder --dataset MetaDatasetCatDog --target_name cat --confounder_names background \
--lr 0.0005 --batch_size 16 \
--weight_decay 0.0001 --model resnet50 --n_epochs 100  --log_dir $ROOT/output/ \
--root_dir $ROOT/data/metashifts/MetaDatasetCatDog \
--erm_path output/MetaDatasetCatDog/ERM/reweight_groups=0-augment=0-lr=0.001-batch_size=16-n_epochs=100-seed=0/best_model.pth \
--concept_img_folder $ROOT/synthetic_concepts --concept_categories everything \
--n_clusters $N_CLUSTERS --augment_data --save_last --save_best --seed $SEED --disc 

However from the results that are shown in the output, (as shown below), it is quite different from in the table 1, where the ERM should achieve 72.9% and DISC should achieve 75.5% in avg acc.

ERM:
worst_group_acc     0.686567
mean_differences    0.155132
group_avg_acc       0.764133
avg_acc             0.773913

DISC:
worst_group_acc     0.687259
mean_differences    0.034134
group_avg_acc       0.704326
avg_acc             0.702174

Could you please help me with it? Also do you have the pretrained weights on ERM and DISC and corresponding eval results?

Thanks,

ISIC experiment

Hi @Wuyxin, thank you for sharing your great work.

I run the GroupDRO and other methods on the ISIC dataset. Could you please explain more about the way you pick the model for evaluating on the test set (e.g. based on validation's AUC or worst group)?

Training issue - name 'is_training' is not defined

When trying to run the training I face the following issue:
disc/disc/utils/loss.py", line 198, in get_stats
group_str = self.group_str(idx, is_training) if 'Meta' in self.args.dataset else self.group_str(idx)
NameError: name 'is_training' is not defined

Can you assist? This is indeed not defined so I am unsure how it worked for you.
Thank you!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.