Code Monkey home page Code Monkey logo

cf-feasibility's Introduction

Feasible Counterfactual Explanations

Code accompanying the paper Preserving Causal Constraints in Counterfactual Explanations for Machine Learning Classifiers, selected for Oral Spotlight at the NeurIPS 2019 Workshop on Machine learning and Causal Inference for improved decision making

DiCE

This work is also being integerated with DiCE, an open source library for explaining ML models. Please check the this tutorial and follow DiCE for updates regarding the same.

Cite

@article{mahajan2019preserving,
  title={Preserving Causal Constraints in Counterfactual Explanations for Machine Learning Classifiers},
  author={Mahajan, Divyat and Tan, Chenhao and Sharma, Amit},
  journal={arXiv preprint arXiv:1912.03277},
  year={2019}
}

Code Structure

generativecf

Contains the code for experiments on Simple-BN, Sangiovese, Adult dataset

generativecf-mnist

Containts the code for experiments on MNIST

generativecf

  • models/

    • Contains pre trained models for the different methods across datasets
  • data/

    • Contains the processed data files for all the datasets; download the data files from this link
  • master_evalute.py

    • Utilizes the pre trained models (models/) and datasets (data/) to reproduce the results mentioned in the paper. The results are stored in the directory /results

    • It also generates a file 'plot_dict.json' in the directory r_plots/; where you may convert it to plotdf.csv file and then execute 'plot_figures.R' script to get better graphs stored in the directory /results

  • base-generative-cf.py

    • Implementation of BaseGenCF for all datasets

    • Usage: python3 base-generative-cf.py --htune 0 --batch_size 64 --epoch 50 --dataset_name bn1 --margin 0.1 --validity_reg 10

  • ae-base-generative-cf.py

    • Implementation of AEGenCF for all datasets

    • Usage: python3 ae-base-generative-cf.py --htune 0 --batch_size 64 --epoch 50 --dataset_name bn1 --ae_path bn1-64-50-target-class--1-auto-encoder.pth --margin 0.1 --validity_reg 10 --ae_reg 10

  • oracle-generative-cf.py

    • Implementation of OracleGenCF for all datasets

    • Usage: python3 oracle-generative-cf.py --htune 0 --batch_size 64 --epoch 50 --dataset_name bn1 --cf_path bn1-margin-0.014-validity_reg-54.0-epoch-50-base-gen.pth --oracle_data bn1-fine-tune-size-100-upper-lim-10-good-cf-set.json --margin 0.1 --validity_reg 10 --oracle_reg 10

  • model-approx-generative-cf.py

    • Implementation of ModelApproxGenCF for Simple-BN dataset

    • Usage: python3 model-approx-generative-cf.py --htune 0 --batch_size 64 --epoch 50 --dataset_name bn1 --ae_reg 0 --ae_path bn1-64-50-target-class--1-auto-encoder.pth --margin 0.1 --validity_reg 10 --constraint_reg 10

  • model-approx-generative-cf-bnlearn.py

    • Implementation of ModelApproxGenCF for Sangiovese dataset

    • Usage: python3 model-approx-generative-cf-bnlearn.py --htune 0 --batch_size 512 --epoch 50 --dataset_name sangiovese --ae_reg 0 --ae_path sangiovese-512-50-target-class--1-auto-encoder.pth --margin 0.1 --validity_reg 10 --constraint_reg 10 --constrained_nodes 'BunchN'

  • unary-const-generative-cf.py

    • Implementation of ModelApproxGenCF for Adult dataset C1 constraint ( Non Decreasing Age )

    • Usage: python3 unary-const-generative-cf.py --htune 0 --batch_size 2048 --epoch 50 --dataset_name adult --margin 0.1 --validity_reg 10 --constraint_reg 10

  • unary-ed-const-generative-cf.py

    • Implementation of ModelApproxGenCF for Adult dataset C2 constraint ( Age-Ed Causal Constraint )

    • Usage: python3 unary-ed-const-generative-cf.py --htune 0 --batch_size 2048 --epoch 50 --dataset_name adult --margin 0.1 --validity_reg 10 --constraint_reg 10

  • scm-generative-cf.py

    • Implementation of SCMGenCF for Simple-BN dataset

    • Usage: python3 scm-generative-cf.py --htune 0 --batch_size 64 --epoch 50 --dataset_name bn1 --margin 0.1 --validity_reg 10 --scm_reg 10

  • scm-generative-cf-bnlearn.py

    • Implementation of SCMGenCF for Sangiovese dataset

    • Usage: python3 scm-generative-cf-bnlearn.py --htune 0 --batch_size 512 --epoch 50 --dataset_name sangiovese --validity_reg 10 --scm_reg 10 --constraint_node 'BunchN'

  • contrastive_explanations.py

    • Implementation of CEM for all datasets

    • Usage: python3 contrastive_explanations.py --dataset_name bn1 --htune 0 --train_case_pred 0 --train_case_ae 0 --explain_case 1 --sample_size 3 --timeit 0 --c_init 10 --max_iterations 1000 --beta 0.1 --kappa 0.1 --gamma 1 --c_steps 2

  • timeit-base-generative-cf.py

    • Computing the training and evaluaiton time of BaseGenCF

    • Usage: python3 timeit-base-generative-cf.py --htune 0 --batch_size 64 --epoch 50 --dataset_name bn1 --margin 0.1 --validity_reg 10 --cf_path bn1-margin-0.014-validity_reg-54.0-epoch-50-base-gen.pth

  • timeit-oracle-generative-cf.py

    • Computing the training and evaluaiton time of Example-based CF

    • Usage: python3 timeit-oracle-generative-cf.py --htune 0 --batch_size 64 --epoch 50 --dataset_name bn1 --cf_path bn1-margin-0.014-validity_reg-54.0-epoch-50-base-gen.pth --oracle_data bn1-fine-tune-size-100-upper-lim-10-good-cf-set.json --margin 0.1 --validity_reg 10 --oracle_reg 10

