Code Monkey home page Code Monkey logo

Comments (34)

sgbaird avatar sgbaird commented on May 17, 2024 2

Seems like the PR never got merged

from botorch.

Balandat avatar Balandat commented on May 17, 2024 1

@mc-robinson, thanks for the comprehensive issue. We have designed the posterior API in a way to also support getting the posterior predictive, by using the observation_noise kwarg. So you can get the posterior predictive by calling gp.posterior(X_test, observation_noise=True). You will see that for the basic gpytorch model wrapper here, the implementation is exactly what @jacobrgardner suggested above.

However, since the HeteroskedasticGP subclasses SingleTaskGP, the actual code path for your case is through BatchedMultiOutputGPyTorchModel (which for performance reasons does some magic for representing multi-output models with independently modeled outputs as a batched model under the hood). Looking through the posterior code of that, I'm realizing that we appear to have missed handling the observation_noise kwarg in there. This should be easy to add though, I'll put up a PR for this tomorrow.

Note that the noise model of HeteroskedasticSingleTaskGP currently returns the log of the variance, thus the exponentiation. I believe this may be an error in the noise model.

Indeed, the actual noise returned by the likelihood should not be on the log scale. The reason the model fits the inner GP on the log noise is that this way we don't have to worry about zero (or negative) noise predictions. The correct way to handle is to pass in a custom Constraint object with torch.exp transform as the noise_constraint to the noise HeteroskedasticNoise constructor. (Alternatively, the model could also just be fit on the inverse softplus transform of the noise). I'll add this as another PR to do.

I do believe there should probably be a HeteroskedasticGP class that merely accepts train_x, train_y

Generally, a heteroskedastic GP for the case where you don't have noise observations would be a reasonable thing to implement. I haven't thought all too much about this so far though. Let me try to rephrase your suggestion: You suggest first training a model that infers the noise level, then, based on the mean prediction of that model, estimate a heteroskedastic noise function (using the data you've already used to fit the model)? This sounds like it could easily go wrong... maybe some kind of latent-variable model in which the noise levels for the inner GP are themselves model parameters that you optimize over would be more appropriate?

Edit: Just realized this is all in the colab notebook - awesome job! I'll give it a closer look tomorrow.

@jacobrgardner, haha, feel free to keep helping out with issues over here, I don't mind...

from botorch.

mc-robinson avatar mc-robinson commented on May 17, 2024 1

Hi guys, thanks for the work you've already put into this. I'll try to update the example notebook soon, and I'm happy to try and help build out the heteroskedastic models. Just let me know if there is a best place to start.

As for your question @Balandat

The only question is what exactly it means for this process to "converge" (the paper doesn't say). It's probably reasonable to just look at the change in mean predictions and noise variance estimates between iterations.

I too am a bit confused as to what they used. I assumed they were just looking at the value of the likelihood, but I think your method would also work well. I have also found the following critique in http://www.comonsens.org/documents/conferences/205_vhgpr_icml.pdf

the algorithm is not guaranteed to converge and may instead oscillate. Furthermore, it may require many iterations (each one requiring to train two standard GPs) before stabilizing. In a related, more recent work (Quadrianto et al., 2009), these issues are addressed by choosing g so as to maximize a penalized likelihood, equivalent up to a constant to p(g|D), thus introducing the MAPHGP approximation.

from botorch.

mc-robinson avatar mc-robinson commented on May 17, 2024 1

Ok great, thanks for all of the helpful comments. I'll try to update the notebook with these suggestions later today.

from botorch.

mc-robinson avatar mc-robinson commented on May 17, 2024 1

Cool, thanks for the advice. I'll try to implement these things and look into the effects of the kernel soon.

from botorch.

mc-robinson avatar mc-robinson commented on May 17, 2024 1

Great, thanks for the feedback @Balandat and @eytan. I've been traveling a bit, but let me first try to get a PR going. After that, I'm happy to write up a tutorial -- and always appreciate your help/input once I get a draft made.

from botorch.

mc-robinson avatar mc-robinson commented on May 17, 2024 1

Hey @Balandat, sorry, been bogged down with a master's thesis due Friday... Happy to finally work on it this weekend though.

That being said, totally happy for you guys to package it up yourselves -- I just want it to be useful! And if you want to do the PR, but still want me to write up the tutorial, happy for that as well. Just let me know what works best!

from botorch.

Balandat avatar Balandat commented on May 17, 2024 1

Yeah. I have an internal set of changes with an updated version but it has some (surmountable) issues and needs to be rebased.

from botorch.

eytan avatar eytan commented on May 17, 2024 1

@Balandat i think this model would get a good deal of usage if we landed it. many folks have heteroskedastic error. we can always refine the model later.

from botorch.

Balandat avatar Balandat commented on May 17, 2024 1

Well “refine” would mean “fix”, at least for the updated internal implementation. So I’d prefer to avoid landing something that We know is broken. What I can do is rebase this and put our a PR for visibility and then plan for some time (and someone) to work on it - while in the meantime hoping for someone lurking here to fix it :)

from botorch.

esantorella avatar esantorella commented on May 17, 2024 1

I don't think anyone at Meta is currently working on this. But thanks for letting us know you're interested! And of course a PR is always welcome :)

