Code Monkey home page Code Monkey logo

hstats's People

Contributors

mayer79 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

hstats's Issues

Support for H2O models

Awesome package! Thanks for all the work, @mayer79. Great to see such a solid implementation of H statistics in R. Well documented and compatibile with many modelling packages.

Would it be feasible to add out-of-the-box support for H2O models? Due to the way how you've set up the code, it's already possible now to use the pred_fun argument for H2O models in the various functions. See the code example down below, where I illustrate this for binomial, regression and multinomial H2O models.

I believe overwriting the generics for hstats(), partial_dep(), ice(), and perm_importance() would require two main things:

  1. Changing the pred_funargument to:
# Class "H2OBinomialModel" 
pred_fun <- function(object, newdata) {
  as.data.frame(h2o::h2o.predict(object, h2o::as.h2o(newdata)))[[3]]
}

# Class "H2ORegressionModel"
pred_fun <- function(object, newdata) {
  as.data.frame(h2o::h2o.predict(object, h2o::as.h2o(newdata)))[[1]]
}

# Class "H2OMulticlassModel"
pred_fun <- function(object, newdata) {
  as.data.frame(h2o::h2o.predict(object, h2o::as.h2o(newdata)))[, -1]
}
  1. Ensuring the provided data argument (potentially an H2OFrame) gets converted to a data.frame using as.data.frame()

Further suggestions / potential improvements:

  • Would it be possible/better to set up a generic for the pred_fun itself? Then you'd only have to overwrite that one for ranger, Learner, explainer, H2OBinomialModel, H2ORegressionModel, and H2OMulticlassModel classes instead of all of hstats(), partial_dep(), ice(), and perm_importance().
  • The prediction function for H2O models requires converting the data to an H2OFrame using as.h2o() and then calling h2o.predict(). Doing this only once is much faster than calling pred_fun multiple times. Hence, particularly for H2O models, further speed improvements are possible by first combining the data and then calling pred_fun instead of the other way around. [Try running the below examples using h2o.show_progress() to see a progress bar whenever as.h2o() and h2o.predict() get called]
    • partial_dep(..., BY = ...) could be faster by avoiding the for loop over the BY argument.
    • hstats() could be faster by avoiding the for loop over the one-way, two-way and three-way effects. [This would be hard to refactor though]
    • perm_importance() could be faster by avoiding the for loop over the v argument. [This might lead to memory issues though when stacking too many data frames when v, m_rep and/or n_max is large, so potentially having an optional argument for this would be better]
  • For H2O models, the default argument when v = NULL in hstats() and perm_importance() could be set to object@allparameters$x. That way no unnecessary computations are performed for columns in X not used as features.
  • For H2O models, the default argument for y in perm_importance() can be set to object@allparameters$y.

Code example for H2O models:

library(hstats)
library(h2o)
h2o.init()
h2o.no_progress()

# H2O Binomial ----

# Import the prostate dataset into H2O:
prostate <- h2o.importFile("http://s3.amazonaws.com/h2o-public-test-data/smalldata/prostate/prostate.csv")
prostate_df <- as.data.frame(prostate)

# Set the predictors and response; set the factors:
prostate$CAPSULE <- as.factor(prostate$CAPSULE)
prostate$RACE <- as.factor(prostate$RACE)
prostate$DPROS <- as.factor(prostate$DPROS)
prostate$DCAPS <- as.factor(prostate$DCAPS)
prostate$GLEASON <- as.factor(prostate$GLEASON)
predictors <- c("AGE", "RACE", "DPROS", "DCAPS", "PSA", "VOL", "GLEASON")
response <- "CAPSULE"

# Build and train the model:
pros_gbm <- h2o.gbm(x = predictors,
                    y = response,
                    nfolds = 5,
                    seed = 1111,
                    keep_cross_validation_predictions = TRUE,
                    training_frame = prostate)

# Class "H2OBinomialModel" 
pred_fun <- function(object, newdata) {
  as.data.frame(h2o::h2o.predict(object, h2o::as.h2o(newdata)))[[3]]
}

# H stats
s <- hstats(pros_gbm, X = prostate_df, pred_fun = pred_fun)
s
#> 'hstats' object. Use plot() or summary() for details.
#> 
#> H^2 (normalized)
#> [1] 0.2818805

# H statistics: overall level of interaction & pairwise
plot(s)

# Unnormalized pairwise statistics
plot(h2_pairwise(s, normalize = FALSE, squared = FALSE))

# Stratified PD
plot(partial_dep(pros_gbm, v = "PSA", X = prostate_df, BY = "GLEASON", pred_fun = pred_fun), show_points = FALSE)

# Two-dimensional PDP
pd <- partial_dep(pros_gbm, v = c("PSA", "GLEASON"), X = prostate_df, pred_fun = pred_fun, grid_size = 1000)
plot(pd)

# Centered ICE plot with colors
ic <- ice(pros_gbm, v = "PSA", X = prostate_df, BY = "GLEASON", pred_fun = pred_fun)
plot(ic, center = TRUE)

# Variable importance based on permutation
plot(perm_importance(pros_gbm, X = prostate_df, y = pros_gbm@allparameters$y, pred_fun = pred_fun, v = pros_gbm@allparameters$x))

