Code Monkey home page Code Monkey logo

Comments (5)

jessegrabowski avatar jessegrabowski commented on July 17, 2024 1

Hey, thanks for giving the module a try!

I can reproduce the problem on my end, so it's definitely a bug. It looks like a JAX problem; I can run your code if I use the default PyMC sampler. Also, if you have more than one regressor it works. For example, this runs:

import pandas as pd
import numpy as np

# Generate dummy data with monthly seasonality and trend
n_samples = 1000
k_exog = 3

np.random.seed(100)

trend_data = np.arange(n_samples) * .1
true_betas = np.random.normal(size=(k_exog,))
regressor_data = np.random.normal(scale=2, size=(n_samples, k_exog))
y = trend_data + regressor_data @ true_betas + np.random.normal(scale=2, size=n_samples) + 10
df = pd.DataFrame(np.c_[regressor_data, y],
                  index = pd.date_range("2001-01-01", freq="M", periods=n_samples),
                  columns=['y'] + [f'x_{i}' for i in range(k_exog)])
df.index.freq = 'M'


trend = st.LevelTrendComponent(name="linear_trend", order=2, innovations_order=0)
regressor = st.RegressionComponent(name="xreg", k_exog=k_exog, state_names=['x0', 'x1', 'x2'])
error = st.MeasurementError(name="error")

mod = trend + error + regressor
ss_mod = mod.build(name="test")
trend_dims, obs_dims, regressor_dims, regression_data_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords

with pm.Model(coords=coords) as model_1:
    data_xreg = pm.MutableData("data_xreg", df.drop(columns='y').values)
    
    P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
    P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)

    initial_trend = pm.Normal("initial_trend", dims=trend_dims)
    sigma_error = pm.Gamma("sigma_error", alpha=2, beta=5, dims=["observed_state"])

    beta_xreg = pm.Normal("beta_xreg", .2, 1, dims=regressor_dims)

    ss_mod.build_statespace_graph(df[['y']], mode='JAX')
    idata = pm.sample(nuts_sampler='numpyro', target_accept=0.9)
    # prior = pm.sample_prior_predictive(samples=10)

Probably the data is being incorrectly squeezed somewhere. I'll look closely and push a fix ASAP. Thanks for finding this bug and opening an issue!

from pymc-experimental.

ricardoV94 avatar ricardoV94 commented on July 17, 2024

CC @jessegrabowski

from pymc-experimental.

jessegrabowski avatar jessegrabowski commented on July 17, 2024

I finally had some time to look closely at this. It appears to be a bug that arises because broadcastable dimensions are signified by a shape of 1 in pytensor. This means the program considers them dynamic, since they might change after broadcasting. As a result, JAX gets upset by the graph, because it doesn't allow dynamic shapes. This is why the model works if you have more than one exogenous variable -- the 2nd dimension of the exogenous data isn't 1 anymore, and everything is inferred to be static. Might be related to pymc-devs/pytensor#408, but not sure.

For now, I can think of two possible work-arounds:

  1. Explicitly specify the shape of the exogenous data when you create the pm.MutableData, by passing a shape keyword argument.
  2. Use pm.ConstantData instead of pm.MutableData

Despite my choice of ordering, I think option 2 is preferable.

Here is a working example:

import pandas as pd
import numpy as np

# Generate dummy data with monthly seasonality and trend
n_samples = 1000
k_exog = 1

np.random.seed(100)

trend_data = np.arange(n_samples) * .1
true_betas = np.random.normal(size=(k_exog,))
regressor_data = np.random.normal(scale=2, size=(n_samples, k_exog))
y = trend_data + regressor_data @ true_betas + np.random.normal(scale=2, size=n_samples) + 10
df = pd.DataFrame(np.c_[y, regressor_data],
                  index = pd.date_range("2001-01-01", freq="ME", periods=n_samples),
                  columns=['y'] + [f'x_{i}' for i in range(k_exog)])
df.index.freq = 'ME'


trend = st.LevelTrendComponent(name="linear_trend", order=2, innovations_order=0)
regressor = st.RegressionComponent(name="xreg", k_exog=k_exog, state_names=[f'x{i}' for i in range(k_exog)])
error = st.MeasurementError(name="error")

mod = trend + error + regressor
ss_mod = mod.build(name="test")
trend_dims, obs_dims, regressor_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords

with pm.Model(coords=coords) as model_1:
    
    # Option 1:
    data_xreg = pm.MutableData("data_xreg", df.drop(columns='y').values, 
                               dims=['time', 'exog_state'],
                               shape=(n_samples, k_exog)) # <--- Key line

    # Option 2:
    # data_xreg = pm.ConstantData("data_xreg", df.drop(columns='y').values, 
    #                           dims=['time', 'exog_state'])
    
    P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
    P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)

    initial_trend = pm.Normal("initial_trend", dims=trend_dims)
    sigma_error = pm.Gamma("sigma_error", alpha=2, beta=5)

    beta_xreg = pm.Normal("beta_xreg", .2, 1, dims=regressor_dims)
    
    ss_mod.build_statespace_graph(df[['y']], mode='JAX')
    idata = pm.sample(nuts_sampler='numpyro', target_accept=0.9)

Note that I also specified n_samples. If you don't JAX will bark at you about dynamic shapes when you try to do post-estimation sampling (ss_mod.sample_conditional_posterior, for example).

I'll let you know when I come up with a more long-term solution.

from pymc-experimental.

ricardoV94 avatar ricardoV94 commented on July 17, 2024

You only need to specify the shape for broadcastable dims if you intend it to broadcast. You can pass shape=(None, 1) if you still want the other dim to be resizeable (but cannot broadcast it with other parameters)

from pymc-experimental.

jessegrabowski avatar jessegrabowski commented on July 17, 2024

Yes this works as well (with pm.MutableData), but JAX will still error on pm.sample_posterior_predictive, complaining about dynamic slicing. So I recommend to just declare both for now, since conditional forecasting with exogenous timeseries isn't support yet anyway.

from pymc-experimental.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.