Comments (5)
Hey, thanks for giving the module a try!
I can reproduce the problem on my end, so it's definitely a bug. It looks like a JAX problem; I can run your code if I use the default PyMC sampler. Also, if you have more than one regressor it works. For example, this runs:
import pandas as pd
import numpy as np
# Generate dummy data with monthly seasonality and trend
n_samples = 1000
k_exog = 3
np.random.seed(100)
trend_data = np.arange(n_samples) * .1
true_betas = np.random.normal(size=(k_exog,))
regressor_data = np.random.normal(scale=2, size=(n_samples, k_exog))
y = trend_data + regressor_data @ true_betas + np.random.normal(scale=2, size=n_samples) + 10
df = pd.DataFrame(np.c_[regressor_data, y],
index = pd.date_range("2001-01-01", freq="M", periods=n_samples),
columns=['y'] + [f'x_{i}' for i in range(k_exog)])
df.index.freq = 'M'
trend = st.LevelTrendComponent(name="linear_trend", order=2, innovations_order=0)
regressor = st.RegressionComponent(name="xreg", k_exog=k_exog, state_names=['x0', 'x1', 'x2'])
error = st.MeasurementError(name="error")
mod = trend + error + regressor
ss_mod = mod.build(name="test")
trend_dims, obs_dims, regressor_dims, regression_data_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords
with pm.Model(coords=coords) as model_1:
data_xreg = pm.MutableData("data_xreg", df.drop(columns='y').values)
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)
initial_trend = pm.Normal("initial_trend", dims=trend_dims)
sigma_error = pm.Gamma("sigma_error", alpha=2, beta=5, dims=["observed_state"])
beta_xreg = pm.Normal("beta_xreg", .2, 1, dims=regressor_dims)
ss_mod.build_statespace_graph(df[['y']], mode='JAX')
idata = pm.sample(nuts_sampler='numpyro', target_accept=0.9)
# prior = pm.sample_prior_predictive(samples=10)
Probably the data is being incorrectly squeezed somewhere. I'll look closely and push a fix ASAP. Thanks for finding this bug and opening an issue!
from pymc-experimental.
from pymc-experimental.
I finally had some time to look closely at this. It appears to be a bug that arises because broadcastable dimensions are signified by a shape of 1
in pytensor. This means the program considers them dynamic, since they might change after broadcasting. As a result, JAX gets upset by the graph, because it doesn't allow dynamic shapes. This is why the model works if you have more than one exogenous variable -- the 2nd dimension of the exogenous data isn't 1
anymore, and everything is inferred to be static. Might be related to pymc-devs/pytensor#408, but not sure.
For now, I can think of two possible work-arounds:
- Explicitly specify the shape of the exogenous data when you create the
pm.MutableData
, by passing ashape
keyword argument. - Use
pm.ConstantData
instead ofpm.MutableData
Despite my choice of ordering, I think option 2 is preferable.
Here is a working example:
import pandas as pd
import numpy as np
# Generate dummy data with monthly seasonality and trend
n_samples = 1000
k_exog = 1
np.random.seed(100)
trend_data = np.arange(n_samples) * .1
true_betas = np.random.normal(size=(k_exog,))
regressor_data = np.random.normal(scale=2, size=(n_samples, k_exog))
y = trend_data + regressor_data @ true_betas + np.random.normal(scale=2, size=n_samples) + 10
df = pd.DataFrame(np.c_[y, regressor_data],
index = pd.date_range("2001-01-01", freq="ME", periods=n_samples),
columns=['y'] + [f'x_{i}' for i in range(k_exog)])
df.index.freq = 'ME'
trend = st.LevelTrendComponent(name="linear_trend", order=2, innovations_order=0)
regressor = st.RegressionComponent(name="xreg", k_exog=k_exog, state_names=[f'x{i}' for i in range(k_exog)])
error = st.MeasurementError(name="error")
mod = trend + error + regressor
ss_mod = mod.build(name="test")
trend_dims, obs_dims, regressor_dims, P0_dims = ss_mod.param_dims.values()
coords = ss_mod.coords
with pm.Model(coords=coords) as model_1:
# Option 1:
data_xreg = pm.MutableData("data_xreg", df.drop(columns='y').values,
dims=['time', 'exog_state'],
shape=(n_samples, k_exog)) # <--- Key line
# Option 2:
# data_xreg = pm.ConstantData("data_xreg", df.drop(columns='y').values,
# dims=['time', 'exog_state'])
P0_diag = pm.Gamma("P0_diag", alpha=2, beta=5, dims=P0_dims[0])
P0 = pm.Deterministic("P0", pt.diag(P0_diag), dims=P0_dims)
initial_trend = pm.Normal("initial_trend", dims=trend_dims)
sigma_error = pm.Gamma("sigma_error", alpha=2, beta=5)
beta_xreg = pm.Normal("beta_xreg", .2, 1, dims=regressor_dims)
ss_mod.build_statespace_graph(df[['y']], mode='JAX')
idata = pm.sample(nuts_sampler='numpyro', target_accept=0.9)
Note that I also specified n_samples
. If you don't JAX will bark at you about dynamic shapes when you try to do post-estimation sampling (ss_mod.sample_conditional_posterior
, for example).
I'll let you know when I come up with a more long-term solution.
from pymc-experimental.
You only need to specify the shape for broadcastable dims if you intend it to broadcast. You can pass shape=(None, 1)
if you still want the other dim to be resizeable (but cannot broadcast it with other parameters)
from pymc-experimental.
Yes this works as well (with pm.MutableData
), but JAX will still error on pm.sample_posterior_predictive
, complaining about dynamic slicing. So I recommend to just declare both for now, since conditional forecasting with exogenous timeseries isn't support yet anyway.
from pymc-experimental.
Related Issues (20)
- Add notebook example on how to use BlackJax SMC from pymc models
- Consider renaming to pymc-extras
- Re-working `as_model` HOT 10
- Pathfinder gives confident wrong answer with small sample prediction HOT 5
- Error message from build_statespace_graph when cycle is one of the model components.
- Add test for MarginalModel where variable depends on two marginalized variables
- including cyclic or seasonal components causes error messages from build_statespace_graph since last bug fix
- Support batched constant arguments when marginalising `DiscreteUniform` HOT 1
- use dict instead of treedict in marginalized model HOT 1
- Standard deviation parameters are incorrectly treated as variances in statespace covariance matrices HOT 24
- Error messages when using the pymc or nutpie NUTS samplers in combination with pymc-experimental HOT 8
- MarginalModel fails with Data containers
- `test_histogram_approximation` failing due to warning in newer JAX release HOT 5
- `MarginalModel.unmarginalize` doesn't accept `var_names` HOT 2
- `recover_marginals` should have a progress bar
- Support marginalization of HMM with higher lag orders
- In model_builder, _validate_data changes input type HOT 1
- MarginalModel freezes mutable dim lengths HOT 1
- ModelBuilder not work HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from pymc-experimental.