Code Monkey home page Code Monkey logo

Comments (5)

bjricketts avatar bjricketts commented on July 23, 2024 1

Thanks @kazewong ! I think this is sufficient for my use case (or at least I can work around the issue with a bit of legwork on my end). This was mostly just a suggestion as there is such a function in the emcee package which enabled me to "generalize" some work I am currently doing with hierarchical Bayesian statistics and reducing the amount of work for me to rewrite or add to some very complex hierarchical Bayesian structures.

While I did only suggest a use case where only observations are an external argument, my own work generally parses a large number of custom functions that fit into another overarching function. Being able to parse series of functions into another function for use is where this use case becomes quite powerful. That being said, I totally understand the fact that this would require a major restructure in the code. I really hope you do get the time at some point to implement it though!

from flowmc.

bjricketts avatar bjricketts commented on July 23, 2024 1

That's a nice solution @dfm! That seems like the best course of action for me for now.

That being said, a perhaps hacky but generalized solution to this issue for implementation into flowMC might be to use something like this:

class _FunctionWrapper(object):
    def __init__(self, f, args, kwargs):
        self.f = f
        self.args = args or []
        self.kwargs = kwargs or {}

    def __call__(self, x):
        return self.f(x, *self.args, **self.kwargs)

where the logpdf is instead replaced by a _FunctionWrapper object when initalizing the sampler. For instance, in the case of MALA, it could be changed to be:

class MALA(LocalSamplerBase):

    def __init__(self, logpdf: Callable, jit: bool, params: dict, args = [], kwargs = {}) -> Callable:
        super().__init__(logpdf, jit, params)
        self.params = params
        self.logpdf = _FunctionWrapper(logpdf,args,kwargs)

where args and kwargs are optional arguments that can be parsed through and called as normal. I think this would probably mean that flowMC would need minimal/no rewriting with this solution and users can arbitrarily add as many arguments as they wish to their log posterior functions. Whether this will work with JAX, I'll admit I don't know.

from flowmc.

kazewong avatar kazewong commented on July 23, 2024

@bjricketts Thanks for bringing up this issue. The current recommended way to compare with the observation is to "pre-baked" the data into the function instead of loading it as an argument.

In your code, since you have defined observation in the global scope, I think your log_posterior function will still work without the observation in the argument. If you want to dynamically generate the data, we would recommend constructing the log_posterior function with another function, something similar to this:

def make_posterior_function(data):
    def log_posterior(x):
        return f(x,data)
    return log_posterior

The reason behind this seemingly complicated syntax is mainly to avoid triggering recompilation when the data size change. Even though the current code does not directly support what you want to do, I think having a version where we can use syntax like log_posterior(x,observation) could be useful. Implementing this will require some thoughts on how to safeguard against unnecessary recompilation, and it will take some work in restructuring the sampler.

There are some community examples out there, here is one of mine on gravitational-wave parameter estimation
https://github.com/kazewong/jim/blob/5f5c51622121cc3b175a0cf0ee00be0f8040b23f/example/ParameterEstimation/GW150914.py#L148
At some point in the future, we are going to link these community examples in the doc.

Please let me know if this is sufficient for your use case. For now, we do not have the bandwidth to restructure the code to accommodate the suggested syntax, but we may come back to this issue at some point.

from flowmc.

dfm avatar dfm commented on July 23, 2024

@bjricketts — Another tip is that you could update your example above as follows:

n_dim = 2

data_x = np.arange(1,10)
m = 2
c = 5
observation = m*data_x + c

def log_posterior(x, observation):
    x_model = np.arange(1,10)
    m = x[0]
    c = x[1]
    y = m*x_model+c
    gaussian = -((y-observation)**2)/(2*jnp.sqrt(observation)**2)
    return jnp.sum(gaussians)

from functools import partial
log_prob_func = partial(log_posterior, observation=observation)

(Where all I've done is added the two final lines, and changed a np to a jnp in your function.)Then use the existing implementation (with log_prob_func as the input to the sampler) without any change. This doesn't seem to onerous to me!

from flowmc.

kazewong avatar kazewong commented on July 23, 2024

@bjricketts Some updates on this issue: in order to make sure we do not trigger recompilation every time we call the likelihood function, we do need to rehaul the API a bit, such that the sampler can accept data as it input and be aware of that during compilation.

I made some progress along this direction in #116 . Currently the sampler works with MALA in the way you describe. I am gonna clean up other stuffs around the current API, such as examples, implementing the functionality for other local samplers such HMC, then it should be in the released version.

from flowmc.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.