mayer79 / hstats Goto Github PK
View Code? Open in Web Editor NEWFriedman's H-statistics
Home Page: https://mayer79.github.io/hstats/
License: GNU General Public License v2.0
Friedman's H-statistics
Home Page: https://mayer79.github.io/hstats/
License: GNU General Public License v2.0
Awesome package! Thanks for all the work, @mayer79. Great to see such a solid implementation of H statistics in R. Well documented and compatibile with many modelling packages.
Would it be feasible to add out-of-the-box support for H2O models? Due to the way how you've set up the code, it's already possible now to use the pred_fun
argument for H2O models in the various functions. See the code example down below, where I illustrate this for binomial, regression and multinomial H2O models.
I believe overwriting the generics for hstats()
, partial_dep()
, ice()
, and perm_importance()
would require two main things:
pred_fun
argument to:# Class "H2OBinomialModel"
pred_fun <- function(object, newdata) {
as.data.frame(h2o::h2o.predict(object, h2o::as.h2o(newdata)))[[3]]
}
# Class "H2ORegressionModel"
pred_fun <- function(object, newdata) {
as.data.frame(h2o::h2o.predict(object, h2o::as.h2o(newdata)))[[1]]
}
# Class "H2OMulticlassModel"
pred_fun <- function(object, newdata) {
as.data.frame(h2o::h2o.predict(object, h2o::as.h2o(newdata)))[, -1]
}
as.data.frame()
Further suggestions / potential improvements:
pred_fun
itself? Then you'd only have to overwrite that one for ranger, Learner, explainer, H2OBinomialModel, H2ORegressionModel, and H2OMulticlassModel classes instead of all of hstats()
, partial_dep()
, ice()
, and perm_importance()
.as.h2o()
and then calling h2o.predict()
. Doing this only once is much faster than calling pred_fun
multiple times. Hence, particularly for H2O models, further speed improvements are possible by first combining the data and then calling pred_fun
instead of the other way around. [Try running the below examples using h2o.show_progress()
to see a progress bar whenever as.h2o()
and h2o.predict()
get called]
partial_dep(..., BY = ...)
could be faster by avoiding the for loop over the BY
argument.hstats()
could be faster by avoiding the for loop over the one-way, two-way and three-way effects. [This would be hard to refactor though]perm_importance()
could be faster by avoiding the for loop over the v
argument. [This might lead to memory issues though when stacking too many data frames when v
, m_rep
and/or n_max
is large, so potentially having an optional argument for this would be better]v = NULL
in hstats()
and perm_importance()
could be set to object@allparameters$x
. That way no unnecessary computations are performed for columns in X
not used as features.y
in perm_importance()
can be set to object@allparameters$y
.Code example for H2O models:
library(hstats)
library(h2o)
h2o.init()
h2o.no_progress()
# H2O Binomial ----
# Import the prostate dataset into H2O:
prostate <- h2o.importFile("http://s3.amazonaws.com/h2o-public-test-data/smalldata/prostate/prostate.csv")
prostate_df <- as.data.frame(prostate)
# Set the predictors and response; set the factors:
prostate$CAPSULE <- as.factor(prostate$CAPSULE)
prostate$RACE <- as.factor(prostate$RACE)
prostate$DPROS <- as.factor(prostate$DPROS)
prostate$DCAPS <- as.factor(prostate$DCAPS)
prostate$GLEASON <- as.factor(prostate$GLEASON)
predictors <- c("AGE", "RACE", "DPROS", "DCAPS", "PSA", "VOL", "GLEASON")
response <- "CAPSULE"
# Build and train the model:
pros_gbm <- h2o.gbm(x = predictors,
y = response,
nfolds = 5,
seed = 1111,
keep_cross_validation_predictions = TRUE,
training_frame = prostate)
# Class "H2OBinomialModel"
pred_fun <- function(object, newdata) {
as.data.frame(h2o::h2o.predict(object, h2o::as.h2o(newdata)))[[3]]
}
# H stats
s <- hstats(pros_gbm, X = prostate_df, pred_fun = pred_fun)
s
#> 'hstats' object. Use plot() or summary() for details.
#>
#> H^2 (normalized)
#> [1] 0.2818805
# H statistics: overall level of interaction & pairwise
plot(s)
# Unnormalized pairwise statistics
plot(h2_pairwise(s, normalize = FALSE, squared = FALSE))
# Stratified PD
plot(partial_dep(pros_gbm, v = "PSA", X = prostate_df, BY = "GLEASON", pred_fun = pred_fun), show_points = FALSE)
# Two-dimensional PDP
pd <- partial_dep(pros_gbm, v = c("PSA", "GLEASON"), X = prostate_df, pred_fun = pred_fun, grid_size = 1000)
plot(pd)
# Centered ICE plot with colors
ic <- ice(pros_gbm, v = "PSA", X = prostate_df, BY = "GLEASON", pred_fun = pred_fun)
plot(ic, center = TRUE)
# Variable importance based on permutation
plot(perm_importance(pros_gbm, X = prostate_df, y = pros_gbm@allparameters$y, pred_fun = pred_fun, v = pros_gbm@allparameters$x))
# Variable importance based on PD function
plot(pd_importance(s))
# H2O Regression ----
# Run GLM of VOL ~ CAPSULE + AGE + RACE + PSA + GLEASON
prostate_path = system.file("extdata", "prostate.csv", package = "h2o")
prostate = h2o.importFile(path = prostate_path)
predictors <- setdiff(colnames(prostate), c("ID", "DPROS", "DCAPS", "VOL"))
pros_glm <- h2o.glm(y = "VOL", x = predictors, training_frame = prostate, family = "tweedie",
nfolds = 0, alpha = 0.1, lambda_search = FALSE,
interactions = c("AGE", "GLEASON"))
# Class "H2ORegressionModel"
pred_fun <- function(object, newdata) {
as.data.frame(h2o::h2o.predict(object, h2o::as.h2o(newdata)))[[1]]
}
# H stats
s <- hstats(pros_glm, X = as.data.frame(prostate), pred_fun = pred_fun)
s
#> 'hstats' object. Use plot() or summary() for details.
#>
#> H^2 (normalized)
#> [1] 0.0007039798
# Pairwise statistics
plot(h2_pairwise(s, normalize = FALSE, squared = FALSE))
# PD plot
pd <- partial_dep(pros_glm, v = "AGE", X = as.data.frame(prostate), BY = "GLEASON", pred_fun = pred_fun, grid_size = 500)
plot(pd)
# H2O Multinomial ----
# import the cars dataset
cars <- h2o.importFile("https://s3.amazonaws.com/h2o-public-test-data/smalldata/junit/cars_20mpg.csv")
# set the predictor names and the response column name
predictors <- c("displacement", "power", "weight", "acceleration", "year")
response <- "cylinders"
cars[, response] <- as.factor(cars[response])
# split into train and validation sets
cars_splits <- h2o.splitFrame(data = cars, ratios = 0.8, seed = 1234)
train <- cars_splits[[1]]
valid <- cars_splits[[2]]
# build and train the model:
cars_gbm <- h2o.gbm(x = predictors,
y = response,
training_frame = train,
validation_frame = valid,
distribution = "multinomial",
seed = 1234)
# Class "H2OMulticlassModel"
pred_fun <- function(object, newdata) {
as.data.frame(h2o::h2o.predict(object, h2o::as.h2o(newdata)))[, -1]
}
# H stats
s <- hstats(cars_gbm, X = as.data.frame(cars), pred_fun = pred_fun)
s
#> 'hstats' object. Use plot() or summary() for details.
#>
#> H^2 (normalized)
#> p3 p4 p5 p6 p8
#> 0.74177545 0.06181569 0.95886855 0.07111337 0.00996367
# Pairwise statistics
plot(h2_pairwise(s, normalize = FALSE, squared = FALSE))
# PD plot
pd <- partial_dep(cars_gbm, v = "displacement", X = as.data.frame(cars), pred_fun = pred_fun, grid_size = 500)
plot(pd)
Created on 2023-10-28 with reprex v2.0.2
library(hstats)
library(ranger)
set.seed(1)
fit <- ranger(Species ~ ., data = iris)
H <- hstats(fit, X = iris[-5])
plot(H)
currently has a 0 in the denominator of one pairwise interaction in one component. This leads to a NA value in the normalized statistic. How to fix?
I have 5 network class categories and some number of features. The classification model was trained using boost trees (xgboost engine) in Tidymodels. I am using the Shapviz package for XGBOOST for the model interpretation (global, local, and interaction effects of features for each network class). I would like to quantify how much interaction effect (two-way, three-way, etc.) is accounted for by the features in each of the network class using the 'hstats' package. Thanks
This was proposed in #91 by @RoelVerbelen. Very good visualization to detect interactions.
Unrelated to H-statistics, but would be useful to have the basic model-agnostic feature importance measure in the package. It would support multivariate/classification outputs. It would provide an alternative way to select features for which H-statistics are to be calculated.
library(hstats)
library(tibble)
## Dobson (1990) Page 93: Randomized Controlled Trial :
counts <- c(18,17,15,20,10,20,25,13,12)
outcome <- gl(3,1,9)
treatment <- gl(3,3)
df <- data.frame(treatment, outcome, counts) # showing data
glm.D93 <- glm(counts ~ outcome + treatment, family = poisson())
# PD with categorical BY
# Works
partial_dep(glm.D93, v = "outcome", X = df, BY = "treatment")
#> Partial dependence object (9 rows). Extract via $data. Top rows:
#>
#> treatment outcome y
#> 1 1 1 3.044522
#> 2 1 2 2.590267
#> 3 1 3 2.751535
# Does not work
partial_dep(glm.D93, v = "outcome", X = tibble(df), BY = "treatment")
#> Error in data.frame(..., check.names = FALSE): arguments imply differing number of rows: 3, 0
Created on 2023-10-28 with reprex v2.0.2
Root cause seems to be the different behaviour when subsetting tibbles compared to data frames, see here.
That gets done here in the code for prepare_by()
.
Add function average_loss()
to compute performance metrics.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.