Code Monkey home page Code Monkey logo

attentive-neural-processes's People

Contributors

wassname avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

attentive-neural-processes's Issues

ANP-RNN 'use_deterministic_path'= False?

Hi there,

I am just wondering if you forgot to set 'use_deterministic_path'= True for ANP-RNN, since in the paper the authors indicate so, and you obviously have already set cross-validations.

https://github.com/3springs/attentive-neural-processes/blob/af431a267bad309b2d5698f25551986e2c4e7815/neural_processes/models/neural_process/lightning.py#L189-L195

BTW, I was trying to replicate the experiment on GP dataset, and I have already implemented an ANP. Although the architecture of mine is slightly different from yours, I assume I will just need to replace all the MLPs with LSTMs plus sequential encodings of the input, output, right? However, mine result was really bad, when I sorted my x_context, x_target, the model only seems to be able to predict a few points and predicts flat curves afterwards. Could you please share with me any hints regarding this?

Your help is very much appreciated

1d regression example

I was trying to run the 1d regression notebook straight out of the box as provided. Ran into this issue of size mismatch. Any help would be appreciated. Thank you

ValueError Traceback (most recent call last)

in ()
32 optim.zero_grad()
33 y_pred, kl, loss, mse_loss, y_std = model(context_x, context_y, target_x,
---> 34 target_y)
35 loss.backward()
36 optim.step()

ValueError: not enough values to unpack (expected 5, got 3)

No TensorBoard logs from smartmeters-ANP-RNN[-mcdropout].ipynb

Thanks for the awesome repo! I ran the notebooks smartmeters-ANP-RNN.ipynb and smartmeters-ANP-RNN-mcdropout.ipynb which instructed to run tensorboard --logdir ${MODEL_DIR} but there were no records found.

I tried replacing DictLogger with the vanilla TensorBoardLogger but this didn't change anything I could see.
There were no output .tfevents in MODEL_DIR (only a model checkpoint).
image

Still anp-rnn_1d_regression.ipynb logged to TensorBoard fine, using SummaryWriter directly:
image

Although not a bug, I was also wondering why the training looked unstable from these plots :) The ANN-RNP paper reported pretty stable convergence:
image

Thanks a lot for your time! I'll report back here if I find anything new.

Fixing the error in kl_loss_var function

https://github.com/3springs/attentive-neural-processes/blob/016272a077a19bc51d145d1ad99d910477458876/neural_processes/utils.py#L167

There is an issue with the computation of kl-divergence when using kl_loss_var function. I think the fix would be by removing the second ( before var_ratio_log.exp(). The update would look like:

def kl_loss_var(prior_mu, log_var_prior, post_mu, log_var_post):
    var_ratio_log = log_var_post - log_var_prior
    kl_div = (
         var_ratio_log.exp() + ((post_mu - prior_mu) ** 2) / log_var_prior.exp()
        - 1.0
        - var_ratio_log
       )
    kl_div = 0.5 * kl_div

Otherwise, using torch.distributions.kl_divergence(z_post_dist, z_prior_dist) where

z_prior_dist =  torch.distributions.normal.Normal(mu_c, sigma_c) # mu_c, sigma_c are the computed mean and standard deviation using contexts
z_post_dist =  torch.distributions.normal.Normal(mu_t, sigma_t) # mu_t, sigma_t are the computed mean and standard deviation using targets

would do the job.

PS. thank you for this open-source implementation! ๐Ÿ‘

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.