Code Monkey home page Code Monkey logo

skater's Introduction

Skater

Skater is a unified framework to enable Model Interpretation for all forms of model to help one build an Interpretable machine learning system often needed for real world use-cases(** we are actively working towards to enabling faithful interpretability for all forms models). It is an open source python library designed to demystify the learned structures of a black box model both globally(inference on the basis of a complete data set) and locally(inference about an individual prediction).

The project was started as a research idea to find ways to enable better interpretability(preferably human interpretability) to predictive "black boxes" both for researchers and practioners. The project is still in beta phase.

Installation

pip

    Option 1: without rule lists and without deepinterpreter
    pip install -U skater

    Option 2: without rule lists and with deep-interpreter:
    1. Ubuntu: pip3 install --upgrade tensorflow (follow instructions at https://www.tensorflow.org/install/ for details and best practices)
    2. sudo pip install keras
    3. pip install -U skater==1.1.2

    Option 3: For everything included
    1. conda install gxx_linux-64
    2. Ubuntu: pip3 install --upgrade tensorflow (follow instructions https://www.tensorflow.org/install/ for
       details and best practices)
    3. sudo pip install keras
    4. sudo pip install -U --no-deps --force-reinstall --install-option="--rl=True" skater==1.1.2

To get the latest changes try cloning the repo and use the below mentioned commands to get started,


    1. conda install gxx_linux-64
    2. Ubuntu: pip3 install --upgrade tensorflow (follow instructions https://www.tensorflow.org/install/ for
       details and best practices)
    3. sudo pip install keras
    4. git clone the repo
    5. sudo python setup.py install --ostype=linux-ubuntu --rl=True

Testing

  1. If repo is cloned: python skater/tests/all_tests.py
  2. If pip installed: python -c "from skater.tests.all_tests import run_tests; run_tests()"

Usage and Examples

See examples folder for usage examples.

Contributing

This project welcomes contributions from the community. Before submitting a pull request, please review our contribution guide

Security

Please consult the security guide for our responsible security vulnerability disclosure process

License

Copyright (c) 2018, 2023 Oracle and/or its affiliates.

Released under the Universal Permissive License v1.0 as shown at https://oss.oracle.com/licenses/upl/.

skater's People

Contributors

aikramer2 avatar alvinthai avatar astupidbear avatar bacook17 avatar benvandyke avatar darenr avatar deveshcode avatar glemaitre avatar jamesmyatt avatar limscoder avatar m-richards avatar nithanaroy avatar pramitchoudhary avatar rputhuma avatar silversurfer84 avatar spavlusieva avatar totalamateurhour avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

skater's Issues

improperly scaled bandwidth parameter can ruin results

i tried running lime for the sklearn breast_cancer dataset; as a result of the default kernel width, all the distances seemed huge, and in turn the sample weights were practically 0, so all the lime coeffients were 0.

2 options:

  1. use feature scaling for the default kernel width
  2. rescale the data before computing distances.

Minor code changes

  1. rename i.consider rename
  2. restructure model module
  3. rename model.processor

Algos: Supporting more specialized interpretation support for tree based models

This is separate algorithm for interpreting trees and is useful only for the following types of models

  • DecisionTreeRegressor
  • DecisionTreeClassifier
  • RandomForestRegressor
  • RandomForestClassifier

This could possibly be an extension of the model interpretation interface. Extension because its model agnostic only wrt Tree/Ensemble based models.

Installing with custom source can cause issues with pip uninstalls

in requirements.txt, -e #egg=packagename will let you import packagename, but wont be recognizable to pip as package, so pip uninstall packagename doesnt work.

if users want to use sklearn==0.17, but our fork is 0.18, then there will be conflicts.

if authors update package, we need to merge those changes into our fork, creating potential issues.

Calling third party funcs behind our own function obscures the true function signature, which can cause errors, requiring us to maintain our pointer functions in accordance with the true, underlying function after an update.

Algos/Performance: For tree based model pass it on to the default PDP sklearn implementation

The default PDP implementation makes use of weighted tree traversal for tree based model for improving performance.
Reference: http://scikit-learn.org/stable/modules/ensemble.html
"For each grid point a weighted tree traversal is performed: if a split node involves a ‘target’ feature, the corresponding left or right branch is followed, otherwise both branches are followed, each branch is weighted by the fraction of training samples that entered that branch. Finally, the partial dependence is given by a weighted average of all visited leaves. For tree ensembles the results of each individual tree are again averaged"

De-Mean results for pdp?

Currently we have

mean(y_hat) | X_i = x_j for all x_j in the perturbed space.

so if for instance:
image

sklearn returns this in deviation form:
image

Which do we think is more helpful?

Enable support for R

So, R has some pretty neat implementations here.
Lets be pragmatic, on how we decide on extending model-interpretation. The goal is to make our library the de-facto implementation ppl decide to use for all forms of interpretation need.

better plotting api for setting axes, comparing multiple models

This pattern is enabled for FI but not for pdps:

f, axes = plt.subplots(2,2, figsize = (16, 16))

ax_dict = {
    'mlp':axes[0][0],
    'knn':axes[1][0],
    'reg':axes[0][1],
    'gb':axes[1][1]
}

interpreter = Interpretation()
interpreter.load_data(X_test, feature_names=data.feature_names)
for model_key in models:
    pyint_model = InMemoryModel(models[model_key].predict_proba, examples=X_train)
    ax = ax_dict[model_key]
    interpreter.feature_importance.plot_feature_importance(pyint_model, ax=ax)
    ax.set_title(model_key)

plot bug

%matplotlib inline
from sklearn.datasets import load_boston, load_breast_cancer
from sklearn.ensemble import GradientBoostingRegressor, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression, LinearRegression
from pyinterpret.core.explanations import Interpretation
import pandas as pd
import numpy as np

classifier_data = load_breast_cancer()
classifier_X = classifier_data.data
classifier_y = classifier_data.target

classifier = GradientBoostingClassifier()
classifier.fit(classifier_X, classifier_y)

classifier_feature_id = [7]
classifier_feature_name = [classifier_data.feature_names[i] for i in classifier_feature_id]

classifier_feature_ids = [7, 23]
classifier_feature_names = [classifier_data.feature_names[i] for i in classifier_feature_ids]    

interpreter = Interpretation()
interpreter.load_data(classifier_X)
feature_ids = [classifier.feature_importances_.argsort()[-1]]
interpreter.partial_dependence.plot_partial_dependence(feature_ids, classifier.predict, with_variance=False)

fails:

TypeErrorTraceback (most recent call last)
<ipython-input-29-761529fb5028> in <module>()
      2 interpreter.load_data(classifier_X)
      3 feature_ids = [classifier.feature_importances_.argsort()[-1]]
----> 4 interpreter.partial_dependence.plot_partial_dependence(feature_ids, classifier.predict, with_variance=False)

/usr/local/lib/python2.7/dist-packages/pyinterpret-0.0.1-py2.7.egg/pyinterpret/core/global_interpretation/partial_dependence.pyc in plot_partial_dependence(self, feature_ids, predict_fn, class_id, grid, grid_resolution, grid_range, sample, sampling_strategy, n_samples, bin_count, samples_per_bin, with_variance)
    238                                       samples_per_bin=samples_per_bin)
    239 
