Code Monkey home page Code Monkey logo

jonathancrabbe / simplex Goto Github PK

View Code? Open in Web Editor NEW
22.0 3.0 8.0 1.01 MB

This repository contains the implementation of SimplEx, a method to explain the latent representations of black-box models with the help of a corpus of examples. For more details, please read our NeurIPS 2021 paper: 'Explaining Latent Representations with a Corpus of Examples'.

License: Apache License 2.0

Python 65.48% Jupyter Notebook 34.52%
machine-learning artificial-intelligence explainable-ai latent-space latent-representations interpretability method example-based-explanations feature-importance neurips-2021

simplex's Introduction

SimplEx - Explaining Latent Representations with a Corpus of Examples

Tests pdf License: Apache 2.0

image

Code Author: Jonathan Crabbé ([email protected])

This repository contains the implementation of SimplEx, a method to explain the latent representations of black-box models with the help of a corpus of examples. For more details, please read our NeurIPS 2021 paper: 'Explaining Latent Representations with a Corpus of Examples'.

🚀 Installation

The library can be installed from PyPI using

$ pip install simplexai

or from source, using

$ pip install .

Toy example

Bellow, you can find a toy demonstration where we make a corpus decomposition of test examples representations. All the relevant code can be found in the file simplex.

from simplexai.explainers.simplex import Simplex
from simplexai.models.image_recognition import MnistClassifier
from simplexai.experiments.mnist import load_mnist

# Get a model
model = MnistClassifier() # Model should have the BlackBox interface

# Load corpus and test inputs
corpus_loader = load_mnist(subset_size=100, train=True, batch_size=100) # MNIST train loader
test_loader = load_mnist(subset_size=10, train=True, batch_size=10) # MNIST test loader
corpus_inputs, _ = next(iter(corpus_loader)) # A tensor of corpus inputs
test_inputs, _ = next(iter(test_loader)) # A set of inputs to explain

# Compute the corpus and test latent representations
corpus_latents = model.latent_representation(corpus_inputs).detach()
test_latents = model.latent_representation(test_inputs).detach()

# Initialize SimplEX, fit it on test examples
simplex = Simplex(corpus_examples=corpus_inputs,
                  corpus_latent_reps=corpus_latents)
simplex.fit(test_examples=test_inputs,
            test_latent_reps=test_latents,
            reg_factor=0)

# Get the weights of each corpus decomposition
weights = simplex.weights

We get a tensor weights that can be interpreted as follows: weights[i,c] = weight of corpus example c in the decomposition of example i.

We can get the importance of each corpus feature for the decomposition of a given example i in the following way:

import torch

# Compute the Integrated Jacobian for a particular example
i = 4
input_baseline = torch.zeros(corpus_inputs.shape) # Baseline tensor of the same shape as corpus_inputs
simplex.jacobian_projection(test_id=i, model=model, input_baseline=input_baseline)
result = simplex.decompose(i)

We get a list result where each element of the list corresponds to a corpus example. This list is sorted by decreasing order of importance in the corpus decomposition. Each element of the list is a tuple structured as follows:

w_c, x_c, proj_jacobian_c = result[c]

Where w_c corresponds to the weight weights[i,c], x_c corresponds to corpus_inputs[c] and proj_jacobian is a tensor such that proj_jacobian_c[k] is the Projected Jacobian of feature k from corpus example c.

Reproducing the paper results

Reproducing MNIST Approximation Quality Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m simplexai.experiments.mnist -experiment "approximation_quality" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.mnist.quality.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Reproducing Prostate Cancer Approximation Quality Experiment

This experiment requires the access to the private datasets CUTRACT and SEER decribed in the paper.

  1. Copy the files cutract_internal_all.csv and seer_external_imputed_new.csv are in the folder data/Prostate Cancer
  2. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m simplexai.experiments.prostate_cancer -experiment "approximation_quality" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.prostate.quality.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots are saved here.

Reproducing Prostate Cancer Outlier Experiment

This experiment requires the access to the private datasets CUTRACT and SEER decribed in the paper.

  1. Make sure that the files cutract_internal_all.csv and seer_external_imputed_new.csv are in the folder data/Prostate Cancer
  2. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m simplexai.experiments.prostate_cancer -experiment "outlier_detection" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.prostate.outlier.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots are saved here.

Reproducing MNIST Jacobian Projection Significance Experiment

  1. Run the following script
python -m simplexai.experiments.mnist -experiment "jacobian_corruption"

2.The resulting plots and data are saved here.

Reproducing MNIST Outlier Detection Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 9)
python -m simplexai.experiments.mnist -experiment "outlier_detection" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.mnist.outlier.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Reproducing MNIST Influence Function Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m simplexai.experiments.mnist -experiment "influence" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.mnist.influence.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Note: some problems can appear with the package Pytorch Influence Functions. If this is the case, please change calc_influence_function.py in the following way:

343: influences.append(tmp_influence) ==> influences.append(tmp_influence.cpu())
438: influences_meta['test_sample_index_list'] = sample_list ==> #influences_meta['test_sample_index_list'] = sample_list

Reproducing AR Approximation Quality Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m simplexai.experiments.time_series -experiment "approximation_quality" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.ar.quality.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

Reproducing AR Outlier Detection Experiment

  1. Run the following script for different values of CV (the results from the paper were obtained by taking all integer CV between 0 and 4)
python -m simplexai.experiments.time_series -experiment "outlier_detection" -cv CV
  1. Run the following script by adding all the values of CV from the previous step
python -m simplexai.experiments.results.ar.outlier.plot_results -cv_list CV1 CV2 CV3 ...
  1. The resulting plots and data are saved here.

🔨 Tests

Install the testing dependencies using

pip install .[testing]

The tests can be executed using

pytest -vsx

Citing

If you use this code, please cite the associated paper:

@inproceedings{Crabbe2021Simplex,
 author = {Crabbe, Jonathan and Qian, Zhaozhi and Imrie, Fergus and van der Schaar, Mihaela},
 booktitle = {Advances in Neural Information Processing Systems},
 editor = {M. Ranzato and A. Beygelzimer and Y. Dauphin and P.S. Liang and J. Wortman Vaughan},
 pages = {12154--12166},
 publisher = {Curran Associates, Inc.},
 title = {Explaining Latent Representations with a Corpus of Examples},
 url = {https://proceedings.neurips.cc/paper/2021/file/65658fde58ab3c2b6e5132a39fae7cb9-Paper.pdf},
 volume = {34},
 year = {2021}
}

simplex's People

Contributors

bcebere avatar jonathancrabbe avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

simplex's Issues

README typo

Small typo in the code for your toy example.

This:
simplex.jacobian_projections(test_id=i, model=model, input_baseline=input_baseline)

Should be this:
simplex.jacobian_projection(test_id=i, model=model, input_baseline=input_baseline)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.