Code Monkey home page Code Monkey logo

sage's Introduction

SAGE

SAGE (Shapley Additive Global importancE) is a game-theoretic approach for understanding black-box machine learning models. It summarizes each feature's importance based on the predictive power it contributes, and it accounts for complex feature interactions using the Shapley value.

SAGE was introduced in this paper, but if you're new to using Shapley values you might want to start by reading this blog post.

Install

The easiest way to get started is to clone the repository and install the package into your Python environment:

pip install .

Usage

SAGE is model-agnostic, so you can use it with any kind of machine learning model (linear models, GBMs, neural networks, etc). All you need to do is set up an imputer to handle held out features and run a Shapley value estimator:

import sage

# Get data
x, y = ...
feature_names = ...

# Get model
model = ...

# Set up imputer for missing features
imputer = sage.MarginalImputer(model, x[:512])

# Set up estimator
estimator = sage.PermutationEstimator(imputer, 'mse')

# Calculate SAGE values
sage_values = estimator(x, y)
sage_values.plot(feature_names)

The result will look like this:

Our implementation supports several features to make Shapley value calculation more practical:

  • Uncertainty estimation. Confidence intervals are provided for each feature's importance value.
  • Convergence. Convergence is determined automatically based on the size of the confidence intervals, and a progress bar displays the estimated time until convergence.
  • Model conversion. Our back-end requires models that are converted into a consistent format, and the conversion step is performed automatically for XGBoost, CatBoost, LightGBM, sklearn and PyTorch models. If you're using a different kind of model, it must be converted to a callable function (see here for examples).

Examples

Check out the following notebooks to get started:

  • Bike is a simple example using XGBoost, and it shows how to calculate SAGE values and Shapley Effects (an alternative explanation when no labels are available)
  • Credit shows how to generate explanations with a surrogate model to approximate the conditional distribution (using CatBoost)
  • Airbnb shows an example where SAGE values are calculated with grouped features (using a PyTorch MLP)
  • Bank shows a model monitoring example that uses SAGE to identify features that hurt the model's performance (using CatBoost)
  • MNIST shows several strategies to accelerate convergence for datasets with many features (feature grouping, different imputing setups)

If you want to replicate any experiments described in our paper, see this separate repository.

More details

This repository provides some flexibility in the explanations that are provided. You can make several choices when generating explanations.

1. Feature removal approach

The original SAGE paper proposes marginalizing out missing features using their conditional distribution. Since this is challenging to implement in practice, several approximations are available. The choices include

  1. Use default values for missing features (see MNIST for an example). This is a fast but low-quality approximation.
  2. Sample features from the marginal distribution (see Bike for an example). This approximation is discussed in the SAGE paper.
  3. Train a supervised surrogate model (see Credit for an example). This approach is described in this paper, and it can provide a better approximation than the other approaches. However, it requires training an additional model (typically a neural network).

2. Explanation type

Two types of explanations can be calculated, both based on Shapley values:

  1. SAGE. This approach quantifies each feature's role in improving the model's performance (the default explanation here).
  2. Shapley Effects. Described in this paper, this explanation method quantifies the model's sensitivity to each feature. Since Shapley Effects is a variation on SAGE (see details in this paper), our implementation generates this type of explanation when labels are not provided. See the Bike notebook for an example.

3. Shapley value estimator

Shapley values are computationally costly to calculate, so we implemented four different estimators:

  1. Permutation sampling. This is the approach described in the original paper (see PermutationEstimator).
  2. KernelSAGE. This is a linear regression-based estimatorthat is similar to KernelSHAP (see KernelEstimator). It is described in this paper, and the Bank notebook shows an example use-case.
  3. Iterated sampling. This is a variation on the permutation sampling approach where we calculate Shapley values for each feature sequentially (see IteratedEstimator). This permits faster convergence for features with low variance, but it can result in wider confidence intervals.
  4. Sign estimation. This method estimates SAGE values to a lower precision by focusing on their sign (i.e., whether they help or hurt performance). It is implemented in SignEstimator, and the Bank notebook shows an example.

The results from each approach should be identical because they are all unbiased estimators. However, their convergence speed may differ. Permutation sampling is a good approach to start with. KernelSAGE converges a bit faster, but the uncertainty is spread more evenly among the features (rather than being highest for more important features).

4. Grouped features

Rather than removing features individually, you can specify groups of features to be removed together. This will likely speed up convergence because there are fewer feature subsets to consider. See the Airbnb notebook for an example.

Authors

References

Ian Covert, Scott Lundberg, Su-In Lee. "Understanding Global Feature Contributions With Additive Importance Measures." NeurIPS 2020

Ian Covert, Scott Lundberg, Su-In Lee. "Explaining by Removing: A Unified Framework for Model Explanation." arxiv preprint:2011.14878

Ian Covert, Su-In Lee. "Improving KernelSHAP: Practical Shapley Value Estimation via Linear Regression." AISTATS 2021

Art Owen. "Sobol' Indices and Shapley value." SIAM 2014

sage's People

Contributors

iancovert avatar

Watchers

James Cloos avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.