--> 240         ax = self._plot_pdp_from_df(feature_ids, pdp, with_variance=with_variance)
    241         return ax
    242 

/usr/local/lib/python2.7/dist-packages/pyinterpret-0.0.1-py2.7.egg/pyinterpret/core/global_interpretation/partial_dependence.pyc in _plot_pdp_from_df(self, feature_ids, pdp, with_variance)
    291                 data = pdp.set_index(feature_name)
    292                 plane = data[mean_col]
--> 293                 plane.plot(ax=ax, color=color)
    294 
    295                 if with_variance:

/usr/local/lib/python2.7/dist-packages/pandas/tools/plotting.pyc in __call__(self, kind, ax, figsize, use_index, title, grid, legend, style, logx, logy, loglog, xticks, yticks, xlim, ylim, rot, fontsize, colormap, table, yerr, xerr, label, secondary_y, **kwds)
   3564                            colormap=colormap, table=table, yerr=yerr,
   3565                            xerr=xerr, label=label, secondary_y=secondary_y,
-> 3566                            **kwds)
   3567     __call__.__doc__ = plot_series.__doc__
   3568 

/usr/local/lib/python2.7/dist-packages/pandas/tools/plotting.pyc in plot_series(data, kind, ax, figsize, use_index, title, grid, legend, style, logx, logy, loglog, xticks, yticks, xlim, ylim, rot, fontsize, colormap, table, yerr, xerr, label, secondary_y, **kwds)
   2643                  yerr=yerr, xerr=xerr,
   2644                  label=label, secondary_y=secondary_y,