generativecf/scripts/

  • blackboxmodel.py

    • Contains the architecture of the ML model to be explained across datasets
  • vae_model.py

    • Contains the architecutre of the BaseGenCF and AutoEncoder model
  • blackbox-model-train.py

    • Trains the ML model to be explained across datasets

    • Usage: python3 blackbox-model-train.py bn1

  • auto-encoder-train.py

    • Trains the Auto Encoder model used in AEGenCF and computing IM Metric

    • Usage: python3 auto-encoder-train.py --dataset_name bn1 --batch_size 64 --epoch 50 --target_class -1

  • good-cf-set-gen.py

    • Contains the code for generating labelled queries for OracleGenCF for Simple-BN, Adult dataset

    • Usage: python3 good-cf-set-gen.py --dataset_name bn1 --fine_tune_size 100 --upper_limit 10 --cf_path bn1-margin-0.014-validity_reg-54.0-epoch-50-base-gen.pth

  • good-cf-set-gen-bnlearn.py

    • Contains the code for generating labelled queries for OracleGenCF for Simple-BN, Adult dataset

    • Usage: python3 good-cf-set-gen-bnlearn.py --dataset_name sangiovese --fine_tune_size 100 --upper_limit 10 --cf_path sangiovese-margin-0.161-validity_reg-94.0-epoch-50-base-gen.pth --constraint_node BunchN

  • datagen.py

    • Creates train, val, test splits with other important processed data for all datasets

    • Usage: python3 datagen.py bn1

  • evaluation_functions.py

    • Contains evaluations metrics like Target-Class Validity, Constraint Feasibility Score, etc. for all datasets
  • bnlearn_parser.py

    • Reads the sangiovese-scm.txt and creates the SCM
  • helpers.py

    • Contains code for generating the Adult dataset
  • sangiovese-data-gen.py

    • Contains code for processing the Sangiovese dataset
  • simple-bn-gen.py

    • Contains the code for generating the Simple-BN dataset

generativecf-mnist

Similar description as stated above for generativecf files; with the only difference that evaluation happens for MNIST dataset.

cf-feasibility's People

Contributors

divyat09 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

cf-feasibility's Issues

Google Drive's link broken

I want to try the dummy dataset (Simple-BN) in your paper. However, the dataset's link to Google Drive is broken. It says that the data is in owner's trash. Can you please provide new links containing the dataset?

Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.