from botorch.

jacobrgardner avatar jacobrgardner commented on May 17, 2024

The notebook looks great! a few comments:

  1. In general, GPyTorch example notebooks should be independent of BoTorch's convenience methods for wrapping them and fitting them. In GPyTorch, there is no posterior method. The pattern to get predictions is to call model(test_x) to get p(f*|D, x*) and likelihood(model(test_x)) to get p(y*|D, x*). In your case, this just requires a small change because your likelihood also requires test_x as an input. The following code works for me in your notebook to get the predictive posterior after you have the trained GPyTorch model:
# test on the training points 
# call if X_test just for ease of use in future
X_test = torch.linspace(0,1,100)

mll.eval()
with torch.no_grad():
    posterior_f = mll.model(X_test)   # Call the GPyTorch model directly to get p(f*|D, x*)
    test_pred = mll.likelihood(mll.model(X_test), X_test)   # Call the likelihood on the result to get p(y*|D, x*)
    # above works -- we just needed to pass X_test to the likelihood as well because it depends on the input.

Then you should be able to call test_pred.confidence_region() or test_pred.variance as you'd like. This isn't as well documented as it should be because, as you've probably noticed, we are sorely lacking a heteroscedastic example notebook to document such things :-).

  1. I think this should be possible to build as a likelihood, if slightly more complicated. Basically what you want is a GP that gets trained through its predictive mean while using something like softplus(predictive_mean) as the observation noise for the outer GP, right?

from botorch.

jacobrgardner avatar jacobrgardner commented on May 17, 2024

Whoops! It seems like I'm lost :-). I've been responding to all the issues opened this weekend that when I got another email about an example notebook for another GP model, I just automatically responded assumed we were on the GPyTorch repo.

Apologies! Obviously an example notebook on the botorch repo should definitely be free to use BoTorch features :-). My comments about getting the predictive posterior are still useful, hopefully.

from botorch.

mc-robinson avatar mc-robinson commented on May 17, 2024

Thanks for the detailed responses! And @jacobrgardner, I know it was a mistake haha, but I am also happy to (eventually) try to do a detailed notebook on heteroskedastic GPs in GpyTorch. It might be helpful to show how one might build one from the ground up. I already sort of have a draft ready -- are there any examples currently showing how to embed one model in another such as what is done in the heteroskedastic model in BoTorch?

@Balandat Sorry I missed the observation noise hyperparameter! I will definitely add that to the notebook when I clean it up. Also, I'll reiterate that my way of creating a heteroskedasticGP is basically the simplest/dumbest way I know, but there are lots of schemes that seem to get quite complicated. If anyone here has an idea for something better that isn't too nasty to implement, I would love to hear it.

The iterative scheme found in http://people.csail.mit.edu/kersting/papers/kersting07icml_mlHetGP.pdf is relatively close to what we are already doing. And might not be too hard to implement? I know that there are some issues noted in other papers with convergence and such, but might be worth trying. If a HeteroskedasticGP is built out that doesn't require a train_yvar input, the iteration should be relatively easy to write.

