Comments (5)
Thanks @kazewong ! I think this is sufficient for my use case (or at least I can work around the issue with a bit of legwork on my end). This was mostly just a suggestion as there is such a function in the emcee package which enabled me to "generalize" some work I am currently doing with hierarchical Bayesian statistics and reducing the amount of work for me to rewrite or add to some very complex hierarchical Bayesian structures.
While I did only suggest a use case where only observations are an external argument, my own work generally parses a large number of custom functions that fit into another overarching function. Being able to parse series of functions into another function for use is where this use case becomes quite powerful. That being said, I totally understand the fact that this would require a major restructure in the code. I really hope you do get the time at some point to implement it though!
from flowmc.
That's a nice solution @dfm! That seems like the best course of action for me for now.
That being said, a perhaps hacky but generalized solution to this issue for implementation into flowMC might be to use something like this:
class _FunctionWrapper(object):
def __init__(self, f, args, kwargs):
self.f = f
self.args = args or []
self.kwargs = kwargs or {}
def __call__(self, x):
return self.f(x, *self.args, **self.kwargs)
where the logpdf is instead replaced by a _FunctionWrapper object when initalizing the sampler. For instance, in the case of MALA, it could be changed to be:
class MALA(LocalSamplerBase):
def __init__(self, logpdf: Callable, jit: bool, params: dict, args = [], kwargs = {}) -> Callable:
super().__init__(logpdf, jit, params)
self.params = params
self.logpdf = _FunctionWrapper(logpdf,args,kwargs)
where args and kwargs are optional arguments that can be parsed through and called as normal. I think this would probably mean that flowMC would need minimal/no rewriting with this solution and users can arbitrarily add as many arguments as they wish to their log posterior functions. Whether this will work with JAX, I'll admit I don't know.
from flowmc.
@bjricketts Thanks for bringing up this issue. The current recommended way to compare with the observation is to "pre-baked" the data into the function instead of loading it as an argument.
In your code, since you have defined observation
in the global scope, I think your log_posterior
function will still work without the observation
in the argument. If you want to dynamically generate the data, we would recommend constructing the log_posterior
function with another function, something similar to this:
def make_posterior_function(data):
def log_posterior(x):
return f(x,data)
return log_posterior
The reason behind this seemingly complicated syntax is mainly to avoid triggering recompilation when the data size change. Even though the current code does not directly support what you want to do, I think having a version where we can use syntax like log_posterior(x,observation)
could be useful. Implementing this will require some thoughts on how to safeguard against unnecessary recompilation, and it will take some work in restructuring the sampler.
There are some community examples out there, here is one of mine on gravitational-wave parameter estimation
https://github.com/kazewong/jim/blob/5f5c51622121cc3b175a0cf0ee00be0f8040b23f/example/ParameterEstimation/GW150914.py#L148
At some point in the future, we are going to link these community examples in the doc.
Please let me know if this is sufficient for your use case. For now, we do not have the bandwidth to restructure the code to accommodate the suggested syntax, but we may come back to this issue at some point.
from flowmc.
@bjricketts — Another tip is that you could update your example above as follows:
n_dim = 2
data_x = np.arange(1,10)
m = 2
c = 5
observation = m*data_x + c
def log_posterior(x, observation):
x_model = np.arange(1,10)
m = x[0]
c = x[1]
y = m*x_model+c
gaussian = -((y-observation)**2)/(2*jnp.sqrt(observation)**2)
return jnp.sum(gaussians)
from functools import partial
log_prob_func = partial(log_posterior, observation=observation)
(Where all I've done is added the two final lines, and changed a np
to a jnp
in your function.)Then use the existing implementation (with log_prob_func
as the input to the sampler) without any change. This doesn't seem to onerous to me!
from flowmc.
@bjricketts Some updates on this issue: in order to make sure we do not trigger recompilation every time we call the likelihood function, we do need to rehaul the API a bit, such that the sampler can accept data as it input and be aware of that during compilation.
I made some progress along this direction in #116 . Currently the sampler works with MALA in the way you describe. I am gonna clean up other stuffs around the current API, such as examples, implementing the functionality for other local samplers such HMC, then it should be in the released version.
from flowmc.
Related Issues (20)
- `jax.interpreters.pxla` has no attribute `ShardedDeviceArray` HOT 2
- Image of the function in tutorial HOT 1
- Ensemble training of normalizing flow
- Question about integrating with bayeux HOT 9
- Sampling from arrays
- Get rid of random_key_set
- Clean up parameter names HOT 1
- Question regarding the data for the log-likelihood HOT 3
- Use scan to reduce NF compilation time
- Making sampler composable
- Put training loop into NF class
- TypeError: unsupported operand type(s) for *: `dict` and `dict` in MALA.py after flowMC-v3.0.0 release HOT 4
- Implement optimization strategy
- Add probability floor to normalizing flow model
- [Fixed bug, but not in release] UnboundLocalError: local variable 'best_state' referenced before assignment HOT 2
- Update examples
- Lower precision training HOT 1
- Why do we have to pass data two times? HOT 1
- Refine strategy interface
- Implement flow matching
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from flowmc.