Code Monkey home page Code Monkey logo

Comments (15)

maksymiuks avatar maksymiuks commented on June 18, 2024 2

I've merged the changes to master and will push them to CRAN next week. Let me know if I could do something else to ease in integration between DALEX and tidymodels ecosytem

from dalextra.

maksymiuks avatar maksymiuks commented on June 18, 2024 1

Thank you so much for this request as well as for including DALEXtra in your manual!

I understand the need and frankly speaking, upon creating those methods in DALEXtra we were not able to extract the raw model and raw data. Thanks to your examples I'll be able to extend functionalities.

I'll ping you in this thread once I finish my work.

Also please let me know if you will need anything while preparing this chapter.

from dalextra.

maksymiuks avatar maksymiuks commented on June 18, 2024 1

@topepo Thanks!

This is exactly what's happening. I made it possible to pass parnsip model directly (previously it resulted in an error) to get a detailed overview of potentially new variables introduced to the model (like date break-down in the example). Passing workflow would work exactly the same as it was working previously, I didn't change that behavior.

Does it answer your concerns?

from dalextra.

topepo avatar topepo commented on June 18, 2024

Let us know if there is any way that we can help. There are some pull_*() functions in workflows that make things a little easier.

from dalextra.

gofford avatar gofford commented on June 18, 2024

I've been looking at this, and I'm actually using a similar process to the one posted by @topepo.

Small issue I'm seeing, however, is that the role of the predictor/feature does not seems to be respected. In the example above the role of date is changed to id from predictor (and so isn't used in prediction), but DALEX includes both date and the derived features in the feature importance as it does not respect the role. E.g., at the bottom here:

image

Not a big deal in this example, but in my real work I tend to move highly correlated features into different roles so that they are retained in the data but not used for prediction. The variable importance charts from DALEX are therefore difficult to interpret because the (known, and excluded) highly correlated features dominate despite their not being used in the model itself.

Current workaround is to exclude the features from feature_data manually, but would be good if this was handled automatically.

from dalextra.

topepo avatar topepo commented on June 18, 2024

Could you post a reprex on that? It looks like it's using the Chicago data and I can help troubleshoot that.

from dalextra.

gofford avatar gofford commented on June 18, 2024

@topepo here you go. It's basically the example you posted above. Any thoughts?

library(tidymodels)
#> Registered S3 method overwritten by 'tune':
#>   method                   from   
#>   required_pkgs.model_spec parsnip
#> Warning: package 'parsnip' was built under R version 4.0.5
library(DALEX)
#> Warning: package 'DALEX' was built under R version 4.0.5
#> Welcome to DALEX (version: 2.2.1).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#> Additional features will be available after installation of: ggpubr.
#> Use 'install_dependencies()' to get all suggested dependencies
#> 
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#> 
#>     explain
library(DALEXtra)
#> Anaconda not found on your computer. Conda related functionality such as create_env.R and condaenv and yml parameters from explain_scikitlearn will not be available

data("Chicago")

rec <- recipe(ridership ~ date + Clark_Lake + California, data = Chicago) %>% 
  step_date(date) %>% 
  step_holiday(date) %>% 
  update_role(date, new_role = "id") %>% 
  step_dummy(all_nominal_predictors())

lm_spec <- linear_reg() %>% set_engine("lm")

lm_wflow <- 
  workflow() %>% 
  add_model(lm_spec) %>% 
  add_recipe(rec) %>%
  fit(data = Chicago)

lm_fit <- lm_wflow %>% 
  pull_workflow_fit()

feature_data <- 
  lm_wflow %>% 
  pull_workflow_prepped_recipe() %>% 
  bake(new_data = Chicago)

vip_features <- 
  ## DALEXtra::explain_tidymodels takes a workflow object instead of extracted model. 
  ## use DALEX::explain instead
  explain(
    lm_fit, 
    data = feature_data %>% select(-ridership), 
    y = feature_data$ridership
  ) %>% 
  model_parts()
