Code Monkey home page Code Monkey logo

stan4bart's People

Contributors

bgoodri avatar vdorie avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

stan4bart's Issues

Timeline for binary model?

Hi there, really interesting work and am definitely excited to try out. Is there a timeline for implementation of the binary model?

Saving Model to a *.RData object

Hi,

how do I save the model to an *.RData object so that I can perform a prediction after reloading?
In the dbarts package you have to touch the pointer via invisible(model$fit$state) to store it as an internal R object. Is there something similar to your package?
sampler.bart and range.bart are used for prediction as mentioned but invisible(model$sampler.bart) does not bring the desired success.

Best and thanks again
Louis

Additional families

Hello, I was wondering if you expected to eventually support additional families such as Poisson, Gamma, and NegativeBinomial in this package or dbarts, or if you see a path for users to do so, such as making adbarts interface that can be used in Stan code.

Predict function not working

Here is my training model:

fit40 <- stan4bart(
  formula = sales ~
    hdummys+tv_ads+dig_ads+prt_ads+ # linear component ("fixef")
    (1|dmaseqid) + # multilevel ("ranef") #damaseqid is a factor variable
    bart(. -region -coupons -hdummys -tv_ads -dig_ads -prt_ads), # use bart for other variables
  verbose = -1, # suppress ALL output
  # low numbers for illustration
  data = train, # 8400 rows
  chains = 1, iter = 100, bart_args = list(n.trees = 5,keepTrees = TRUE)) # using only 1 chain

this runs without a problem. Then I use the predict function as follows:

predict(fit40, newdata=test, type = c("ev", "ppd", "indiv.fixef", "indiv.ranef","indiv.bart"), # test data has 2520 rows
combine_chains = FALSE, # has only 1 chain, no need to combine
sample_new_levels = TRUE)

I get the following error:

Warning message in validateXTest(newdata, attr(data@x, "term.labels"), ncol(data@x), :
“column names of 'test' does not equal that of 'x': 'dmaseqid.1, dmaseqid.2, dmaseqid.3, dmaseqid.4, dmaseqid.5, dmaseqid.6, dmaseqid.7, dmaseqid.8, dmaseqid.9, dmaseqid.10, dmaseqid.11, dmaseqid.12, dmaseqid.13, dmaseqid.14, dmaseqid.15, dmaseqid.16, dmaseqid.17, dmaseqid.18, dmaseqid.19, dmaseqid.20, dmaseqid.21, dmaseqid.22, dmaseqid.23, dmaseqid.24, dmaseqid.25, dmaseqid.26, dmaseqid.27, dmaseqid.28, dmaseqid.29, dmaseqid.30, dmaseqid.31, dmaseqid.32, dmaseqid.33, dmaseqid.34, dmaseqid.35, dmaseqid.36, dmaseqid.37, dmaseqid.38, dmaseqid.39, dmaseqid.40, dmaseqid.41, dmaseqid.42, dmaseqid.43, dmaseqid.44, dmaseqid.45, dmaseqid.46, dmaseqid.47, dmaseqid.48, dmaseqid.49, dmaseqid.50, dmaseqid.51, dmaseqid.52, dmaseqid.53, dmaseqid.54, dmaseqid.55, dmaseqid.56, dmaseqid.57, dmaseqid.58, dmaseqid.59, dmaseqid.60, dmaseqid.61, dmaseqid.62, dmaseqid.63, dmaseqid.64, dmaseqid.65, dmaseqid.66, dmaseqid.67, dmaseqid.68, dmaseqid.69, dmaseqid.70, dmaseqid.71, dmaseqid.72, dmaseqid.73, dmaseqid.74, dmaseqid.75, dmaseqid.76, dmaseqid.77, dmaseqid.78, dmaseqid.79, dmaseqid.80, dmaseqid.81, dmaseqid.82, dmaseqid.83, dmaseqid.84, dmaseqid.85, dmaseqid.86, dmaseqid.87, dmaseqid.88, dmaseqid.89, dmaseqid.90, dmaseqid.91, dmaseqid.92, dmaseqid.93, dmaseqid.94, dmaseqid.95, dmaseqid.96, dmaseqid.97, dmaseqid.98, dmaseqid.99, dmaseqid.100, dmaseqid.101, dmaseqid.102, dmaseqid.103, dmaseqid.104, dmaseqid.105, dmaseqid.106, dmaseqid.107, dmaseqid.108, dmaseqid.109, dmaseqid.110, dmaseqid.111, dmaseqid.112, dmaseqid.113, dmaseqid.114, dmaseqid.115, dmaseqid.116, dmaseqid.117, dmaseqid.118, dmaseqid.119, dmaseqid.120, dmaseqid.121, dmaseqid.122, dmaseqid.123, dmaseqid.124, dmaseqid.125, dmaseqid.126, dmaseqid.127, dmaseqid.128, dmaseqid.129, dmaseqid.130, dmaseqid.131, dmaseqid.132, dmaseqid.133, dmaseqid.134, dmaseqid.135, dmaseqid.136, dmaseqid.137, dmaseqid.138, dmaseqid.139, dmaseqid.140, dmaseqid.141, dmaseqid.142, dmaseqid.143, dmaseqid.144, dmaseqid.145, dmaseqid.146, dmaseqid.147, dmaseqid.148, dmaseqid.149, dmaseqid.150, dmaseqid.151, dmaseqid.152, dmaseqid.153, dmaseqid.154, dmaseqid.155, dmaseqid.156, dmaseqid.157, dmaseqid.158, dmaseqid.159, dmaseqid.160, dmaseqid.161, dmaseqid.162, dmaseqid.163, dmaseqid.164, dmaseqid.165, dmaseqid.166, dmaseqid.167, dmaseqid.168, dmaseqid.169, dmaseqid.170, dmaseqid.171, dmaseqid.172, dmaseqid.173, dmaseqid.174, dmaseqid.175, dmaseqid.176, dmaseqid.177, dmaseqid.178, dmaseqid.179, dmaseqid.180, dmaseqid.181, dmaseqid.182, dmaseqid.183, dmaseqid.184, dmaseqid.185, dmaseqid.186, dmaseqid.187, dmaseqid.188, dmaseqid.189, dmaseqid.190, dmaseqid.191, dmaseqid.192, dmaseqid.193, dmaseqid.194, dmaseqid.195, dmaseqid.196, dmaseqid.197, dmaseqid.198, dmaseqid.199, dmaseqid.200, dmaseqid.201, dmaseqid.202, dmaseqid.203, dmaseqid.204, dmaseqid.205, dmaseqid.206, dmaseqid.207, dmaseqid.208, dmaseqid.209, dmaseqid.210, hdummys, tv_ads, dig_ads, prt_ads, region, coupons'; match will be made by position”

Error in dimnames(indiv.bart) <- list(observation = NULL, sample = NULL, : length of 'dimnames' [3] must match that of 'dims' [2]
Traceback:

1. predict(fit40, newdata = test, type = c("ev", "ppd", "indiv.fixef", 
 .     "indiv.ranef", "indiv.bart"), combine_chains = FALSE, sample_new_levels = TRUE)
2. predict.stan4bartFit(fit40, newdata = test, type = c("ev", "ppd", 
 .     "indiv.fixef", "indiv.ranef", "indiv.bart"), combine_chains = FALSE, 
 .     sample_new_levels = TRUE)

Does this have a solution? My train and test data frames have the exactly the same columns, just the number of rows are different. I read here by using a single chain we can overcome the error that comes up with number of dimensions associated with bart component.

Var list changing size

I am trying to fit a stan4bart model after the latest update but am getting the following error:

fit <- stan4bart(dx ~ bart(. - dx - history) + history,  dat_new[dat_new$train,],
                 cores = 1, seed = 0,
                 verbose = 1,
                 bart_args=list(keepTrees=T),
                 test=dat_new[!dat_new$train,])
'varlist' has changed (from nvar=9) to new 10 after EncodeVars() -- should no longer happen!

configure does not pick compiler from R settings

UPD. The only issue is compiler choice. Configure seems to ignore both R settings and env, but when I explicitly force CC= and CXX= via adding these to configure, everything is fine and tests pass.

  1. It seems that configure does not use compiler from R settings but picks OS default. This results in wrong settings (this is Rosetta, physical cpu is Intel, but build and host are ppc):
** using staged installation
checking for g++... g++
checking whether the C++ compiler works... yes
checking for C++ compiler default output file name... a.out
checking for suffix of executables... 
checking whether we are cross compiling... no
checking for suffix of object files... o
checking whether the compiler supports GNU C++... yes
checking whether g++ accepts -g... yes
checking for g++ option to enable C++11 features... none needed
checking how to run the C++ preprocessor... g++ -E
checking whether the compiler supports GNU C++... (cached) yes
checking whether g++ accepts -g... (cached) yes
checking for g++ option to enable C++11 features... (cached) none needed
checking build system type... x86_64-apple-darwin10.8.0
checking host system type... x86_64-apple-darwin10.8.0
checking for gcc... gcc
checking whether the compiler supports GNU C... yes
checking whether gcc accepts -g... yes
checking for gcc option to enable C11 features... unsupported
checking for gcc option to enable C99 features... -std=gnu99
checking for x86 cpuid  output... unknown
checking for x86-AVX xgetbv  output... unknown
checking for x86 cpuid 0x00000000 output... d:756e6547:6c65746e:49656e69
checking for x86 cpuid 0x00000001 output... 106a5:1060800:ffba220b:1f8bfbff
checking whether SSE2 is supported by the processor... yes
checking whether SSE2 is supported by the processor and OS... yes
checking whether C++ compiler accepts -msse2... yes
checking for stdio.h... yes
checking for stdlib.h... yes
checking for string.h... yes
checking for inttypes.h... yes
checking for stdint.h... yes
checking for strings.h... yes
checking for sys/stat.h... yes
checking for sys/types.h... yes
checking for unistd.h... yes
checking for malloc.h... no
checking size of size_t... 8
checking alignment of void*... 8
checking for size_t... yes
checking for working alloca.h... yes
checking for alloca... yes
checking for working posix_memalign... yes
checking for ffs... yes
configure: creating ./config.status
config.status: creating src/Makevars
config.status: creating src/config.h
  1. The package pulls in Intel intrinsics headers on PowerPC:
/opt/local/bin/g++-mp-12 -std=gnu++17 -I"/opt/local/Library/Frameworks/R.framework/Resources/include" -DNDEBUG -I"include" -I"include/sundials" -I"../inst/include" -DBOOST_DISABLE_ASSERTS -DEIGEN_NO_DEBUG -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/BH/include' -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/Rcpp/include' -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/RcppEigen/include' -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/RcppParallel/include' -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/dbarts/include' -isystem/opt/local/include/LegacySupport -I/opt/local/include  -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/RcppParallel/include' -D_REENTRANT -DSTAN_THREADS   -fPIC  -pipe -Os -arch ppc  -c bart_util.cpp -o bart_util.o
/opt/local/bin/g++-mp-12 -std=gnu++17 -I"/opt/local/Library/Frameworks/R.framework/Resources/include" -DNDEBUG -I"include" -I"include/sundials" -I"../inst/include" -DBOOST_DISABLE_ASSERTS -DEIGEN_NO_DEBUG -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/BH/include' -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/Rcpp/include' -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/RcppEigen/include' -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/RcppParallel/include' -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/dbarts/include' -isystem/opt/local/include/LegacySupport -I/opt/local/include  -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/RcppParallel/include' -D_REENTRANT -DSTAN_THREADS   -fPIC  -pipe -Os -arch ppc  -c init.cpp -o init.o
/opt/local/bin/gcc-mp-12 -I"/opt/local/Library/Frameworks/R.framework/Resources/include" -DNDEBUG -I"include" -I"include/sundials" -I"../inst/include" -DBOOST_DISABLE_ASSERTS -DEIGEN_NO_DEBUG -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/BH/include' -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/Rcpp/include' -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/RcppEigen/include' -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/RcppParallel/include' -I'/opt/local/Library/Frameworks/R.framework/Versions/4.2/Resources/library/dbarts/include' -isystem/opt/local/include/LegacySupport -I/opt/local/include   -fPIC  -pipe -Os -arch ppc  -c misc_adaptiveRadixTree.c -o misc_adaptiveRadixTree.o
In file included from include/misc/intrinsic.h:11,
                 from misc_adaptiveRadixTree.c:51:
/opt/local/lib/gcc12/gcc/powerpc-apple-darwin10/12.2.0/include/emmintrin.h:56:2: error: #error "Please read comment above.  Use -DNO_WARN_X86_INTRINSICS to disable this error."
   56 | #error "Please read comment above.  Use -DNO_WARN_X86_INTRINSICS to disable this error."
      |  ^~~~~
  1. Source code seems to unconditionally use 64-bit settings:
In file included from /opt/local/Library/Frameworks/R.framework/Resources/include/R.h:70,
                 from include/ext/R.h:24,
                 from include/ext/io.h:4,
                 from misc_adaptiveRadixTree.c:53:
/opt/local/Library/Frameworks/R.framework/Resources/include/Rconfig.h: At top level:
/opt/local/Library/Frameworks/R.framework/Resources/include/Rconfig.h:26: warning: "SIZEOF_SIZE_T" redefined
   26 | #define SIZEOF_SIZE_T 4
      | 
In file included from misc_adaptiveRadixTree.c:38:
config.h:74: note: this is the location of the previous definition
   74 | #define SIZEOF_SIZE_T 8
      |

Prediction function is not working

I trained a model and want to predict that to a data frame and to a data frame out of a raster. Both is not working.

model.stan <- stan4bart(
    formula = pa ~ bart(. - road_distance - urban_distance - water_distance) +
    (1 | road_distance) + (1 | urban_distance) + (1 | water_distance),
    verbose = 0, 
    data = PA_train,
    #test = PA_test,
    weights = my_weights,
    chains = 1, 
    iter = 10, 
    bart_args = list(n.trees = 10, keepTrees = T)
)
model.stan.predict <- predict(model.stan, newdata = as.data.frame(EU_environment_low, xy = T), type = "ev")

Following error:

Error in dimnames(indiv.bart) <- list(observation = NULL, sample = NULL,  : 
  length of 'dimnames' [3] must match that of 'dims' [2]

When I am using type = "indiv.ranef", it is working but this is not what I need.

Hopefully there is a solution.

Best
Loubert

Setting Custom Prior

I see that bart_args allows us to specify dbarts control elements. How would one go about changing the tree prior (power, base)?

Multilevel Bayesian Causal Forest

Hey,
I am interested in estimating a multilevel Bayesian Causal Forrest as used by Yeager et. (2019). While the co-authors Jared Murray and Carlos Carvalho have discussed the extension of BCF to multilevel models in several talks and in the supplemental material to the paper, they do not state with which software they estimated their model. Does this package enable one to estimate such a model?
Really appreciate the work on this package, and I thank you in advance for your help.

Feature Request: predict specific chains/iterations

Hi @vdorie,

In my process, I noticed that the predict function predicts my new data for every chain and every iteration - the burn-in. Wouldn't it be nice and much faster to be able to directly address a specific chain/all chains and maybe the last 200 iterations (or something else) that should be predicted?

Something like this:

model: chains = 6, iterations-burnin = 2500

predict.stan4bartFit(
        object = model,
        newdata = newdata_df,
        type = "ev",
        combine_chains = FALSE,
        pred_chains = c(2, 4, 6),
        pred_iter = 2300:2500
)

As I am quite busy at the moment, I am not able to look at your code and I don't know how "simple" can be realized, but from my "outside view" it might not be that difficult. Might it be "just" an implementation of how many iterations/chains are extracted from the model object and handled in the function!?

All the best
Louis

Returning Trees

Hi there! What are some optimal ways to return the fit tree structure (decision splits, split variables/values, conditional means, etc.) from each of the posterior draws?

I have been trying to work on accessing the sampler in C but to no avail so far. It would be nice if there was a C call that would extract the tree structure into the fitted object or a straightforward means of extracting such an object, like darts. Thanks!

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.