Code Monkey home page Code Monkey logo

survshap's Introduction

SurvSHAP(t)

This repository contains data and code for the article:

M. Krzyziński, M. Spytek, H. Baniecki, P. Biecek. SurvSHAP(t): Time-dependent explanations of machine learning survival models. Knowledge-Based Systems, 262:110234, 2023. https://doi.org/10.1016/j.knosys.2022.110234

@article{survshap,
    title = {SurvSHAP(t): Time-dependent explanations of machine learning survival models},
    author = {Mateusz Krzyziński and Mikołaj Spytek and Hubert Baniecki and Przemysław Biecek},
    journal = {Knowledge-Based Systems},
    volume = {262},
    pages = {110234},
    year = {2023}
}

Implementations

In the survshap_package directory, you will find the code for survshap Python package, which contains the implementation of the SurvSHAP(t) method. Now you can also easily install it from PyPI:

pip install survshap

NOTE: SurvSHAP(t) and SurvLIME are also implemented in the survex R package, along with many more explanation methods for survival models. survex offers explanations for scikit-survival models loaded into R via the reticulate package.

Additional materials

In addition to the package, the repository also contains the materials used for the article (in the paper directory).

other_codes

  • survlime.py is the SurvLIME method implementation
  • survnam directory contains the SurvNAM method implementation (based on Jia-Xiang Chengh implementation)
  • data_generation.R is the code for synthetic censored data generation (for Experiments 1 and 2)
  • plots.R is the code for creating Figures from the article

data

  • data directory contains the datasets used in experiments

experiments

  • experiments directory contains Jupyter Notebooks (*.ipynb files) with code of the conducted experiments

plots

  • plots directory contains Figures in .pdf format

results

  • results directory contains results of the conducted experiments stored in .csv files

survshap's People

Contributors

hbaniecki avatar krzyzinskim avatar mikolajsp avatar pierrickpochelu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar

survshap's Issues

Extension to DeepHit

Hello,

This is great work!

I was wondering if SurvSHAP can be applied to Deep Learning survival Networks such as DeepHit that can produce survival function estimates without being constrained by the "Proportional Hazards (PH)" assumption.

Please let me know. I am keen to try it out.
Ani

Limiting the number of timestamps

By default, limit the number of timestamps by generating new vector based on times in the background data.

Possible methods of generation:

  • equally spaced points ("uniform" in survex)
  • based on quantiles ("quantiles")

Possible default length: 51

Example code to plot the shap values

Could you please point to an example code/notebook to plot the shap values in Python?

I found R codes for the plots in the repository but wondering if python sample code are available.

Thanks

SurvSHAP(t) values for cumulative hazard function

Dear authors,

firstly, thank you very much for providing the opportunity to explain survival models with SHAP values!

I would like to calculate SHAP values based on the attribution to the cumulative hazard function. I was thinking that setting predict_cumulative_hazard_function to True for the object SurvivalModelExplainer would mean calculating SHAP for the chf, but no matter how I vary the parameters predict_cumulative_hazard_function and predict_survival_function the resulted SHAP values do not seem to change.

In the code I saw the opportunity to set the function_type to chf for the PredictSurvSHAP object, but even though SurvivalModelExplainer seems to call PredictSurvSHAP I found no way of setting function_type via SurvivalModelExplainer.

Could you please tell me how can I produce SHAP values for chf? Am I right that current explanations are based on the attribution to survival function?

Kind Regards,
Mariia

SurvSHAP values for predictions made on right-censored data

Hi.

Thanks for your great work on SurvSHAP. I wanted to hear if it's possible for getting SurvSHAP values for predictions made on right-censored data? For example if we have a subject that so far has survived for 5 months, and we want to get SurvSHAP values for the prediction at 7 months (e.g. by using the Cox proportional hazards model). I could see that a correction to the survival function result has to be made in such cases (sebp/scikit-survival#128) and I was unsure of whether or not this impacts the SurvSHAP values for such a prediction. If yes, how can this be done?

Additivity check failed in TreeExplainer! (with check_additivity=False)

Hi, I am trying to use the survshap library to extract SHAP values from a random survival model. I passed the training dataset to the SurvivalModelExplainer. The following code gives an idea:

rsf = RandomSurvivalForest(
    n_estimators=8, n_jobs=-1, random_state=random_state
)

rsf.fit(X_train, y_train)

explainer = SurvivalModelExplainer(model = rsf, data = X_train, y = y_train)
model_survshap = ModelSurvSHAP(calculation_method="treeshap") 
model_survshap.fit(explainer = explainer)

My X_train contains some one-hot encoded features and I'm getting the following error:

ExplainerError: Additivity check failed in TreeExplainer! Please ensure the data matrix you passed to the explainer is the same shape that the model was trained on. If your data shape is correct then please report this on GitHub. This check failed because for one of the samples the sum of the SHAP values was 0.666667, while the model output was 0.000000. If this difference is acceptable you can set check_additivity=False to disable this check.

I passed check_additivity=False which didn't change the outcome.

Expected Future Lifetime

I apologize if I am extrapolating beyond what your package was intended to do. I understand that a large portion of survSHAP(t) is providing explanatory variable analysis. Your paper indicates that SurvSHAP(t) can return an individual's unique survival function given their covariates, distinct from the assumptions in a Cox model. Can your Survival Function be used to calculate an Expected Future Lifetime?
https://en.wikipedia.org/wiki/Survival_analysis#Quantities_derived_from_the_survival_distribution