#> Preparation of a new explainer is initiated
#>   -> model label       :  model_fit  ( [33m default [39m )
#>   -> data              :  5698  rows  24  cols 
#>   -> data              :  tibble converted into a data.frame 
#>   -> target variable   :  5698  values 
#>   -> predict function  :  yhat.model_fit  will be used ( [33m default [39m )
#>   -> predicted values  :  No value for predict function target column. ( [33m default [39m )
#>   -> model_info        :  package parsnip , ver. 0.1.6 , task regression ( [33m default [39m ) 
#>   -> predicted values  :  numerical, min =  -10.0413 , mean =  13.61933 , max =  21.75192  
#>   -> residual function :  difference between y and yhat ( [33m default [39m )
#>   -> residuals         :  numerical, min =  -16.28978 , mean =  6.959834e-13 , max =  10.9823  
#>  [32m A new explainer has been created! [39m

plot(vip_features)

Created on 2021-06-11 by the reprex package (v2.0.0)

Session info
sessionInfo()
#> R version 4.0.4 (2021-02-15)
#> Platform: x86_64-pc-linux-gnu (64-bit)
#> Running under: Ubuntu 20.04.1 LTS
#> 
#> Matrix products: default
#> BLAS/LAPACK: /usr/lib/x86_64-linux-gnu/openblas-pthread/libopenblasp-r0.3.8.so
#> 
#> locale:
#>  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
#>  [3] LC_TIME=en_US.UTF-8        LC_COLLATE=en_US.UTF-8    
#>  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=C             
#>  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
#>  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
#> [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#>  [1] DALEXtra_2.1.1     DALEX_2.2.1        yardstick_0.0.8    workflowsets_0.0.2
#>  [5] workflows_0.2.2    tune_0.1.5         tidyr_1.1.3        tibble_3.1.2      
#>  [9] rsample_0.1.0      recipes_0.1.16     purrr_0.3.4        parsnip_0.1.6     
#> [13] modeldata_0.1.0    infer_0.5.4        ggplot2_3.3.3      dplyr_1.0.6       
#> [17] dials_0.0.9        scales_1.1.1       broom_0.7.6        tidymodels_0.1.3  
#> 
#> loaded via a namespace (and not attached):
#>  [1] fs_1.5.0           lubridate_1.7.10   httr_1.4.2         DiceDesign_1.9    
#>  [5] tools_4.0.4        backports_1.2.1    utf8_1.2.1         R6_2.5.0          
#>  [9] rpart_4.1-15       DBI_1.1.1          colorspace_2.0-1   nnet_7.3-15       
#> [13] withr_2.4.2        tidyselect_1.1.1   curl_4.3.1         compiler_4.0.4    
#> [17] cli_2.5.0          xml2_1.3.2         labeling_0.4.2     rappdirs_0.3.3    
#> [21] stringr_1.4.0      digest_0.6.27      ingredients_2.2.0  rmarkdown_2.8     
#> [25] pkgconfig_2.0.3    htmltools_0.5.1.1  parallelly_1.25.0  styler_1.4.1      
#> [29] lhs_1.1.1          highr_0.9          rlang_0.4.11       rstudioapi_0.13   
#> [33] farver_2.1.0       generics_0.1.0     jsonlite_1.7.2     magrittr_2.0.1    
#> [37] Matrix_1.3-2       Rcpp_1.0.6         munsell_0.5.0      fansi_0.5.0       
#> [41] GPfit_1.0-8        reticulate_1.20    lifecycle_1.0.0    furrr_0.2.2       
#> [45] stringi_1.6.2      pROC_1.17.0.1      yaml_2.2.1         MASS_7.3-53       
#> [49] plyr_1.8.6         grid_4.0.4         parallel_4.0.4     listenv_0.8.0     
#> [53] crayon_1.4.1       lattice_0.20-41    splines_4.0.4      knitr_1.33        
#> [57] pillar_1.6.1       codetools_0.2-18   reprex_2.0.0       glue_1.4.2        
#> [61] evaluate_0.14      vctrs_0.3.8        png_0.1-7          foreach_1.5.1     
#> [65] gtable_0.3.0       future_1.21.0      assertthat_0.2.1   xfun_0.23         
#> [69] gower_0.2.2        mime_0.10          prodlim_2019.11.13 class_7.3-18      
#> [73] survival_3.2-7     timeDate_3043.102  iterators_1.0.13   hardhat_0.1.5     
#> [77] lava_1.6.9         globals_0.14.0     ellipsis_0.3.2     ipred_0.9-11