-> 2645                  **kwds)
   2646 
   2647 

/usr/local/lib/python2.7/dist-packages/pandas/tools/plotting.pyc in _plot(data, x, y, subplots, ax, kind, **kwds)
   2439         plot_obj = klass(data, subplots=subplots, ax=ax, kind=kind, **kwds)
   2440 
-> 2441     plot_obj.generate()
   2442     plot_obj.draw()
   2443     return plot_obj.result

/usr/local/lib/python2.7/dist-packages/pandas/tools/plotting.pyc in generate(self)
   1024     def generate(self):
   1025         self._args_adjust()
-> 1026         self._compute_plot_data()
   1027         self._setup_subplots()
   1028         self._make_plot()

/usr/local/lib/python2.7/dist-packages/pandas/tools/plotting.pyc in _compute_plot_data(self)
   1133         if is_empty:
   1134             raise TypeError('Empty {0!r}: no numeric data to '
-> 1135                             'plot'.format(numeric_data.__class__.__name__))
   1136 
   1137         self.data = numeric_data

TypeError: Empty 'DataFrame': no numeric data to plot

Example Notebook for deployed model

  • Build simple Regression model using Keras, deploy it, and then interpret it globally and locally
  • A text model using spaCy(A basic LDA), deploy it, and then interpret it locally.

notebook examples

  1. with binary classification example - LR, RF/GBM, SVM
  2. Regression example - linear RF/GBM,
  3. Multi-class classification example
  4. Multi-label classification example

Added SVM as well. Did I miss on anything else ?

n_classes inference potential bug

we infer the number of classes for classifiers by running model.predict on the dataset
say however you had data whereby the classifier only tended to output class 1 or class 2, though technically class 3 is possible
then if we perturbed data in such a way as to get a class 3 prediction back...
then wed have a bug

POTENTIAL SOLUTIONS:
1:
in the event that we see a different class in the perturbed data, just start the whole algorithm over with the updated information
2:
make those 3 model types separate functions, and for classifiers you specify the classes ( i dont like, what if the user messes up or doesnt know)
3:
we always have an additional null class. in the event that a new class is discovered (any new class), then it just inherits the effect of the null class, and the null class is displayed as "other classes"

Benchmark whether cythonize compute_pd improves performance

Im not convinced this will help, because regardless of defining c types and indexing etc we still need to rely on predict functions. Could be interesting though. I have a cythonized version that id like to build on (side project at home). Will PR here with benchmark results when ready.

A more organized package structure

Possible alternatives:

1. datascience.ai.interpretation(:+1)
2. datascience.model.mai (_not sure what sub-module model stands for here_)
3. datascience.ml.mai (:+1)
4. datascience.mi

Improvement: Improvement to partial_dependence function

Too many if else is ugly and error prone and difficult to extend. When time permits lets address this.
This will help in reducing bugs. We should do this earlier than later because as we move forward we will forget about things.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.