Code Monkey home page Code Monkey logo

Comments (3)

njtierney avatar njtierney commented on August 26, 2024

Related - TF2 error - Failure (test_inference.R:305): mcmc supports rwmh sampler with uniform proposals

devtools::load_all(".")
#> ℹ Loading greta
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> 
#> ✔ Initialising python and checking dependencies ... done!
#> Loaded Tensorflow version 2.10.0
source(here::here("tests", "testthat", "helpers.R"))
set.seed(5)
  x <- uniform(0, 1)
  m <- model(x)
  expect_ok(draws <- mcmc(m,
    sampler = rwmh("uniform"),
    n_samples = 100, warmup = 100,
    verbose = FALSE
  ))
#> Error: `expr` threw an unexpected error.
#> Message: greta hit a tensorflow error:
#> Error in py_call_impl(callable, dots$args, dots$keywords): RuntimeError:
#> Evaluation error: ValueError: Cannot reshape a tensor with 0 elements to shape
#> [1] (1 elements) for '{{node Reshape}} = Reshape[T=DT_DOUBLE,
#> Tshape=DT_INT32](Mul, Reshape/shape)' with input shapes: [0], [1] and with
#> input tensors computed as partial shapes: input[1] = [1]. .
#> Class:   simpleError/error/condition

Created on 2022-12-09 with reprex v2.0.2

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.2.1 (2022-06-23)
#>  os       macOS Monterey 12.3.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Australia/Perth
#>  date     2022-12-09
#>  pandoc   2.19.2 @ /Applications/RStudio.app/Contents/Resources/app/quarto/bin/tools/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  !  package     * version    date (UTC) lib source
#>     abind         1.4-5      2016-07-21 [1] CRAN (R 4.2.0)
#>     backports     1.4.1      2021-12-13 [1] CRAN (R 4.2.0)
#>     base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.2.0)
#>     brio          1.1.3      2021-11-30 [1] CRAN (R 4.2.0)
#>     cachem        1.0.6      2021-08-19 [1] CRAN (R 4.2.0)
#>     callr         3.7.3      2022-11-02 [1] CRAN (R 4.2.0)
#>     cli           3.4.1      2022-09-23 [1] CRAN (R 4.2.0)
#>     coda          0.19-4     2020-09-30 [1] CRAN (R 4.2.0)
#>     codetools     0.2-18     2020-11-04 [1] CRAN (R 4.2.1)
#>     crayon        1.5.2      2022-09-29 [1] CRAN (R 4.2.0)
#>     desc          1.4.2      2022-09-08 [1] CRAN (R 4.2.0)
#>     devtools      2.4.5      2022-10-11 [1] CRAN (R 4.2.0)
#>     digest        0.6.30     2022-10-18 [1] CRAN (R 4.2.0)
#>     ellipsis      0.3.2      2021-04-29 [1] CRAN (R 4.2.0)
#>     evaluate      0.18       2022-11-07 [1] CRAN (R 4.2.0)
#>     fansi         1.0.3      2022-03-24 [1] CRAN (R 4.2.0)
#>     fastmap       1.1.0      2021-01-25 [1] CRAN (R 4.2.0)
#>     fs            1.5.2      2021-12-08 [1] CRAN (R 4.2.0)
#>     future        1.29.0     2022-11-06 [1] CRAN (R 4.2.0)
#>     globals       0.16.2     2022-11-21 [1] CRAN (R 4.2.1)
#>     glue          1.6.2      2022-02-24 [1] CRAN (R 4.2.0)
#>  VP greta       * 0.4.2.9000 2022-09-08 [?] CRAN (R 4.2.0) (on disk 0.4.3)
#>     here          1.0.1      2020-12-13 [1] CRAN (R 4.2.0)
#>     highr         0.9        2021-04-16 [1] CRAN (R 4.2.0)
#>     hms           1.1.2      2022-08-19 [1] CRAN (R 4.2.0)
#>     htmltools     0.5.3      2022-07-18 [1] CRAN (R 4.2.0)
#>     htmlwidgets   1.5.4      2021-09-08 [1] CRAN (R 4.2.0)
#>     httpuv        1.6.6      2022-09-08 [1] CRAN (R 4.2.0)
#>     jsonlite      1.8.3      2022-10-21 [1] CRAN (R 4.2.0)
#>     knitr         1.41       2022-11-18 [1] CRAN (R 4.2.0)
#>     later         1.3.0      2021-08-18 [1] CRAN (R 4.2.0)
#>     lattice       0.20-45    2021-09-22 [1] CRAN (R 4.2.1)
#>     lifecycle     1.0.3      2022-10-07 [1] CRAN (R 4.2.0)
#>     listenv       0.8.0      2019-12-05 [1] CRAN (R 4.2.0)
#>     magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.2.0)
#>     Matrix        1.5-3      2022-11-11 [1] CRAN (R 4.2.0)
#>     memoise       2.0.1      2021-11-26 [1] CRAN (R 4.2.0)
#>     mime          0.12       2021-09-28 [1] CRAN (R 4.2.0)
#>     miniUI        0.1.1.1    2018-05-18 [1] CRAN (R 4.2.0)
#>     parallelly    1.32.1     2022-07-21 [1] CRAN (R 4.2.0)
#>     pillar        1.8.1      2022-08-19 [1] CRAN (R 4.2.0)
#>     pkgbuild      1.4.0      2022-11-27 [1] CRAN (R 4.2.1)
#>     pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.2.0)
#>     pkgload       1.3.2      2022-11-16 [1] CRAN (R 4.2.0)
#>     png           0.1-8      2022-11-29 [1] CRAN (R 4.2.0)
#>     prettyunits   1.1.1      2020-01-24 [1] CRAN (R 4.2.0)
#>     processx      3.8.0      2022-10-26 [1] CRAN (R 4.2.0)
#>     profvis       0.3.7      2020-11-02 [1] CRAN (R 4.2.0)
#>     progress      1.2.2      2019-05-16 [1] CRAN (R 4.2.0)
#>     promises      1.2.0.1    2021-02-11 [1] CRAN (R 4.2.0)
#>     ps            1.7.2      2022-10-26 [1] CRAN (R 4.2.0)
#>     purrr         0.3.5      2022-10-06 [1] CRAN (R 4.2.0)
#>     R.cache       0.16.0     2022-07-21 [1] CRAN (R 4.2.0)
#>     R.methodsS3   1.8.2      2022-06-13 [1] CRAN (R 4.2.0)
#>     R.oo          1.25.0     2022-06-12 [1] CRAN (R 4.2.0)
#>     R.utils       2.12.2     2022-11-11 [1] CRAN (R 4.2.0)
#>     R6            2.5.1      2021-08-19 [1] CRAN (R 4.2.0)
#>     Rcpp          1.0.9      2022-07-08 [1] CRAN (R 4.2.0)
#>     remotes       2.4.2      2021-11-30 [1] CRAN (R 4.2.0)
#>     reprex        2.0.2      2022-08-17 [1] CRAN (R 4.2.0)
#>     reticulate    1.26       2022-08-31 [1] CRAN (R 4.2.0)
#>     rlang         1.0.6      2022-09-24 [1] CRAN (R 4.2.0)
#>     rmarkdown     2.18       2022-11-09 [1] CRAN (R 4.2.0)
#>     rprojroot     2.0.3      2022-04-02 [1] CRAN (R 4.2.0)
#>     rstudioapi    0.14       2022-08-22 [1] CRAN (R 4.2.0)
#>     sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.2.0)
#>     shiny         1.7.3      2022-10-25 [1] CRAN (R 4.2.0)
#>     stringi       1.7.8      2022-07-11 [1] CRAN (R 4.2.0)
#>     stringr       1.5.0      2022-12-02 [1] CRAN (R 4.2.0)
#>     styler        1.8.1      2022-11-07 [1] CRAN (R 4.2.0)
#>     tensorflow    2.9.0      2022-05-21 [1] CRAN (R 4.2.0)
#>     testthat    * 3.1.5      2022-10-08 [1] CRAN (R 4.2.0)
#>     tfautograph   0.3.2      2021-09-17 [1] CRAN (R 4.2.0)
#>     tfruns        1.5.1      2022-09-05 [1] CRAN (R 4.2.0)
#>     urlchecker    1.0.1      2021-11-30 [1] CRAN (R 4.2.0)
#>     usethis       2.1.6      2022-05-25 [1] CRAN (R 4.2.0)
#>     utf8          1.2.2      2021-07-24 [1] CRAN (R 4.2.0)
#>     vctrs         0.5.1      2022-11-16 [1] CRAN (R 4.2.0)
#>     whisker       0.4        2019-08-28 [1] CRAN (R 4.2.0)
#>     withr         2.5.0      2022-03-03 [1] CRAN (R 4.2.0)
#>     xfun          0.35       2022-11-16 [1] CRAN (R 4.2.0)
#>     xtable        1.8-4      2019-04-21 [1] CRAN (R 4.2.0)
#>     yaml          2.3.6      2022-10-18 [1] CRAN (R 4.2.0)
#>     yesno         0.1.2      2020-07-10 [1] CRAN (R 4.2.0)
#> 
#>  [1] /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library
#> 
#>  V ── Loaded and on-disk version mismatch.
#>  P ── Loaded and on-disk path mismatch.
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.8.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2
#>  version:        3.8.13 | packaged by conda-forge | (default, Mar 25 2022, 06:05:16)  [Clang 12.0.1 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/numpy
#>  numpy_version:  1.23.2
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.8/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

