Code Monkey home page Code Monkey logo

vimpy's Introduction

Python/vimpy: inference on algorithm-agnostic variable importance

PyPI version License: MIT

Software author: Brian Williamson

Methodology authors: Brian Williamson, Peter Gilbert, Noah Simon, Marco Carone

R package: https://github.com/bdwilliamson/vimp

Introduction

In predictive modeling applications, it is often of interest to determine the relative contribution of subsets of features in explaining an outcome; this is often called variable importance. It is useful to consider variable importance as a function of the unknown, underlying data-generating mechanism rather than the specific predictive algorithm used to fit the data. This package provides functions that, given fitted values from predictive algorithms, compute nonparametric estimates of variable importance based on $R^2$, deviance, classification accuracy, and area under the receiver operating characteristic curve, along with asymptotically valid confidence intervals for the true importance.

For more details, please see the accompanying manuscripts "Nonparametric variable importance assessment using machine learning techniques" by Williamson, Gilbert, Carone, and Simon (Biometrics, 2020), "A unified approach for inference on algorithm-agnostic variable importance" by Williamson, Gilbert, Simon, and Carone (arXiv, 2020), and "Efficient nonparametric statistical inference on population feature importance using Shapley values" by Williamson and Feng (arXiv, 2020; to appear in the Proceedings of the Thirty-seventh International Conference on Machine Learning [ICML 2020]).

Installation

You may install a stable release of vimpy using pip by running python pip install vimpy from a Terminal window. Alternatively, you may install within a virtualenv environment.

You may install the current dev release of vimpy by downloading this repository directly.

Issues

If you encounter any bugs or have any specific feature requests, please file an issue.

Example

This example shows how to use vimpy in a simple setting with simulated data and using a single regression function. For more examples and detailed explanation, please see the R vignette.

## load required libraries
import numpy as np
import vimpy
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.model_selection import GridSearchCV

## -------------------------------------------------------------
## problem setup
## -------------------------------------------------------------
## define a function for the conditional mean of Y given X
def cond_mean(x = None):
    f1 = np.where(np.logical_and(-2 <= x[:, 0], x[:, 0] < 2), np.floor(x[:, 0]), 0)
    f2 = np.where(x[:, 1] <= 0, 1, 0)
    f3 = np.where(x[:, 2] > 0, 1, 0)
    f6 = np.absolute(x[:, 5]/4) ** 3
    f7 = np.absolute(x[:, 6]/4) ** 5
    f11 = (7./3)*np.cos(x[:, 10]/2)
    ret = f1 + f2 + f3 + f6 + f7 + f11
    return ret

## create data
np.random.seed(4747)
n = 100
p = 15
s = 1 # importance desired for X_1
x = np.zeros((n, p))
for i in range(0, x.shape[1]) :
    x[:,i] = np.random.normal(0, 2, n)

y = cond_mean(x) + np.random.normal(0, 1, n)

## -------------------------------------------------------------
## preliminary step: get regression estimators
## -------------------------------------------------------------
## use grid search to get optimal number of trees and learning rate
ntrees = np.arange(100, 500, 100)
lr = np.arange(.01, .1, .05)

param_grid = [{'n_estimators':ntrees, 'learning_rate':lr}]

## set up cv objects
cv_full = GridSearchCV(GradientBoostingRegressor(loss = 'ls', max_depth = 1), param_grid = param_grid, cv = 5)
cv_small = GridSearchCV(GradientBoostingRegressor(loss = 'ls', max_depth = 1), param_grid = param_grid, cv = 5)

## -------------------------------------------------------------
## get variable importance estimates
## -------------------------------------------------------------
# set seed
np.random.seed(12345)
## set up the vimp object
vimp = vimpy.vim(y = y, x = x, s = 1, pred_func = cv_full, measure_type = "r_squared")
## get the point estimate of variable importance
vimp.get_point_est()
## get the influence function estimate
vimp.get_influence_function()
## get a standard error
vimp.get_se()
## get a confidence interval
vimp.get_ci()
## do a hypothesis test, compute p-value
vimp.hypothesis_test(alpha = 0.05, delta = 0)
## display the estimates, etc.
vimp.vimp_
vimp.se_
vimp.ci_
vimp.p_value_
vimp.hyp_test_

## -------------------------------------------------------------
## using precomputed fitted values
## -------------------------------------------------------------
np.random.seed(12345)
folds_outer = np.random.choice(a = np.arange(2), size = n, replace = True, p = np.array([0.5, 0.5]))
## fit the full regression
cv_full.fit(x[folds_outer == 1, :], y[folds_outer == 1])
full_fit = cv_full.best_estimator_.predict(x[folds_outer == 1, :])

