Code Monkey home page Code Monkey logo

tidytreatment's People

Contributors

bonstats avatar mjskay avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

mjskay denisashah

tidytreatment's Issues

Using survival (time-to-event) outcomes

Thanks for continuing development! Would you perhaps know whether there is any way to use survival (time-to-event) outcomes with any of the BART engines that {tidytreatment} is supporting? We are currently evaluating the {grf}, {personalized}, and {StratifiedMedicine} packages for the use on our clinical data, and having support for survival endpoints in {tidytreatment} would be great for the comparison. Thanks for any hints you may have!

is_01_integer_vector(newdata[, treatment]) is not TRUE

Hi!
Thanks for this package - it looks really handy. I know it's still in early stages, so perhaps that is the reason why I'm getting this error.

I'm trying to get the treatment effects from a classification bartMachine model (but this error also occurs with a regression model as well).

I've created a data frame called dat and have followed your vignette, calculating the propensity score:
(Note: This doesn't format nicely on Github, but if you copy and paste it into R, you will get the data frame with the propensity score included).

tibble::tribble(
                 ~x1,                ~x2, ~c1, ~c2, ~c3, ~c4, ~c5, ~c6, ~z, ~y,       ~prop_score,
  -0.619983928438955, -0.704383293245453,  2L,  1L,  5L,  2L,  4L,  4L,  1, 1L, 0.542908575444251,
  -0.626634394521585,  -1.29479197186313,  2L,  1L,  3L,  3L,  4L,  2L,  0, 1L, 0.472973074828456,
  -0.211822807502002,  0.509778990213741,  1L,  1L,  3L,  2L,  1L,  2L,  0, 1L, 0.534461492086001,
  -0.901743029895167,  -1.23626678676882,  1L,  1L,  3L,  3L,  2L,  1L,  0, 1L, 0.469981801468016,
   0.108343378149696,  0.500197708935189,  2L,  0L,  5L,  4L,  1L,  4L,  0, 1L,  0.46428639428549,
   0.711996308261449, -0.601781192572445,  2L,  1L,  4L,  3L,  1L,  1L,  1, 1L, 0.527992738829824,
   -1.38204158657383,  -1.23812835972309,  1L,  0L,  5L,  2L,  2L,  1L,  1, 1L, 0.500471458430563,
   0.721034867054809,  0.509705714053607,  2L,  1L,  4L,  2L,  2L,  1L,  0, 1L, 0.542532509472265,
   0.707471997522852,  0.565327806243676,  2L,  1L,  3L,  3L,  1L,  1L,  1, 0L, 0.523616539070879,
    0.70342326385171,  0.498471660056034,  1L,  1L,  1L,  3L,  1L,  4L,  1, 0L, 0.390898999903596,
   0.709016830182189,  0.405428236852683,  1L,  1L,  3L,  3L,  1L,  4L,  1, 1L, 0.407917511977301,
   0.719187348357036,  0.569524495448596,  1L,  0L,  5L,  4L,  1L,  3L,  0, 1L, 0.484834964845595,
   0.719743462605666,  0.503915016980274,  1L,  1L,  5L,  4L,  2L,  3L,  1, 1L, 0.476156586501844,
   0.708793635755663, -0.446566834486588,  1L,  1L,  5L,  4L,  1L,  4L,  0, 1L, 0.524972723679686,
    0.72673036391176,  0.508745711467959,  1L,  1L,  1L,  3L,  1L,  4L,  0, 1L,   0.6294890530542,
   0.717929677790862, -0.502116337087818,  2L,  1L,  3L,  4L,  1L,  3L,  0, 0L, 0.496876401125496,
   0.723212745676559,  0.510802082592574,  1L,  1L,  1L,  3L,  1L,  3L,  1, 1L, 0.600318725870474,
   0.121425965803968,  0.296052357185999,  1L,  1L,  2L,  4L,  1L,  4L,  0, 0L, 0.545459684037481,
   0.722287187034774,  0.558292991175618,  1L,  1L,  2L,  3L,  1L,  3L,  0, 1L, 0.537113083141673,
   0.716249243063561, -0.662962588231552,  2L,  1L,  3L,  1L,  2L,  3L,  1, 0L, 0.502049711413861,
   0.717828654765672,  -1.22899310053262,  2L,  1L,  3L,  1L,  1L,  2L,  0, 1L, 0.483631686895801,
   0.706389311802865,  0.565629138649073,  1L,  1L,  3L,  4L,  1L,  2L,  1, 0L, 0.536503059865542,
  -0.203417031998133, -0.770222158480375,  2L,  1L,  3L,  2L,  2L,  3L,  0, 1L, 0.484053590181911,
   0.711436137741298,   -1.2301795104596,  2L,  0L,  4L,  3L,  1L,  2L,  1, 0L, 0.524062945928105,
   0.724140796755138,  0.505182463040808,  1L,  1L,  4L,  4L,  1L,  4L,  0, 1L, 0.502258493897457,
   0.716268318989523, -0.512667654828064,  1L,  1L,  4L,  1L,  2L,  3L,  0, 0L, 0.589050987302026,
  -0.210806733608954,  -1.45495789372891,  2L,  1L,  4L,  3L,  2L,  1L,  1, 1L, 0.496656909399398,
   0.719827976889018,  0.562633419916495,  1L,  1L,  1L,  2L,  1L,  4L,  0, 0L, 0.587927508222399,
   0.715648299615523,  0.238069956615491,  2L,  1L,  2L,  3L,  2L,  4L,  1, 0L, 0.518915001679177,
   0.735343066119825,  0.559735033474916,  1L,  1L,  2L,  4L,  1L,  3L,  1, 1L,  0.53282840822453,
   0.723612845213415,  -1.23197380567581,  1L,  1L,  5L,  4L,  1L,  2L,  0, 0L, 0.516384808056013,
   0.111768030844326, -0.471829269894991,  2L,  1L,  3L,  2L,  2L,  1L,  0, 1L, 0.490294929493957,
   -1.37244386883096,  -1.23542008622328,  1L,  1L,  5L,  2L,  2L,  1L,  0, 1L, 0.519759047557842,
   0.725708390097975,   -1.4187950549084,  1L,  1L,  4L,  3L,  1L,  3L,  0, 1L, 0.506215512190674,
  -0.203895625066123, -0.707663908453269,  2L,  1L,  4L,  3L,  1L,  2L,  1, 1L, 0.531299969705177,
   0.720558921143588, -0.507896987328083,  2L,  1L,  5L,  3L,  2L,  1L,  1, 1L,  0.57566280672683,
    0.71560806106389,  0.505113232185129,  2L,  1L,  2L,  4L,  1L,  4L,  0, 0L, 0.515829756245115,
   0.112391830248744,  0.507387374908298,  1L,  1L,  3L,  4L,  1L,  1L,  1, 1L, 0.500827148021631,
   0.730676501651419,  0.561127755107954,  2L,  0L,  2L,  4L,  1L,  4L,  1, 0L, 0.546148981837634,
   0.103926715042258,  0.408184629561685,  2L,  1L,  1L,  4L,  1L,  3L,  0, 0L, 0.381803787913373,
   0.707640791779242,  0.564613846564169,  2L,  1L,  4L,  4L,  1L,  3L,  1, 1L,   0.4466918335167,
  -0.619983733513069,  -1.34729960729906,  2L,  1L,  5L,  3L,  1L,  2L,  1, 1L, 0.507999212628759,
   0.731751828434166,  0.570879206514913,  2L,  0L,  4L,  4L,  1L,  2L,  1, 1L,  0.49503576582592,
  -0.906234718467424, -0.464784881069768,  2L,  1L,  4L,  3L,  2L,  2L,  0, 1L, 0.543727828195565,
   0.724195955009647,  0.519071824873624,  2L,  1L,  4L,  2L,  2L,  4L,  0, 0L, 0.511732005235294,
   -1.36261767399284, -0.597298931815627,  1L,  1L,  3L,  3L,  2L,  2L,  1, 1L,  0.48168744557023,
   0.718465384803859,  0.561570383543805,  1L,  1L,  4L,  3L,  2L,  2L,  1, 1L, 0.444191762248526,
   0.722609922246391,  -1.24064443844372,  1L,  1L,  3L,  4L,  1L,  3L,  0, 1L, 0.453540383935195,
  -0.626390809541331,  0.200706298006921,  2L,  1L,  3L,  2L,  2L,  3L,  0, 1L, 0.496724791477404,
    0.72080545663214,  0.509088972581472,  2L,  1L,  2L,  3L,  1L,  4L,  0, 0L, 0.602933394597977,
   0.725824895374323,  0.570465404411146,  1L,  1L,  4L,  4L,  1L,  3L,  1, 0L,  0.49442028949831,
  -0.622466825029192,  0.457245654530033,  2L,  1L,  1L,  3L,  4L,  2L,  1, 0L, 0.362605952613939,
   0.116647951498765,  0.299363240289861,  1L,  1L,  4L,  4L,  2L,  2L,  1, 0L, 0.549513210484241,
   0.722665069475875,  0.506555312675104,  1L,  1L,  3L,  4L,  1L,  4L,  1, 0L, 0.481206097532479,
   0.718578588309193, -0.600093475766179,  1L,  1L,  5L,  3L,  2L,  3L,  0, 1L, 0.537925251015914,
   0.725199649049049, -0.436045678081683,  2L,  1L,  4L,  4L,  1L,  2L,  1, 1L,  0.53854007572209,
   0.109924118084875, -0.515552543971096,  2L,  1L,  2L,  2L,  3L,  2L,  1, 0L, 0.540454680707154,
  -0.629026458623944, -0.693398805838134,  1L,  1L,  5L,  3L,  1L,  3L,  0, 1L, 0.575609798586739,
   0.720875462565464,  -1.23199800070166,  2L,  1L,  1L,  4L,  1L,  3L,  1, 1L,  0.35477001964168,
   0.720589423513613,  0.507715514920805,  2L,  1L,  2L,  4L,  1L,  4L,  0, 0L, 0.572874158281384,
  -0.619909290626461,  -1.23575360710521,  1L,  1L,  4L,  4L,  2L,  2L,  0, 1L, 0.516137793254588,
   0.733118514241284, -0.401101963210555,  1L,  0L,  5L,  3L,  1L,  4L,  0, 1L,  0.52694365070584,
   0.109681139709487,  0.253943880615392,  2L,  1L,  2L,  4L,  1L,  2L,  1, 1L, 0.523559866215833,
   0.712225907332477, -0.656323340422975,  1L,  0L,  4L,  4L,  1L,  3L,  0, 1L, 0.545722198624145,
   -0.61889936012751,  -1.26243642978487,  2L,  1L,  5L,  3L,  1L,  1L,  0, 1L, 0.507690724444577,
   0.713970410991397,  0.411081182749688,  2L,  1L,  4L,  3L,  1L,  2L,  0, 1L, 0.468726934695203,
   0.111487497649392,  0.408112049324159,  2L,  1L,  4L,  3L,  2L,  2L,  1, 1L, 0.461339321470336,
   -1.36675566300839,  0.558908094710503,  2L,  1L,  3L,  3L,  4L,  4L,  1, 0L, 0.482849394107573,
   -0.61442067164354,  0.209336574406927,  1L,  1L,  3L,  2L,  1L,  3L,  0, 1L, 0.496461374043943,
   0.722188924504669,  0.502200183110036,  1L,  1L,  5L,  3L,  1L,  3L,  0, 1L, 0.481887715622555,
   0.725123520192068,  0.554979283191126,  1L,  1L,  5L,  3L,  1L,  3L,  0, 1L, 0.472118605257476,
  -0.216783501151248, -0.764927126370616,  1L,  1L,  3L,  1L,  3L,  3L,  1, 0L, 0.476628185701138,
   0.718328936832771,  0.565338885010626,  1L,  1L,  5L,  4L,  1L,  2L,  1, 0L, 0.435179355921364,
   -0.21117063974132,  0.500494141597518,  2L,  1L,  5L,  4L,  1L,  2L,  0, 1L, 0.468645476521097,
  -0.206724659814808,  0.574497114796512,  2L,  1L,  4L,  3L,  3L,  1L,  0, 1L, 0.501432198543294,
   -1.36932037985284,  0.567997232201149,  2L,  1L,  1L,  3L,  4L,  1L,  1, 1L, 0.599436746922836,
   0.103966399059827,  0.504659922136247,  1L,  0L,  4L,  4L,  1L,  3L,  1, 1L, 0.483589668724281,
   0.727686345407945,  0.567986494849604,  1L,  0L,  3L,  3L,  1L,  3L,  0, 1L, 0.557420216204061,
   0.716519648044397,  0.562063890488908,  1L,  1L,  5L,  4L,  1L,  2L,  0, 1L, 0.428216720236038,
   0.726464317138348, -0.403663594344434,  1L,  1L,  4L,  4L,  1L,  4L,  0, 1L,  0.54682416705555,
   0.724022157955659,  0.400649301941948,  2L,  1L,  3L,  2L,  2L,  1L,  0, 1L, 0.477434214942826,
  -0.215075694683002, -0.514529209915271,  2L,  1L,  3L,  3L,  5L,  1L,  0, 0L, 0.529332088250662,
   0.725958151563208,   0.28972809288225,  2L,  0L,  1L,  4L,  1L,  2L,  0, 1L,  0.44612359356437,
   0.710721498938538,  0.506107609222488,  1L,  1L,  5L,  3L,  1L,  2L,  1, 0L, 0.510874236024449,
   -1.37732485889494,  0.453978511184122,  2L,  1L,  3L,  2L,  2L,  1L,  1, 1L, 0.372521351932443,
   0.715464511195532,  0.564518579400541,  1L,  1L,  4L,  3L,  1L,  3L,  1, 1L,   0.4466918335167,
   0.728048482425467, -0.510827085283348,  2L,  0L,  4L,  3L,  1L,  2L,  0, 1L, 0.585580676804591,
   0.717609720508345,  0.505631545659016,  1L,  1L,  3L,  3L,  1L,  2L,  1, 1L, 0.466635347329185,
  -0.618705658028534,  -1.27862760188277,  1L,  1L,  4L,  2L,  3L,  2L,  1, 1L, 0.512945806663173,
  -0.625618829322605,  0.561396621884238,  2L,  1L,  4L,  3L,  1L,  2L,  1, 1L, 0.441711034943067,
   0.708812343450583,  0.572875444566286,  1L,  0L,  1L,  2L,  1L,  3L,  1, 0L, 0.595732714429437,
    0.72424286539536,  0.415448738622642,  1L,  1L,  4L,  4L,  1L,  3L,  0, 0L, 0.473707014647853,
   0.707444790047399,   0.55620764090138,  2L,  1L,  3L,  4L,  1L,  2L,  1, 1L, 0.499117321588344,
   0.728219716757736,  0.401105645225583,  1L,  1L,  3L,  4L,  1L,  3L,  0, 1L, 0.463864060774488,
   -1.36383562682315,  -1.53411062215581,  2L,  1L,  5L,  3L,  1L,  1L,  0, 1L, 0.510263169175049,
   0.717364967554017,  0.501866846558604,  1L,  1L,  3L,  3L,  1L,  4L,  0, 0L, 0.429866002810877,
     0.1149139984673,  0.448119671702872,  2L,  1L,  4L,  3L,  1L,  3L,  1, 1L, 0.457492818734489,
   0.118315042985935,  0.456710336927965,  2L,  1L,  2L,  3L,  1L,  2L,  1, 1L, 0.423369972581007,
  -0.912383449456444, -0.705311104570593,  2L,  1L,  1L,  2L,  3L,  2L,  1, 1L, 0.371203388587355,
    0.70840841613744, -0.723831830767927,  2L,  1L,  4L,  3L,  3L,  3L,  1, 1L, 0.524183552730355
  ) %>% as.data.frame()

The code is as follows:

set_bart_machine_num_cores(2)

  var_select_bart <- bartMachine(
    X = select(dat,-y,-z),
    y = select(dat, y)[[1]],
    num_burn_in = 2000,
    num_iterations_after_burn_in = 5000,
    serialize = T,
    verbose = F
    )

  var_select <- bartMachine::var_selection_by_permute_cv(var_select_bart, k_folds = 5)

  prop_bart <- bartMachine(
    X = select(dat,var_select$important_vars_cv), 
    y = as.factor(select(dat, z)[[1]]), 
    num_burn_in = 2000,
    num_iterations_after_burn_in = 5000,
    serialize = T,
    verbose = F
    ) 

  dat$prop_score <-  prop_bart$p_hat_train

  prior_incl_prob <- setNames(rep(1, times = ncol(dat) - 1), colnames(dat)[colnames(dat) != 'y'])
  prior_incl_prob['z'] <- 2

bartM <- build_bart_machine(
    X = select(dat,-y),
    y = select(dat, y)[[1]],
    num_burn_in = 2000, 
    num_iterations_after_burn_in = 5000,
    serialize = T,
    verbose = F,
    cov_prior_vec = prior_incl_prob
    ) 

But when I run the following code to extract the treatment effects, I get this error:

posterior_treat_eff <- treatment_effects(bartM, treatment = "z")
Error in fitted_with_counter_factual_draws(model, newdata, treatment, : is_01_integer_vector(newdata[, treatment]) | is.logical(newdata[, .... is not TRUE

Any reasons why? I'm honestly stumped - your example with the simulate_su_hill_data() works just fine. AFAIK, my set up is identical - including using the integer class for categorical variables.

Session info below:

> sessionInfo()
R version 4.0.2 (2020-06-22)
Platform: x86_64-apple-darwin19.5.0 (64-bit)
Running under: macOS Catalina 10.15.7

Matrix products: default
BLAS:   /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /usr/local/Cellar/openblas/0.3.10_1/lib/libopenblasp-r0.3.10.dylib

locale:
[1] en_AU.UTF-8/en_AU.UTF-8/en_AU.UTF-8/C/en_AU.UTF-8/en_AU.UTF-8

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] tidybayes_2.3.1     forcats_0.5.0       stringr_1.4.0       dplyr_1.0.2         purrr_0.3.4         readr_1.3.1         tidyr_1.1.2        
 [8] tibble_3.0.4        ggplot2_3.3.2       tidyverse_1.3.0     tidytreatment_0.1.0 bartMachine_1.2.5.1 missForest_1.4      itertools_0.1-3    
[15] iterators_1.0.12    foreach_1.5.0       randomForest_4.6-14 bartMachineJARs_1.1 rJava_0.9-13       

loaded via a namespace (and not attached):
  [1] readxl_1.3.1         backports_1.1.10     plyr_1.8.6           igraph_1.2.5         splines_4.0.2        svUnit_1.0.3        
  [7] crosstalk_1.1.0.1    listenv_0.8.0        TH.data_1.0-10       rstantools_2.1.1     inline_0.3.15        digest_0.6.27       
 [13] htmltools_0.5.0      rsconnect_0.8.16     fansi_0.4.1          magrittr_1.5         globals_0.12.5       brms_2.13.5         
 [19] modelr_0.1.8         RcppParallel_5.0.2   matrixStats_0.56.0   MCMCpack_1.4-9       xts_0.12.1           sandwich_2.5-1      
 [25] prettyunits_1.1.1    colorspace_2.0-0     rvest_0.3.6          blob_1.2.1           ggdist_2.3.0         haven_2.3.1         
 [31] xfun_0.16            callr_3.5.1          crayon_1.3.4         jsonlite_1.7.1       survival_3.1-12      zoo_1.8-8           
 [37] glue_1.4.2           gtable_0.3.0         emmeans_1.5.1        MatrixModels_0.4-1   V8_3.2.0             distributional_0.2.1
 [43] pkgbuild_1.1.0       rstan_2.21.2         datapasta_3.1.0      future.apply_1.6.0   abind_1.4-5          SparseM_1.78        
 [49] scales_1.1.1         mvtnorm_1.1-1        DBI_1.1.0            miniUI_0.1.1.1       Rcpp_1.0.5           xtable_1.8-4        
 [55] tmvnsim_1.0-2        stats4_4.0.2         StanHeaders_2.21.0-6 DT_0.15              httr_1.4.2           htmlwidgets_1.5.1   
 [61] threejs_0.3.3        arrayhelpers_1.1-0   lavaan_0.6-7         ellipsis_0.3.1       farver_2.0.3         pkgconfig_2.0.3     
 [67] loo_2.3.1            dbplyr_1.4.4         utf8_1.1.4           here_0.1             tidyselect_1.1.0     rlang_0.4.8         
 [73] reshape2_1.4.4       later_1.1.0.1        cellranger_1.1.0     munsell_0.5.0        tools_4.0.2          cli_2.1.0           
 [79] generics_0.1.0       broom_0.7.0          ggridges_0.5.2       evaluate_0.14        fastmap_1.0.1        yaml_2.2.1          
 [85] blavaan_0.3-10       mcmc_0.9-7           fs_1.5.0             processx_3.4.4       knitr_1.29           bcf_1.3             
 [91] future_1.18.0        nlme_3.1-148         mime_0.9             quantreg_5.70        xml2_1.3.2           nonnest2_0.5-5      
 [97] compiler_4.0.2       bayesplot_1.7.2      shinythemes_1.1.2    rstudioapi_0.13      curl_4.3             reprex_0.3.0        
[103] pbivnorm_0.6.0       stringi_1.4.6        ps_1.4.0             Brobdingnag_1.2-6    lattice_0.20-41      Matrix_1.2-18       
[109] markdown_1.1         shinyjs_2.0.0        vctrs_0.3.4          CompQuadForm_1.4.3   pillar_1.4.6         lifecycle_0.2.0     
[115] bridgesampling_1.0-0 estimability_1.3     conquer_1.0.2        httpuv_1.5.4         R6_2.5.0             promises_1.1.1      
[121] gridExtra_2.3        codetools_0.2-16     colourpicker_1.1.0   MASS_7.3-51.6        gtools_3.8.2         assertthat_0.2.1    
[127] rprojroot_2.0.2      withr_2.3.0          shinystan_2.5.0      mnormt_2.0.2         multcomp_1.4-14      hms_0.5.3           
[133] parallel_4.0.2       grid_4.0.2           coda_0.19-4          rmarkdown_2.5        shiny_1.5.0          lubridate_1.7.9     
[139] base64enc_0.1-3      dygraphs_1.1.1.6    

Add unit testing

We need testing of the functionality of the package functions to ensure that breaking changes/updates in the future will be caught. This is for our own sanity but also so that others can contribute with more confidence.

Unit testing is needed for

  • treatment_effects(...)
  • avg_treatment_effects(...), tidy_ate(...), tidy_att(...)
  • common support functions
  • covariance importance

Great package!

Hey, very nice package! Are you continuing the development? Please do so, it looks super useful already! Perhaps adding bartCause would be an option as well?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.