from greta.

njtierney avatar njtierney commented on August 26, 2024

Current progress is linked in file, tf2-test-inference-mcmcm-rwmh-normal-uniform.qmd locally

The file contains the following:

We get an error for rwmh like so

x <- normal(0, 1)
m <- model(x)

# this errors
draws <- mcmc(m,
              sampler = rwmh("normal"),
              n_samples = 100, 
              warmup = 100,
              verbose = FALSE)

The error being

Error: greta hit a tensorflow error:
Error in py_call_impl(callable,
dots$args, dots$keywords):
RuntimeError: Evaluation error:
ValueError: Cannot reshape a tensor
with 0 elements to shape [1] (1
elements) for '{{node Reshape}} =
Reshape[T=DT_DOUBLE,
Tshape=DT_INT32](Mul, Reshape/shape)'
with input shapes: [0], [1] and with
input tensors computed as partial
shapes: input[1] = [1]. .

Minimally, the code to fail is:

x <- normal(0, 1)
m <- model(x)
draws <- mcmc(m, sampler = rwmh())

So my thought is that there is something the wrong shape or dimension here
related to rwmh sampler?

We don't get this error with the default sampler, hmc.

x <- normal(0, 1)
m <- model(x)
draws <- mcmc(m, sampler = hmc())

Let's go a journey - first step into mcmc