## fit the reduced regression
x_small = np.delete(x[folds_outer == 0, :], s, 1) # delete the columns in s
cv_small.fit(x_small, y[folds_outer == 0])
small_fit = cv_small.best_estimator_.predict(x_small)
## get variable importance estimates
np.random.seed(12345)
vimp_precompute = vimpy.vim(y = y, x = x, s = 1, f = full_fit, r = small_fit, measure_type = "r_squared", folds = folds_outer)
## get the point estimate of variable importance
vimp_precompute.get_point_est()
## get the influence function estimate
vimp_precompute.get_influence_function()
## get a standard error
vimp_precompute.get_se()
## get a confidence interval
vimp_precompute.get_ci()
## do a hypothesis test, compute p-value
vimp_precompute.hypothesis_test(alpha = 0.05, delta = 0)
## display the estimates, etc.
vimp_precompute.vimp_
vimp_precompute.se_
vimp_precompute.ci_
vimp_precompute.p_value_
vimp_precompute.hyp_test_

## -------------------------------------------------------------
## get variable importance estimates using cross-validation
## -------------------------------------------------------------
np.random.seed(12345)
## set up the vimp object
vimp_cv = vimpy.cv_vim(y = y, x = x, s = 1, pred_func = cv_full, V = 5, measure_type = "r_squared")
## get the point estimate
vimp_cv.get_point_est()
## get the standard error
vimp_cv.get_influence_function()
vimp_cv.get_se()
## get a confidence interval
vimp_cv.get_ci()
## do a hypothesis test, compute p-value
vimp_cv.hypothesis_test(alpha = 0.05, delta = 0)
## display estimates, etc.
vimp_cv.vimp_
vimp_cv.se_
vimp_cv.ci_
vimp_cv.p_value_
vimp_cv.hyp_test_

Logo

The logo was created using hexSticker, lisa, and a python image distributed under the CC0 license. Many thanks to the maintainers of these packages and the Color Lisa team.

vimpy's People

Contributors

bdwilliamson avatar jennyleestat avatar jjfeng avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

vimpy's Issues

Warning when using "vim" for groups of covariates

Hello - I am trying to use your package for an analysis. I loaded the "Boston" dataset and tried running the "vim" function for groups of covariates. I get the following warning:

**> neigh.vim <- vim(full.fit, fit ~ x, data = Boston3, y = Boston3$medv,

  •              s = c(1, 2, 3, 4, 10, 11, 12, 13), SL.library = learners.2)
    

Warning messages:
1: In if (standardized) { :
the condition has length > 1 and only the first element will be used
2: In if (standardized) { :
the condition has length > 1 and only the first element will be used
3: In if (standardized) { :
the condition has length > 1 and only the first element will be used**

Would appreciate your advice on whether this is something to worry about.

Best,

Installation of version 2.1 and errors in spvim()

Hi I'm currently trying to use spvim() from vimpy for its ability to accomodate arbitrary prediction functions as oppose to sp_vim() in R where as far as I can see only learners from the SL library can be used. When trying to install version 2.1 I however encountered the following error:
ERROR: Could not find a version that satisfies the requirement scipy.stats (from vimpy) (from versions: none)
ERROR: No matching distribution found for scipy.stats

This seems to be due to the 'scipy.stats' in line 20 in the file "setup.py". Maybe it is there for a reason but after removing it the installation worked fine.

Additionally when using the function spvim() I also encountered a few errors. It could also be that I'm using it in a wrong way however a few potential errors (which unfortunately did not entirely resolve the problems) in vimpy/vimpy/spvim.py are:

for method get_influence_function():

  • [line 109] in self.v.shape[0] the underscore after the v is missing => self.v_.shape[0]
    after including the underscore:
  • [line 109] in self.v_[self.v_.shape[0]] the index is out of bound maybe this should be self.v_[self.v_.shape[0]-1]?
  • [line 110] self.z_counts_ does not exist probably needs to be instantiated under init and defined during get_point_est()?

After incorporating these changes get_influence_function() worked, however, the methods get_ses() and get_cis() seem to have further issues i.e. problems with the indices in the shapley_se() function etc.

While I'm not sure about all of the above propositions they might still be of some use.

Kind regards.

extend code to feature groups

correct me if I'm wrong, but I don't believe the current code is setup to calculate values for feature groups.

Can you confirm I'm understanding this correctly? To extend the code for groups, we would want to select subsets over feature groups rather than individual features. Then when measuring predictiveness, we include all features that are part of the selected feature groups. So for example, if we have groups:

vitals = [blood_pressure, heart_rate]
labs = [sodium, potassium, sugar]
diagnoses = [kidney, heart, liver]

If S = [0, 1], then we train a model with blood_pressure, heart rate, sodium, potassium, and sugar.

Would we need to normalize anything?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.