from dalextra.

topepo avatar topepo commented on June 18, 2024

The issue is that bake(), by default, returns all of the existing columns. You can get rid of the other stuff using

feature_data <- 
    lm_wflow %>% 
    pull_workflow_prepped_recipe() %>% 
    bake(new_data = Chicago, all_predictors(), all_outcomes())

and date won't be in the analysis or plot.

from dalextra.

srgillott avatar srgillott commented on June 18, 2024

I am having the same issue regarding the inclusion of features that had roles updated. I have tried the suggested work around @topepo suggested. I have also manually removed the features from the dataset but get this message when calling model_parts

I then get.

Error in glubort():
! The following required columns are missing: 'date'.
Backtrace:

  1. DALEX::model_parts(explainer_rf, loss_function = loss_root_mean_square)
  2. ingredients:::feature_importance.explainer(...)
  3. ingredients:::feature_importance.default(...)
  4. base::replicate(B, loss_after_permutation())
  5. base::sapply(...)
  6. base::lapply(X = X, FUN = FUN, ...)
  7. ingredients FUN(X[[i]], ...)
  8. ingredients loss_after_permutation()
  9. DALEXtra:::yhat.workflow(x, sampled_data)
  10. workflows:::predict.workflow(X.model, newdata)
  11. workflows:::forge_predictors(new_data, workflow)
  12. hardhat:::forge.data.frame(new_data, blueprint = mold$blueprint)
  13. blueprint$forge$clean(...)
  14. hardhat:::forge_recipe_default_clean_extras(blueprint, new_data)
  15. hardhat:::map(blueprint$extra_role_ptypes, shrink, data = new_data)
  16. base::lapply(.x, .f, ...)
  17. hardhat FUN(X[[i]], ...)
  18. hardhat::validate_column_names(data, cols)
  19. hardhat:::glubort("The following required columns are missing: {missing_names}.")
    Run rlang::last_trace() to see the full context.

from dalextra.

juliasilge avatar juliasilge commented on June 18, 2024

@srgillott Can you create a reprex (a minimal reproducible example) for this? We definitely want to support using DALEXtra with tidymodels, and would like to see how to help. The goal of a reprex is to make it easier for us to recreate your problem so that we can understand it and/or fix it.

If you've never heard of a reprex before, you may want to start with the tidyverse.org help page. You may already have reprex installed (it comes with the tidyverse package), but if not you can install it with:

install.packages("reprex")

Thanks! 🙌

from dalextra.

srgillott avatar srgillott commented on June 18, 2024

@juliasilge I couldnt get reprex to work. I posted a new issue but saw it was already being discussed so closed it. Here is the example which I should of included in the first instance. I couldnt get reprex to work but hopefully this is enough information.

When I try and run explain_tidymodels I have to include all the features that are in my dataset otherwise I get an error. I have run update_role in the original recipie so it isnt included in the analsysis. However, when running explain_tidymodel the feature is included in the VIP plot/results. I tried to remove the feature from the data before running explain_tidymodel but this produces an error. I got around it by removing the feature from the very beginning. Ideally it would be good to have the feature not included in the VIP when its been updated.