debugonce(mcmc)
x <- normal(0, 1)
m <- model(x)
draws <- mcmc(m, sampler = rwmh())

OK so the error occurs here in sample_carefully

#| eval: false
      result <- cleanly(
        self$tf_evaluate_sample_batch(
          free_state = tensorflow::as_tensor(
            free_state,
            dtype = tf_float()
          ),
          sampler_burst_length = tensorflow::as_tensor(sampler_burst_length),
          sampler_thin = tensorflow::as_tensor(sampler_thin),
          sampler_param_vec = tensorflow::as_tensor(
            sampler_param_vec,
            dtype = tf_float(),
            shape = length(sampler_param_vec)
          )
        )
      )

I'll place a browser there and take a look

x <- normal(0, 1)
m <- model(x)
draws <- mcmc(m, sampler = rwmh())

OK and because the code uses tf_function it needs to be undone there somehow

Breaking down this code

result <- cleanly(
        self$tf_evaluate_sample_batch(
          free_state = tensorflow::as_tensor(
            free_state,
            dtype = tf_float()
          ),
          sampler_burst_length = tensorflow::as_tensor(sampler_burst_length),
          sampler_thin = tensorflow::as_tensor(sampler_thin),
          sampler_param_vec = tensorflow::as_tensor(
            sampler_param_vec,
            dtype = tf_float(),
            shape = length(sampler_param_vec)
          )
        )
      )

The outputs of the free_state, sampler_burst_length, and sampler_thin, sampler_param_vec are:

free_state
sampler_burst_length
sampler_thin
sampler_param_vec
Browse[2]> free_state
     all_forward_variable_1
[1,]             0.04096190
[2,]             0.01090323
[3,]            -0.15063224
[4,]             0.08087937
Browse[2]> sampler_burst_length
[1] 3
Browse[2]> sampler_thin
[1] 1
Browse[2]> sampler_param_vec
rwmh_epsilon rwmh_diag_sd 
         0.1          1.0 

And what does this comparative part look like in hmc()?

x <- normal(0, 1)
m <- model(x)
draws <- mcmc(m, sampler = hmc())
free_state
sampler_burst_length
sampler_thin
sampler_param_vec
Browse[2]> free_state
     all_forward_variable_1
[1,]             0.13124380
[2,]            -0.12720201
[3,]            -0.15659195
[4,]             0.08557503
Browse[2]> sampler_burst_length
[1] 3
Browse[2]> sampler_thin
[1] 1
Browse[2]> sampler_param_vec
      hmc_l hmc_epsilon hmc_diag_sd 
        6.0         0.1         1.0 

OK so it's not going to awry at that point...

Overall, tf_evaluate_sample_batch is a tf function, written like so:

#| eval: false
self$tf_evaluate_sample_batch <- tensorflow::tf_function(
        f = self$define_tf_draws,
        input_signature = list(
          # free state
          tf$TensorSpec(shape = list(NULL, self$n_free),
                        dtype = tf_float()),
          # sampler_burst_length
          tf$TensorSpec(shape = list(),
                        dtype = tf$int32),
          # sampler_thin
          tf$TensorSpec(shape = list(),
                        dtype = tf$int32),
          # sampler_param_vec
          tf$TensorSpec(shape = list(
            length(
              unlist(
                self$sampler_parameter_values()
                )
              )
            ),
                        dtype = tf_float())
        )
      )

I can't seem to work out how to remove the tf_function part of so we can debug
it.

I tried to just use self$define_tf_draws as it's own function, defining
the inputs, free_state etc, as either the named variables, or as the Tensors
above, but I think I'm missing something. Not sure how to proceed further from here.

from greta.

hrlai avatar hrlai commented on August 26, 2024

Hi @njtierney I got similar error from mcmc today when testing out the TF2 version from #534

mcmc(m,
              sampler = hmc(10, 15),
              warmup = 4000,
              n_samples = 1000,
              chains = 1,
              one_by_one = TRUE)

gave:

    warmup                                           0/4000 | eta:  ?s          
Error in self$sample_carefully(free_state = self$free_state, sampler_burst_length = as.integer(n_samples),  : 
  object 'n_samples' not found

Sorry I wasn't able to give a minimal working example to reproduce the error now. But will keep looking into it. Posting here in case it is something that you're tackling at the moment.

from greta.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.