Wrapper for `shap.TreeExplainer` for SurvSHAP

  • implement shap.TreeExplainer for local explanations (PredictSurvSHAP)
  • implement shap.TreeExplainer for global explanations (ModelSurvSHAP)
  • check the other explainer types from shap library

compatible with time-varying models?

Hello,

I was looking into your development and was wondering if the survshap function can accommodate longitudinal model (panel), e.g., time-varying Cox (Lifelines), Dynamic DeepHit, and Recurrent Survival Machines with competing risks? Thanks!

SurvSHAP model freezes for new data

Hi,
I am trying to run SurvSHAP on my test dataset of shape 200x8 and the model ends up freezing without any error messages

  1. I was able to reproduce the results of the paper by running https://github.com/MI2DataLab/survshap/blob/main/exp3.ipynb

  2. My dataset is very similar to the dataset used in epx3. Using my dataset, I am able to generate the integrated brier scores of the rsf pipeline

  3. However, when I try to use the SurvivalModelExplainer and ModelSurvShap to generate explanations, I get the following progress bar that does not run. I have tested this over the course of 10 hours and there is no progress, as well as no error message
    image

  4. As mentioned earlier, my dataset is similar to the default dataset used in exp3.ipynb, which runs just fine.
    Can you please help out with this?

Error with the "sampling" calculation method

import numpy as np
import pandas as pd
from survshap import SurvivalModelExplainer, ModelSurvSHAP
import time

nb_features=7
nb_events=200

np_X=np.random.rand(nb_events, nb_features)
np_time=np.random.rand(nb_events, 1)
np_is_living=np_X[:,0] < np_time[:,0]

y=np.empty(nb_events, dtype=[('event', '?'), ('time', '<f16')])
y['event']=np_is_living.reshape(-1)
y['time']=np_time.reshape(-1)
X=pd.DataFrame(np_X,columns=['f'+str(i) for i in range(1,nb_features+1)])

from sksurv.ensemble import RandomSurvivalForest
rsf=RandomSurvivalForest(random_state=42)
st=time.time()
rsf.fit(X,y)
print(f"score:{rsf.score(X,y)} fit time:{time.time()-st}")
print(f"predict: {rsf.predict(X)}")


exp_rsf=SurvivalModelExplainer(rsf,X,y)
ms_rsf=ModelSurvSHAP(random_state=42, calculation_method="sampling")
st=time.time()
ms_rsf.fit(exp_rsf)
print(f"Interpretation time:{time.time()-st}")

produces:

    raise TypeError(f"Could not convert {x} to numeric") from err
TypeError: Could not convert f1f1f1f1f1f1f1f1f1f1f1f1f1f1f1f1f1f1f1f1f1f1f1f1f1 to numeric

Process finished with exit code 1

Scalability over larger datasets

Hi,

Is your method scalable over larger datasets? I tried running this method on a dataset of size (10000, 8) and got an estimated run time as below. This should not be the case since your own test dataset is of size (300,8) and the time per iteration is low. Are you retraining the model for computing the shape values for each example? It is not clear to me why the time per iteration has increased so much given the number of features is the same.

image

draft for v1.0 [PyPI release]

Potentially, version 1.0 of survshap could be published as a separate package on PyPI.

What needs to be done?

  • adding unit tests
  • adding missing and refactoring existing documentation
  • adding a notebook (vignette) other than experiments with an use case
  • #17

Memory issue

Hi,

First, congratulations for this project,

It seems the memory and time complexity is O(2 power n) with n the number of features in the dataset.

simplified_inputs = [list(z) for z in itertools.product(range(2), repeat=p)]

What do you propose when n is large ? (e.g., 300). Do you plan to code approximation method (E.g. based on Monte Carlo sampling) ?

Regards,

nquiry Regarding Legend Addition in plot_mean_abs_shap_values() Function

Dear Dr. Krzyziński and the MI2 Data Lab,

I extend my gratitude to you for developing the 'SurvSHAP' tool, which has proven instrumental in effectively ranking and visualizing the integrability of time-to-event models. I am very interested in this package and would like to apply it for my further experiments. I am writing to seek guidance on a specific aspect.

I would appreciate your assistance in understanding how to incorporate legends for each line within the plot_mean_abs_shap_values() function. I aim to label the names of each predictive feature, similar to the presentation in your publication (figure 11, left).

Thank you very much in advance.

Best regards,

Nhu, NT.

example data is missing

Thank you so much for your contributions on black box survival analysis.
Looks like "data/exp3_real_data.csv" is missing from the repo, please confirm, thanks!

🆘 How to specify timestamps

Hello,

We're currently trying to estimate SHAP values from a random forest survival object (n=8000, predictors = 8) and we are getting tremendous running times. We were wondering whether we could reduce the number of timestamps (e.g. 2 years, 5 years and 10 years instead of a 100 equally-spaced times) - and hence reduce the running time. However, we're not sure how to pass it to the function (we tried with R and it failed).

from survshap import SurvivalModelExplainer, ModelSurvSHAP
rsf_exp = SurvivalModelExplainer(rsf, X[1:300], y[1:300])

exp3_survshap_global_rsf = ModelSurvSHAP(random_state=42)
exp3_survshap_global_rsf.fit(rsf_exp)

Any help would be greatly appreciated, thank you for this amazing package.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.