from botorch.

Balandat avatar Balandat commented on May 17, 2024

The "Most Likely Heteroscedastic Gaussian Process Regression:" seems to be pretty much exactly what you're doing, just adding an EM-style fitting process where you iteratively re-estimate the mean and the observation variances from a heteroskedastic model, starting from a homoskedastic mean estimate. This should be just a few lines of code to implement. The only question is what exactly it means for this process to "converge" (the paper doesn't say). It's probably reasonable to just look at the change in mean predictions and noise variance estimates between iterations.

If a HeteroskedasticGP is built out that doesn't require a train_yvar input, the iteration should be relatively easy to write.

We could add a MostLikelyHeteroskedasticGP model that internally just uses the HeteroskedasticGP the same way you are.

from botorch.

Balandat avatar Balandat commented on May 17, 2024

The issue with the observation_noise kwarg not being honored in for BatchedMultiOutputGPyTorchModel models is addressed in #182

from botorch.

Balandat avatar Balandat commented on May 17, 2024

I think your method would also work well

I think this would work for small models, but we will quickly run into scalability issues if we have a latent noise variable for every datapoint. You could think of many other ways to address that issue though, like parameterize the noise function for these latent variables using some kind of spline interpolation...

For now I think the EM-style algorithm would be a good thing to try first. Essentially, the model you end up getting is exactly a Heteroskedastic GP with the specific inferred noise values, so I don't think it's necessary to build a full new model.

Maybe for now it's enough to have a helper function of the form

construct_most_likely_heteroskedastic_GP(single_task_gp, options):
    [...]  # perform iterative fitting...
    return het_GP

that takes in a SingleTaskGP and internally performs the fitting, including the construction fo the Heteroskedastic GP, and then returns that fitted model. Happy to accept a PR :)

from botorch.

mc-robinson avatar mc-robinson commented on May 17, 2024

Hi all,

Sorry for the delay on this one -- but I am finally attaching a slightly updated notebook:
https://colab.research.google.com/drive/1osVrSIfnrm7WWXgsYgP1rR6FDukHNjQF

In it is a first attempt at a MostLikelyHeteroskedasticGP class. As you can see in the notebook, I'm having trouble reaching "convergence" even though the solutions seem to be alright given enough iterations. The original paper, http://people.csail.mit.edu/kersting/papers/kersting07icml_mlHetGP.pdf , mentions that they were usually able to obtain convergence. However, this could be due to their method of estimating noise, which is different than mine (I'm not sure I totally get the reasoning behind their method). Please let me know if you have any suggestions. Hopefully, we can get a pull request started soon.

@Balandat , I also switched over to using the observation_noise parameter of posterior. However, for the heteroskedastic model, I am getting a slightly different result from when I did it by hand. Can you take a look at the differences and let me know if that is just my stupidity?

from botorch.

Balandat avatar Balandat commented on May 17, 2024

Thanks @mc-robinson, I will take a look later today.

from botorch.

Balandat avatar Balandat commented on May 17, 2024

However, for the heteroskedastic model, I am getting a slightly different result from when I did it by hand. Can you take a look at the differences and let me know if that is just my stupidity?

Looks looks like you're doing upper = mean + 2 * sqrt(var_f) + 2 * sqrt(var_obs) here. What you want to do instead is upper = mean + 2 * sqrt(var_f + var_obs), which is what the observation_noise=True kwarg will do for you.

Looking through the rest of the nb now.

from botorch.

Balandat avatar Balandat commented on May 17, 2024

Thanks a lot for expanding on the nb.

