bonstats / tidytreatment Goto Github PK
View Code? Open in Web Editor NEWTidy methods for Bayesian treatment effect models
License: Other
Tidy methods for Bayesian treatment effect models
License: Other
Thanks for continuing development! Would you perhaps know whether there is any way to use survival (time-to-event) outcomes with any of the BART engines that {tidytreatment} is supporting? We are currently evaluating the {grf}, {personalized}, and {StratifiedMedicine} packages for the use on our clinical data, and having support for survival endpoints in {tidytreatment} would be great for the comparison. Thanks for any hints you may have!
Hi!
Thanks for this package - it looks really handy. I know it's still in early stages, so perhaps that is the reason why I'm getting this error.
I'm trying to get the treatment effects from a classification bartMachine model (but this error also occurs with a regression model as well).
I've created a data frame called dat
and have followed your vignette, calculating the propensity score:
(Note: This doesn't format nicely on Github, but if you copy and paste it into R, you will get the data frame with the propensity score included).
tibble::tribble(
~x1, ~x2, ~c1, ~c2, ~c3, ~c4, ~c5, ~c6, ~z, ~y, ~prop_score,
-0.619983928438955, -0.704383293245453, 2L, 1L, 5L, 2L, 4L, 4L, 1, 1L, 0.542908575444251,
-0.626634394521585, -1.29479197186313, 2L, 1L, 3L, 3L, 4L, 2L, 0, 1L, 0.472973074828456,
-0.211822807502002, 0.509778990213741, 1L, 1L, 3L, 2L, 1L, 2L, 0, 1L, 0.534461492086001,
-0.901743029895167, -1.23626678676882, 1L, 1L, 3L, 3L, 2L, 1L, 0, 1L, 0.469981801468016,
0.108343378149696, 0.500197708935189, 2L, 0L, 5L, 4L, 1L, 4L, 0, 1L, 0.46428639428549,
0.711996308261449, -0.601781192572445, 2L, 1L, 4L, 3L, 1L, 1L, 1, 1L, 0.527992738829824,
-1.38204158657383, -1.23812835972309, 1L, 0L, 5L, 2L, 2L, 1L, 1, 1L, 0.500471458430563,
0.721034867054809, 0.509705714053607, 2L, 1L, 4L, 2L, 2L, 1L, 0, 1L, 0.542532509472265,
0.707471997522852, 0.565327806243676, 2L, 1L, 3L, 3L, 1L, 1L, 1, 0L, 0.523616539070879,
0.70342326385171, 0.498471660056034, 1L, 1L, 1L, 3L, 1L, 4L, 1, 0L, 0.390898999903596,
0.709016830182189, 0.405428236852683, 1L, 1L, 3L, 3L, 1L, 4L, 1, 1L, 0.407917511977301,
0.719187348357036, 0.569524495448596, 1L, 0L, 5L, 4L, 1L, 3L, 0, 1L, 0.484834964845595,
0.719743462605666, 0.503915016980274, 1L, 1L, 5L, 4L, 2L, 3L, 1, 1L, 0.476156586501844,
0.708793635755663, -0.446566834486588, 1L, 1L, 5L, 4L, 1L, 4L, 0, 1L, 0.524972723679686,
0.72673036391176, 0.508745711467959, 1L, 1L, 1L, 3L, 1L, 4L, 0, 1L, 0.6294890530542,
0.717929677790862, -0.502116337087818, 2L, 1L, 3L, 4L, 1L, 3L, 0, 0L, 0.496876401125496,
0.723212745676559, 0.510802082592574, 1L, 1L, 1L, 3L, 1L, 3L, 1, 1L, 0.600318725870474,
0.121425965803968, 0.296052357185999, 1L, 1L, 2L, 4L, 1L, 4L, 0, 0L, 0.545459684037481,
0.722287187034774, 0.558292991175618, 1L, 1L, 2L, 3L, 1L, 3L, 0, 1L, 0.537113083141673,
0.716249243063561, -0.662962588231552, 2L, 1L, 3L, 1L, 2L, 3L, 1, 0L, 0.502049711413861,
0.717828654765672, -1.22899310053262, 2L, 1L, 3L, 1L, 1L, 2L, 0, 1L, 0.483631686895801,
0.706389311802865, 0.565629138649073, 1L, 1L, 3L, 4L, 1L, 2L, 1, 0L, 0.536503059865542,
-0.203417031998133, -0.770222158480375, 2L, 1L, 3L, 2L, 2L, 3L, 0, 1L, 0.484053590181911,
0.711436137741298, -1.2301795104596, 2L, 0L, 4L, 3L, 1L, 2L, 1, 0L, 0.524062945928105,
0.724140796755138, 0.505182463040808, 1L, 1L, 4L, 4L, 1L, 4L, 0, 1L, 0.502258493897457,
0.716268318989523, -0.512667654828064, 1L, 1L, 4L, 1L, 2L, 3L, 0, 0L, 0.589050987302026,
-0.210806733608954, -1.45495789372891, 2L, 1L, 4L, 3L, 2L, 1L, 1, 1L, 0.496656909399398,
0.719827976889018, 0.562633419916495, 1L, 1L, 1L, 2L, 1L, 4L, 0, 0L, 0.587927508222399,
0.715648299615523, 0.238069956615491, 2L, 1L, 2L, 3L, 2L, 4L, 1, 0L, 0.518915001679177,
0.735343066119825, 0.559735033474916, 1L, 1L, 2L, 4L, 1L, 3L, 1, 1L, 0.53282840822453,
0.723612845213415, -1.23197380567581, 1L, 1L, 5L, 4L, 1L, 2L, 0, 0L, 0.516384808056013,
0.111768030844326, -0.471829269894991, 2L, 1L, 3L, 2L, 2L, 1L, 0, 1L, 0.490294929493957,
-1.37244386883096, -1.23542008622328, 1L, 1L, 5L, 2L, 2L, 1L, 0, 1L, 0.519759047557842,
0.725708390097975, -1.4187950549084, 1L, 1L, 4L, 3L, 1L, 3L, 0, 1L, 0.506215512190674,
-0.203895625066123, -0.707663908453269, 2L, 1L, 4L, 3L, 1L, 2L, 1, 1L, 0.531299969705177,
0.720558921143588, -0.507896987328083, 2L, 1L, 5L, 3L, 2L, 1L, 1, 1L, 0.57566280672683,
0.71560806106389, 0.505113232185129, 2L, 1L, 2L, 4L, 1L, 4L, 0, 0L, 0.515829756245115,
0.112391830248744, 0.507387374908298, 1L, 1L, 3L, 4L, 1L, 1L, 1, 1L, 0.500827148021631,
0.730676501651419, 0.561127755107954, 2L, 0L, 2L, 4L, 1L, 4L, 1, 0L, 0.546148981837634,
0.103926715042258, 0.408184629561685, 2L, 1L, 1L, 4L, 1L, 3L, 0, 0L, 0.381803787913373,
0.707640791779242, 0.564613846564169, 2L, 1L, 4L, 4L, 1L, 3L, 1, 1L, 0.4466918335167,
-0.619983733513069, -1.34729960729906, 2L, 1L, 5L, 3L, 1L, 2L, 1, 1L, 0.507999212628759,
0.731751828434166, 0.570879206514913, 2L, 0L, 4L, 4L, 1L, 2L, 1, 1L, 0.49503576582592,
-0.906234718467424, -0.464784881069768, 2L, 1L, 4L, 3L, 2L, 2L, 0, 1L, 0.543727828195565,
0.724195955009647, 0.519071824873624, 2L, 1L, 4L, 2L, 2L, 4L, 0, 0L, 0.511732005235294,
-1.36261767399284, -0.597298931815627, 1L, 1L, 3L, 3L, 2L, 2L, 1, 1L, 0.48168744557023,
0.718465384803859, 0.561570383543805, 1L, 1L, 4L, 3L, 2L, 2L, 1, 1L, 0.444191762248526,
0.722609922246391, -1.24064443844372, 1L, 1L, 3L, 4L, 1L, 3L, 0, 1L, 0.453540383935195,
-0.626390809541331, 0.200706298006921, 2L, 1L, 3L, 2L, 2L, 3L, 0, 1L, 0.496724791477404,
0.72080545663214, 0.509088972581472, 2L, 1L, 2L, 3L, 1L, 4L, 0, 0L, 0.602933394597977,
0.725824895374323, 0.570465404411146, 1L, 1L, 4L, 4L, 1L, 3L, 1, 0L, 0.49442028949831,
-0.622466825029192, 0.457245654530033, 2L, 1L, 1L, 3L, 4L, 2L, 1, 0L, 0.362605952613939,
0.116647951498765, 0.299363240289861, 1L, 1L, 4L, 4L, 2L, 2L, 1, 0L, 0.549513210484241,
0.722665069475875, 0.506555312675104, 1L, 1L, 3L, 4L, 1L, 4L, 1, 0L, 0.481206097532479,
0.718578588309193, -0.600093475766179, 1L, 1L, 5L, 3L, 2L, 3L, 0, 1L, 0.537925251015914,
0.725199649049049, -0.436045678081683, 2L, 1L, 4L, 4L, 1L, 2L, 1, 1L, 0.53854007572209,
0.109924118084875, -0.515552543971096, 2L, 1L, 2L, 2L, 3L, 2L, 1, 0L, 0.540454680707154,
-0.629026458623944, -0.693398805838134, 1L, 1L, 5L, 3L, 1L, 3L, 0, 1L, 0.575609798586739,
0.720875462565464, -1.23199800070166, 2L, 1L, 1L, 4L, 1L, 3L, 1, 1L, 0.35477001964168,
0.720589423513613, 0.507715514920805, 2L, 1L, 2L, 4L, 1L, 4L, 0, 0L, 0.572874158281384,
-0.619909290626461, -1.23575360710521, 1L, 1L, 4L, 4L, 2L, 2L, 0, 1L, 0.516137793254588,
0.733118514241284, -0.401101963210555, 1L, 0L, 5L, 3L, 1L, 4L, 0, 1L, 0.52694365070584,
0.109681139709487, 0.253943880615392, 2L, 1L, 2L, 4L, 1L, 2L, 1, 1L, 0.523559866215833,
0.712225907332477, -0.656323340422975, 1L, 0L, 4L, 4L, 1L, 3L, 0, 1L, 0.545722198624145,
-0.61889936012751, -1.26243642978487, 2L, 1L, 5L, 3L, 1L, 1L, 0, 1L, 0.507690724444577,
0.713970410991397, 0.411081182749688, 2L, 1L, 4L, 3L, 1L, 2L, 0, 1L, 0.468726934695203,
0.111487497649392, 0.408112049324159, 2L, 1L, 4L, 3L, 2L, 2L, 1, 1L, 0.461339321470336,
-1.36675566300839, 0.558908094710503, 2L, 1L, 3L, 3L, 4L, 4L, 1, 0L, 0.482849394107573,
-0.61442067164354, 0.209336574406927, 1L, 1L, 3L, 2L, 1L, 3L, 0, 1L, 0.496461374043943,
0.722188924504669, 0.502200183110036, 1L, 1L, 5L, 3L, 1L, 3L, 0, 1L, 0.481887715622555,
0.725123520192068, 0.554979283191126, 1L, 1L, 5L, 3L, 1L, 3L, 0, 1L, 0.472118605257476,
-0.216783501151248, -0.764927126370616, 1L, 1L, 3L, 1L, 3L, 3L, 1, 0L, 0.476628185701138,
0.718328936832771, 0.565338885010626, 1L, 1L, 5L, 4L, 1L, 2L, 1, 0L, 0.435179355921364,
-0.21117063974132, 0.500494141597518, 2L, 1L, 5L, 4L, 1L, 2L, 0, 1L, 0.468645476521097,
-0.206724659814808, 0.574497114796512, 2L, 1L, 4L, 3L, 3L, 1L, 0, 1L, 0.501432198543294,
-1.36932037985284, 0.567997232201149, 2L, 1L, 1L, 3L, 4L, 1L, 1, 1L, 0.599436746922836,
0.103966399059827, 0.504659922136247, 1L, 0L, 4L, 4L, 1L, 3L, 1, 1L, 0.483589668724281,
0.727686345407945, 0.567986494849604, 1L, 0L, 3L, 3L, 1L, 3L, 0, 1L, 0.557420216204061,
0.716519648044397, 0.562063890488908, 1L, 1L, 5L, 4L, 1L, 2L, 0, 1L, 0.428216720236038,
0.726464317138348, -0.403663594344434, 1L, 1L, 4L, 4L, 1L, 4L, 0, 1L, 0.54682416705555,
0.724022157955659, 0.400649301941948, 2L, 1L, 3L, 2L, 2L, 1L, 0, 1L, 0.477434214942826,
-0.215075694683002, -0.514529209915271, 2L, 1L, 3L, 3L, 5L, 1L, 0, 0L, 0.529332088250662,
0.725958151563208, 0.28972809288225, 2L, 0L, 1L, 4L, 1L, 2L, 0, 1L, 0.44612359356437,
0.710721498938538, 0.506107609222488, 1L, 1L, 5L, 3L, 1L, 2L, 1, 0L, 0.510874236024449,
-1.37732485889494, 0.453978511184122, 2L, 1L, 3L, 2L, 2L, 1L, 1, 1L, 0.372521351932443,
0.715464511195532, 0.564518579400541, 1L, 1L, 4L, 3L, 1L, 3L, 1, 1L, 0.4466918335167,
0.728048482425467, -0.510827085283348, 2L, 0L, 4L, 3L, 1L, 2L, 0, 1L, 0.585580676804591,
0.717609720508345, 0.505631545659016, 1L, 1L, 3L, 3L, 1L, 2L, 1, 1L, 0.466635347329185,
-0.618705658028534, -1.27862760188277, 1L, 1L, 4L, 2L, 3L, 2L, 1, 1L, 0.512945806663173,
-0.625618829322605, 0.561396621884238, 2L, 1L, 4L, 3L, 1L, 2L, 1, 1L, 0.441711034943067,
0.708812343450583, 0.572875444566286, 1L, 0L, 1L, 2L, 1L, 3L, 1, 0L, 0.595732714429437,
0.72424286539536, 0.415448738622642, 1L, 1L, 4L, 4L, 1L, 3L, 0, 0L, 0.473707014647853,
0.707444790047399, 0.55620764090138, 2L, 1L, 3L, 4L, 1L, 2L, 1, 1L, 0.499117321588344,
0.728219716757736, 0.401105645225583, 1L, 1L, 3L, 4L, 1L, 3L, 0, 1L, 0.463864060774488,
-1.36383562682315, -1.53411062215581, 2L, 1L, 5L, 3L, 1L, 1L, 0, 1L, 0.510263169175049,
0.717364967554017, 0.501866846558604, 1L, 1L, 3L, 3L, 1L, 4L, 0, 0L, 0.429866002810877,
0.1149139984673, 0.448119671702872, 2L, 1L, 4L, 3L, 1L, 3L, 1, 1L, 0.457492818734489,
0.118315042985935, 0.456710336927965, 2L, 1L, 2L, 3L, 1L, 2L, 1, 1L, 0.423369972581007,
-0.912383449456444, -0.705311104570593, 2L, 1L, 1L, 2L, 3L, 2L, 1, 1L, 0.371203388587355,
0.70840841613744, -0.723831830767927, 2L, 1L, 4L, 3L, 3L, 3L, 1, 1L, 0.524183552730355
) %>% as.data.frame()
The code is as follows:
set_bart_machine_num_cores(2)
var_select_bart <- bartMachine(
X = select(dat,-y,-z),
y = select(dat, y)[[1]],
num_burn_in = 2000,
num_iterations_after_burn_in = 5000,
serialize = T,
verbose = F
)
var_select <- bartMachine::var_selection_by_permute_cv(var_select_bart, k_folds = 5)
prop_bart <- bartMachine(
X = select(dat,var_select$important_vars_cv),
y = as.factor(select(dat, z)[[1]]),
num_burn_in = 2000,
num_iterations_after_burn_in = 5000,
serialize = T,
verbose = F
)
dat$prop_score <- prop_bart$p_hat_train
prior_incl_prob <- setNames(rep(1, times = ncol(dat) - 1), colnames(dat)[colnames(dat) != 'y'])
prior_incl_prob['z'] <- 2
bartM <- build_bart_machine(
X = select(dat,-y),
y = select(dat, y)[[1]],
num_burn_in = 2000,
num_iterations_after_burn_in = 5000,
serialize = T,
verbose = F,
cov_prior_vec = prior_incl_prob
)
But when I run the following code to extract the treatment effects, I get this error:
posterior_treat_eff <- treatment_effects(bartM, treatment = "z")
Error in fitted_with_counter_factual_draws(model, newdata, treatment, : is_01_integer_vector(newdata[, treatment]) | is.logical(newdata[, .... is not TRUE
Any reasons why? I'm honestly stumped - your example with the simulate_su_hill_data()
works just fine. AFAIK, my set up is identical - including using the integer class for categorical variables.
Session info below:
> sessionInfo()
R version 4.0.2 (2020-06-22)
Platform: x86_64-apple-darwin19.5.0 (64-bit)
Running under: macOS Catalina 10.15.7
Matrix products: default
BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /usr/local/Cellar/openblas/0.3.10_1/lib/libopenblasp-r0.3.10.dylib
locale:
[1] en_AU.UTF-8/en_AU.UTF-8/en_AU.UTF-8/C/en_AU.UTF-8/en_AU.UTF-8
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] tidybayes_2.3.1 forcats_0.5.0 stringr_1.4.0 dplyr_1.0.2 purrr_0.3.4 readr_1.3.1 tidyr_1.1.2
[8] tibble_3.0.4 ggplot2_3.3.2 tidyverse_1.3.0 tidytreatment_0.1.0 bartMachine_1.2.5.1 missForest_1.4 itertools_0.1-3
[15] iterators_1.0.12 foreach_1.5.0 randomForest_4.6-14 bartMachineJARs_1.1 rJava_0.9-13
loaded via a namespace (and not attached):
[1] readxl_1.3.1 backports_1.1.10 plyr_1.8.6 igraph_1.2.5 splines_4.0.2 svUnit_1.0.3
[7] crosstalk_1.1.0.1 listenv_0.8.0 TH.data_1.0-10 rstantools_2.1.1 inline_0.3.15 digest_0.6.27
[13] htmltools_0.5.0 rsconnect_0.8.16 fansi_0.4.1 magrittr_1.5 globals_0.12.5 brms_2.13.5
[19] modelr_0.1.8 RcppParallel_5.0.2 matrixStats_0.56.0 MCMCpack_1.4-9 xts_0.12.1 sandwich_2.5-1
[25] prettyunits_1.1.1 colorspace_2.0-0 rvest_0.3.6 blob_1.2.1 ggdist_2.3.0 haven_2.3.1
[31] xfun_0.16 callr_3.5.1 crayon_1.3.4 jsonlite_1.7.1 survival_3.1-12 zoo_1.8-8
[37] glue_1.4.2 gtable_0.3.0 emmeans_1.5.1 MatrixModels_0.4-1 V8_3.2.0 distributional_0.2.1
[43] pkgbuild_1.1.0 rstan_2.21.2 datapasta_3.1.0 future.apply_1.6.0 abind_1.4-5 SparseM_1.78
[49] scales_1.1.1 mvtnorm_1.1-1 DBI_1.1.0 miniUI_0.1.1.1 Rcpp_1.0.5 xtable_1.8-4
[55] tmvnsim_1.0-2 stats4_4.0.2 StanHeaders_2.21.0-6 DT_0.15 httr_1.4.2 htmlwidgets_1.5.1
[61] threejs_0.3.3 arrayhelpers_1.1-0 lavaan_0.6-7 ellipsis_0.3.1 farver_2.0.3 pkgconfig_2.0.3
[67] loo_2.3.1 dbplyr_1.4.4 utf8_1.1.4 here_0.1 tidyselect_1.1.0 rlang_0.4.8
[73] reshape2_1.4.4 later_1.1.0.1 cellranger_1.1.0 munsell_0.5.0 tools_4.0.2 cli_2.1.0
[79] generics_0.1.0 broom_0.7.0 ggridges_0.5.2 evaluate_0.14 fastmap_1.0.1 yaml_2.2.1
[85] blavaan_0.3-10 mcmc_0.9-7 fs_1.5.0 processx_3.4.4 knitr_1.29 bcf_1.3
[91] future_1.18.0 nlme_3.1-148 mime_0.9 quantreg_5.70 xml2_1.3.2 nonnest2_0.5-5
[97] compiler_4.0.2 bayesplot_1.7.2 shinythemes_1.1.2 rstudioapi_0.13 curl_4.3 reprex_0.3.0
[103] pbivnorm_0.6.0 stringi_1.4.6 ps_1.4.0 Brobdingnag_1.2-6 lattice_0.20-41 Matrix_1.2-18
[109] markdown_1.1 shinyjs_2.0.0 vctrs_0.3.4 CompQuadForm_1.4.3 pillar_1.4.6 lifecycle_0.2.0
[115] bridgesampling_1.0-0 estimability_1.3 conquer_1.0.2 httpuv_1.5.4 R6_2.5.0 promises_1.1.1
[121] gridExtra_2.3 codetools_0.2-16 colourpicker_1.1.0 MASS_7.3-51.6 gtools_3.8.2 assertthat_0.2.1
[127] rprojroot_2.0.2 withr_2.3.0 shinystan_2.5.0 mnormt_2.0.2 multcomp_1.4-14 hms_0.5.3
[133] parallel_4.0.2 grid_4.0.2 coda_0.19-4 rmarkdown_2.5 shiny_1.5.0 lubridate_1.7.9
[139] base64enc_0.1-3 dygraphs_1.1.1.6
We need testing of the functionality of the package functions to ensure that breaking changes/updates in the future will be caught. This is for our own sanity but also so that others can contribute with more confidence.
Unit testing is needed for
treatment_effects(...)
avg_treatment_effects(...)
, tidy_ate(...)
, tidy_att(...)
Currently just a list, a print method would be nice
A bit of a stretch goal but would be really good to have implemented for applied workflows.
The following filter occurs after the treatment effects are calculated on all observations.
tidytreatment/R/treatment-effects-posterior.R
Lines 125 to 133 in 657704e
These could be implemented:
Hey, very nice package! Are you continuing the development? Please do so, it looks super useful already! Perhaps adding bartCause would be an option as well?
Could potentially use Sobol indices: https://arxiv.org/abs/2005.13622
There is an implementation of this method here: https://bitbucket.org/mpratola/openbt/src/master/
https://bitbucket.org/mpratola/openbt/src/master/src/sobol.cpp
This method does not require refitting the BART mdoels
Something to do with the bartMachine model and serialization.
They print and crash RStudio frequently, so a print method for wbart etc would be good.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.