Using the example from https://modeloriented.github.io/DALEXtra/reference/explain_tidymodels.html

library("DALEXtra")
library("tidymodels")

data <- titanic_imputed
data$survived <- as.factor(data$survived)
rec <- recipe(survived ~ ., data = data) %>%
  update_role(parch, new_role = "test_role") %>%
  step_normalize(fare)

model <- decision_tree(tree_depth = 25) %>%
  set_engine("rpart") %>%
  set_mode("classification")

wflow <- workflow() %>%
  add_recipe(rec) %>%
  add_model(model)


model_fitted <- wflow %>%
  fit(data = data)

explain_tidymodels(model_fitted, data = titanic_imputed, y = titanic_imputed$survived)

The above works but parch is still in the model results.

If I remove the offending feature then it returns a warning.

ex_data <- data %>%
  select(-parch)

explain_tidymodels(model_fitted, data =ex_data, y = titanic_imputed$survived)

I am having this issue when trying to plot a VIP from a random forest regression as well.

from dalextra.

juliasilge avatar juliasilge commented on June 18, 2024

Take a look at this article for more on how to use reprex.

What you need to do is make sure you bake() with the specific variables you need (predictors + outcomes), like Max showed.

library("DALEXtra")
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.4.0).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
library("tidymodels")

data <- titanic_imputed
data$survived <- as.factor(data$survived)
rec <- recipe(survived ~ ., data = data) %>%
  update_role(parch, new_role = "test_role") %>%
  step_normalize(fare)

model <- decision_tree(tree_depth = 25) %>%
  set_engine("rpart") %>%
  set_mode("classification")

wflow <- workflow() %>%
  add_recipe(rec) %>%
  add_model(model)

model_fitted <- wflow %>%
  fit(data = data)

explain_tidymodels(
  model_fitted, 
  data = rec %>% prep() %>% bake(new_data = NULL, all_predictors(), all_outcomes()), 
  y = titanic_imputed$survived
)
#> Preparation of a new explainer is initiated
#>   -> model label       :  workflow  ( �[33m default �[39m )
#>   -> data              :  2207  rows  7  cols 
#>   -> data              :  tibble converted into a data.frame 
#>   -> target variable   :  2207  values 
#>   -> predict function  :  yhat.workflow  will be used ( �[33m default �[39m )
#>   -> predicted values  :  No value for predict function target column. ( �[33m default �[39m )
#>   -> model_info        :  package tidymodels , ver. 0.1.4 , task classification ( �[33m default �[39m ) 
#>   -> predicted values  :  the predict_function returns an error when executed ( �[31m WARNING �[39m ) 
#>   -> residual function :  difference between y and yhat ( �[33m default �[39m )
#>   -> residuals         :  the residual_function returns an error when executed ( �[31m WARNING �[39m ) 
#>  �[32m A new explainer has been created! �[39m
#> Model label:  workflow 
#> Model class:  workflow 
#> Data head  :
#>   gender age class    embarked         fare sibsp survived
#> 1   male  42   3rd Southampton -0.297340822     0        0
#> 2   male  13   3rd Southampton  0.001347135     0        0

Created on 2022-03-03 by the reprex package (v2.0.1)

from dalextra.

maksymiuks avatar maksymiuks commented on June 18, 2024

Hi @topepo, hi @juliasilge

I apologize it took me so long, but I think, I'm on a good track. Here is a development branch https://github.com/ModelOriented/DALEXtra/tree/65-feature-level-data-for-tidymodels-workflows with what I believe to be exactly what you asked for. I've tested it against example @topepo provided in the first post which is:

library(tidymodels)
library(DALEXtra)

tidymodels_prefer()
data("Chicago")

rec <- recipe(ridership ~ date + Clark_Lake + California, data = Chicago) %>% 
  step_date(date) %>% 
  step_holiday(date) %>% 
  update_role(date, new_role = "id") %>% 
  step_dummy(all_nominal_predictors())