# Variable importance based on PD function
plot(pd_importance(s))

# H2O Regression ----

# Run GLM of VOL ~ CAPSULE + AGE + RACE + PSA + GLEASON
prostate_path = system.file("extdata", "prostate.csv", package = "h2o")
prostate = h2o.importFile(path = prostate_path)
predictors <- setdiff(colnames(prostate), c("ID", "DPROS", "DCAPS", "VOL"))
pros_glm <- h2o.glm(y = "VOL", x = predictors, training_frame = prostate, family = "tweedie",
                    nfolds = 0, alpha = 0.1, lambda_search = FALSE,
                    interactions = c("AGE", "GLEASON"))

# Class "H2ORegressionModel"
pred_fun <- function(object, newdata) {
  as.data.frame(h2o::h2o.predict(object, h2o::as.h2o(newdata)))[[1]]
}

# H stats
s <- hstats(pros_glm, X = as.data.frame(prostate), pred_fun = pred_fun)
s
#> 'hstats' object. Use plot() or summary() for details.
#> 
#> H^2 (normalized)
#> [1] 0.0007039798

# Pairwise statistics 
plot(h2_pairwise(s, normalize = FALSE, squared = FALSE))

# PD plot 
pd <- partial_dep(pros_glm, v = "AGE", X = as.data.frame(prostate), BY = "GLEASON", pred_fun = pred_fun, grid_size = 500)
plot(pd)

# H2O Multinomial ----

# import the cars dataset
cars <- h2o.importFile("https://s3.amazonaws.com/h2o-public-test-data/smalldata/junit/cars_20mpg.csv")

# set the predictor names and the response column name
predictors <- c("displacement", "power", "weight", "acceleration", "year")
response <- "cylinders"
cars[, response] <- as.factor(cars[response])

# split into train and validation sets
cars_splits <- h2o.splitFrame(data =  cars, ratios = 0.8, seed = 1234)
train <- cars_splits[[1]]
valid <- cars_splits[[2]]

# build and train the model:
cars_gbm <- h2o.gbm(x = predictors,
                    y = response,
                    training_frame = train,
                    validation_frame = valid,
                    distribution = "multinomial",
                    seed = 1234)

# Class "H2OMulticlassModel"
pred_fun <- function(object, newdata) {
  as.data.frame(h2o::h2o.predict(object, h2o::as.h2o(newdata)))[, -1]
} 

# H stats
s <- hstats(cars_gbm, X = as.data.frame(cars), pred_fun = pred_fun)
s
#> 'hstats' object. Use plot() or summary() for details.
#> 
#> H^2 (normalized)
#>         p3         p4         p5         p6         p8 
#> 0.74177545 0.06181569 0.95886855 0.07111337 0.00996367

# Pairwise statistics 
plot(h2_pairwise(s, normalize = FALSE, squared = FALSE))

# PD plot 
pd <- partial_dep(cars_gbm, v = "displacement", X = as.data.frame(cars), pred_fun = pred_fun, grid_size = 500)
plot(pd)

Created on 2023-10-28 with reprex v2.0.2

Multi-output models with 0 in numerator

library(hstats)
library(ranger)

set.seed(1)

fit <- ranger(Species ~ ., data = iris)
H <- hstats(fit, X = iris[-5])
plot(H)

currently has a 0 in the denominator of one pairwise interaction in one component. This leads to a NA value in the normalized statistic. How to fix?

Interaction effect for multiple features across MultiClass

I have 5 network class categories and some number of features. The classification model was trained using boost trees (xgboost engine) in Tidymodels. I am using the Shapviz package for XGBOOST for the model interpretation (global, local, and interaction effects of features for each network class). I would like to quantify how much interaction effect (two-way, three-way, etc.) is accounted for by the features in each of the network class using the 'hstats' package. Thanks

Add permutation importance

Unrelated to H-statistics, but would be useful to have the basic model-agnostic feature importance measure in the package. It would support multivariate/classification outputs. It would provide an alternative way to select features for which H-statistics are to be calculated.

partial_dep() with BY argument: error for tibble X argument

library(hstats)
library(tibble)

## Dobson (1990) Page 93: Randomized Controlled Trial :
counts <- c(18,17,15,20,10,20,25,13,12)
outcome <- gl(3,1,9)
treatment <- gl(3,3)
df <- data.frame(treatment, outcome, counts) # showing data
glm.D93 <- glm(counts ~ outcome + treatment, family = poisson())

# PD with categorical BY

# Works
partial_dep(glm.D93, v = "outcome", X = df, BY = "treatment")
#> Partial dependence object (9 rows). Extract via $data. Top rows:
#> 
#>   treatment outcome        y
#> 1         1       1 3.044522
#> 2         1       2 2.590267
#> 3         1       3 2.751535

# Does not work
partial_dep(glm.D93, v = "outcome", X = tibble(df), BY = "treatment")
#> Error in data.frame(..., check.names = FALSE): arguments imply differing number of rows: 3, 0

Created on 2023-10-28 with reprex v2.0.2

Root cause seems to be the different behaviour when subsetting tibbles compared to data frames, see here.

That gets done here in the code for prepare_by().

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.