However, this could be due to their method of estimating noise, which is different than mine (I'm not sure I totally get the reasoning behind their method). Please let me know if you have any suggestions.

So if I get that right your method of estimating the noise is just doing var(x_j) = (y_obs - post_mean(x_j))**2. I'm not quite sure I understand this. You're not including any variance prediction in this estimate (well, implicitly though the mean estimate, I guess). What the paper does instead is draw s samples from the (joint) posterior (over the function values, not the predictive posterior!) and compute an MC sample of the variance E[(y - mean(y))**2]. This seems like a reasonable approach. You can use MCSampler to do just that:

from botorch.sampling import IIDNormalSampler

sampler = IIDNormalSampler(num_samples=s, resample=True)
posterior = homo_model.posterior(X_train)
samples = sampler(posterior)
observed_var = 0.5 * (samples **2).mean(dim=0)

Another thing to note is that the paper users an RBF kernel for the noise GP, which will smooth things more than the Matern kernel that's used by default in the SingleTaskGP. This could also help with convergence.

A side comment: If you use fit_gpytorch_model instead of fit_gpytorch_scipy then you get a couple of benefits: (i) no need to call train or eval explicitly - inside the function will call train at the beginning and eval after the model fitting, and (ii) there is some logic for retrying the optimization upon failure with initial conditions sampled from the priors - this will improve robustness of the fitting.

Let's see how things go if you switch out the noise estimation procedure - if the results look reasonable and things converge then we can bake this into a PR.

from botorch.

mc-robinson avatar mc-robinson commented on May 17, 2024

Alright, sorry for the delay on this one, but here is the updated notebook:
https://colab.research.google.com/drive/1M--JgobEFdSnrke_j8X-kwNM6S-M9BrM

I have to say I went down a bit of a rabbit hole trying to figure out the variance estimate and why my naive estimate does not work.

Part of my confusion is that I still don't quite understand this statement:

What the paper does instead is draw s samples from the (joint) posterior (over the function values, not the predictive posterior!) and compute an MC sample of the variance E[(y - mean(y))**2].

From my reading of the paper the samples they draw $t_i^{j}$ are indeed from the predictive distribution (see the quote below and the longer one in the notebook).

Consider a sample $t_{i}^{j}$ from the predictive distribution induced by the current GP at $\mathbf{x}{i}$. Viewing $t_i ^{obs}$ and $t{i}^{j}$ as two independent observations of the same noise-free, unknown target, their arithmetic mean $\left(t_{i}^{obs}-t_{i}^{j}\right)^{2} / 2$ is a natural estimate for the noise level at $\mathbf{x}_{i}$. Indeed, we can improve the estimate by taking the expectation with respect to the predictive distribution. >

I actually tried to analytically show that their estimate and my estimate are both consistent. The math is quite rough at the moment, but I think it may show just that. I also plotted the noise estimates for both methods, and the shapes look quite similar (not that that is any rigorous method).

Furthermore, I updated the function to include both estimates. My naive way is a bit quicker since no sampling is needed, but may be a bit worse at capturing the true noise levels. I am working with a colleague on using a good rigorous way to show this such as using negative log estimated
predictive density (NLPD), as was used in the paper.

@Balandat, I was not able to get useful noise estimates from the code you suggested in your previous post, let me know if you can see why.

Please let me know if you see any glaring errors (especially with the math if you have the time to dive into the differing variance estimates, I'll keep looking at it). Otherwise, I think we are getting closer to a usable function.

from botorch.

Balandat avatar Balandat commented on May 17, 2024

Thanks for the update, this looks great!

From my reading of the paper the samples they draw $t_i^{j}$ are indeed from the predictive distribution

I must have misread that part, sorry about that.

My naive way is a bit quicker since no sampling is needed, but may be a bit worse at capturing the true noise levels

Sampling from the posterior (predictive or latent), once computed, is quite cheap on the standard BoTorch models, is there actually a big difference in wall time (per iteration, I see that this doesn't converge as often as your method)?

Regarding the results: It's not overly surprising that you don't fully capture the very high-variance areas. The internal GP for estimating the noise level will smooth out things, so if you have relatively abrupt changes in the noise levels, like the peak you have, it will smooth that out. If you have reason to believe to see these kind of fast changes in the noise levels, you could modify the prior you're putting on the lengthscale of the noise GP (putting more probability mass to lower levels). By default, the kernel used is a 5/2 Matern with GammaPrior(3.0, 6.0) on the lengthscales.

from botorch.

mc-robinson avatar mc-robinson commented on May 17, 2024

Thanks! And no I don't think the difference is that big in these small toy cases, but I am looking into the difference between the variance estimators for some more non-trivial datasets.

And that's a great point about changing the prior -- thanks. But just wondering, is there a reason I cannot pass a covar_module parameter to HeteroskedasticSingleTaskGP? I guess ideally it seems I would want to be able to control the kernel of both the outer model and wrapped noise model?

from botorch.

Balandat avatar Balandat commented on May 17, 2024

is there a reason I cannot pass a covar_module parameter to HeteroskedasticSingleTaskGP? I guess ideally it seems I would want to be able to control the kernel of both the outer model and wrapped noise model?

No particular reason other than simplicity - we didn't want to overload the out-of-the box models with too many options. We may revisit this though and add some optional kwargs to allow specifying the various modules.

If you want more flexibility, for now you could just subclass HeteroskedasticSingleTaskGP and specify the modules you want rather than the ones we set by default. Should be pretty straightforward to swap out the following: https://github.com/pytorch/botorch/blob/master/botorch/models/gp_regression.py#L260-L272 (note that the super().__init__(train_X=train_X, train_Y=train_Y, likelihood=likelihood) calls the constructor of SingleTaskGP which also does some default choices - you could explicitly construct your model in init with other modules as well if you don't want that).

from botorch.

Balandat avatar Balandat commented on May 17, 2024

One thing to be aware of is that the models derived from BatchedMultiOutputGPyTorchModel have some magic going on that automatically moves a multi-output dimension into a separate model batch dimension. This is done to have efficient multi-output models for independent outputs that can automatically exploit parallelized computation.

from botorch.

mc-robinson avatar mc-robinson commented on May 17, 2024

Hi, sorry this took so long -- but here is an updated notebook with the implemented "most likely heteroskedastic GP" from the paper https://colab.research.google.com/drive/1dvFA3LYdLcH0ObQhRvnG1yXCUNgeYIXz . The GP is benchmarked on all of the datasets mentioned on the paper and seems to achieve results either comparable or better to those reported in the paper. I have not done a lot of work to change the default kernels yet, but I figure it would be good to get this basic, working model implemented first.

Note that the most annoying part of the implementation is the normalization and standardization of the independent and dependent variables, respectively. A lot of extra code in the notebook is devoted to to this process specifically.

If it all looks alright to you, I can bake the basic heteroskedastic GP function into a pull request.

from botorch.

Balandat avatar Balandat commented on May 17, 2024

@mc-robinson This is awesome, many thanks for setting this up and doing the benchmarking. I'd love to see this in a PR.

Since this really is a training procedure that returns an existing HeteroskedasticSingleTaskGP, I would suggest calling this function something like fit_most_likely_HeteroskedasticGP, so there is no confusion about MostLikelyHeteroskedasticGP not actually being a Model class.

I would suggest we put this function into botorch/models/utils.py.

What do you think of cleaning up the notebook after the PR goes in and add it as a tutorial notebook under Advanced Usage here?

from botorch.

eytan avatar eytan commented on May 17, 2024

Hi @mc-robinson, +1 to what @Balandat said—this is a very nice writeup and it would be awesome to have it in our main BoTorch tutorials. Happy to provide feedback and contribute to some of the verbiage in the tutorial. Let us know if you need any pointers on how to add a tutorial to the site.

from botorch.

Balandat avatar Balandat commented on May 17, 2024

@mc-robinson, have you made any progress on this PR? I think this is pretty nifty and I'd like to get this in soon. We can also package this up ourselves if that's ok with you.

from botorch.

Balandat avatar Balandat commented on May 17, 2024

If you can put up an PR early next week that would be great. It doesn't have to be complete, we can take care of cleaning it up.

Good luck with the thesis!

from botorch.

mc-robinson avatar mc-robinson commented on May 17, 2024

Thanks! And great, will try to get it up by end of this weekend.

from botorch.

sgbaird avatar sgbaird commented on May 17, 2024

Gotcha!

from botorch.

jakobzeitler avatar jakobzeitler commented on May 17, 2024

Updates on this? :D

from botorch.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.