lm_spec <- linear_reg() %>% set_engine("lm")

lm_wflow <- 
  workflow() %>% 
  add_model(lm_spec) %>% 
  add_recipe(rec) %>%
  fit(data = Chicago)

lm_fit <-
  lm_wflow %>% 
  pull_workflow_fit() # <- parsnip model_fit object

feature_data <- 
  lm_wflow %>% 
  pull_workflow_prepped_recipe() %>% 
  bake(new_data = Chicago)

vip_features <- 
  explain_tidymodels(
    lm_fit, 
    data = feature_data %>% select(-ridership), 
    y = feature_data$ridership
  ) %>% 
  model_parts()

and it seems to work, also the date column is properly broken down. Could you please test it on your own and let me know whether there is room for improvement?

from dalextra.

topepo avatar topepo commented on June 18, 2024

Thanks for doing this. It looks good to me. I've made a few changes below to avoid functions that are now deprecated. Also, it might be better to get the data out of the workflow mold since it excludes columns that have roles other than outcomes and predictors.

# remotes::install_github("ModelOriented/DALEXtra@65-feature-level-data-for-tidymodels-workflows")
library(tidymodels)
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.4.0).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#> 
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#> 
#>     explain

tidymodels_prefer()
data("Chicago")

rec <- recipe(ridership ~ date + Clark_Lake + California, data = Chicago) %>% 
  step_date(date) %>% 
  step_holiday(date) %>% 
  update_role(date, new_role = "id") %>% 
  step_dummy(all_nominal_predictors())

lm_spec <- linear_reg() %>% set_engine("lm")

lm_wflow <- 
  workflow() %>% 
  add_model(lm_spec) %>% 
  add_recipe(rec) %>%
  fit(data = Chicago)

lm_fit <-
  lm_wflow %>% 
  extract_fit_parsnip() # <- parsnip model_fit object

feature_data <- 
  lm_wflow %>% 
  extract_mold() %>% 
  pluck("predictors") 

outcome_data <- 
  lm_wflow %>% 
  extract_mold() %>% 
  pluck("outcomes") %>% 
  pluck(1)    # <- is is a 1D df, make it a vector

vip_features <- 
  explain_tidymodels(
    lm_fit, 
    data = feature_data, 
    y = outcome_data
  ) %>% 
  model_parts()
#> Preparation of a new explainer is initiated
#>   -> model label       :  model_fit  ( �[33m default �[39m )
#>   -> data              :  5698  rows  23  cols 
#>   -> data              :  tibble converted into a data.frame 
#>   -> target variable   :  5698  values 
#>   -> predict function  :  yhat.model_fit  will be used ( �[33m default �[39m )
#>   -> predicted values  :  No value for predict function target column. ( �[33m default �[39m )
#>   -> model_info        :  package parsnip , ver. 0.2.1.9000 , task regression ( �[33m default �[39m ) 
#>   -> predicted values  :  numerical, min =  -10.0413 , mean =  13.61933 , max =  21.75192  
#>   -> residual function :  difference between y and yhat ( �[33m default �[39m )
#>   -> residuals         :  numerical, min =  -16.28978 , mean =  4.159391e-12 , max =  10.9823  
#>  �[32m A new explainer has been created! �[39m

Created on 2022-03-31 by the reprex package (v2.0.1)

from dalextra.

topepo avatar topepo commented on June 18, 2024

Ok, quick question about the approach ☝️ ... does this preclude us from getting predictor-level results?

In other words, we'd like to be able to do separate calls to get

  • summarizations at the predictor level (e.g. for predictors date, Clark_Lake, and California) as well as
  • feature-level results (e.g. date_dow_Mon, etc)

The code above looks like same as the original API for getting predictor-level results.

EDIT - ok it looks like the different is the object that we provide (the parsnip fit vs the workflow). Is that right?

from dalextra.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.