wassname / attentive-neural-processes Goto Github PK
View Code? Open in Web Editor NEWimplementing "recurrent attentive neural processes" to forecast power usage (w. LSTM baseline, MCDropout)
License: Apache License 2.0
implementing "recurrent attentive neural processes" to forecast power usage (w. LSTM baseline, MCDropout)
License: Apache License 2.0
Hi there,
I am just wondering if you forgot to set 'use_deterministic_path'= True for ANP-RNN, since in the paper the authors indicate so, and you obviously have already set cross-validations.
BTW, I was trying to replicate the experiment on GP dataset, and I have already implemented an ANP. Although the architecture of mine is slightly different from yours, I assume I will just need to replace all the MLPs with LSTMs plus sequential encodings of the input, output, right? However, mine result was really bad, when I sorted my x_context, x_target, the model only seems to be able to predict a few points and predicts flat curves afterwards. Could you please share with me any hints regarding this?
Your help is very much appreciated
Hi,
Thank you for such a great and clean implementation.
It seems that anp-rnn_1d_regression.ipynb
is missing from the repo. It is mentioned in the Usage section in your Readme.
https://github.com/3springs/attentive-neural-processes/blob/master/anp-rnn_1d_regression.ipynb
Regards & thanks
Kapil
I was trying to run the 1d regression notebook straight out of the box as provided. Ran into this issue of size mismatch. Any help would be appreciated. Thank you
ValueError Traceback (most recent call last)
in ()
32 optim.zero_grad()
33 y_pred, kl, loss, mse_loss, y_std = model(context_x, context_y, target_x,
---> 34 target_y)
35 loss.backward()
36 optim.step()
ValueError: not enough values to unpack (expected 5, got 3)
Thanks for the awesome repo! I ran the notebooks smartmeters-ANP-RNN.ipynb
and smartmeters-ANP-RNN-mcdropout.ipynb
which instructed to run tensorboard --logdir ${MODEL_DIR}
but there were no records found.
I tried replacing DictLogger
with the vanilla TensorBoardLogger
but this didn't change anything I could see.
There were no output .tfevents
in MODEL_DIR
(only a model checkpoint).
anp-rnn_1d_regression.ipynb
logged to TensorBoard fine, using SummaryWriter
directly:Although not a bug, I was also wondering why the training looked unstable from these plots :) The ANN-RNP paper reported pretty stable convergence:
Thanks a lot for your time! I'll report back here if I find anything new.
There is an issue with the computation of kl-divergence when using kl_loss_var
function. I think the fix would be by removing the second (
before var_ratio_log.exp()
. The update would look like:
def kl_loss_var(prior_mu, log_var_prior, post_mu, log_var_post):
var_ratio_log = log_var_post - log_var_prior
kl_div = (
var_ratio_log.exp() + ((post_mu - prior_mu) ** 2) / log_var_prior.exp()
- 1.0
- var_ratio_log
)
kl_div = 0.5 * kl_div
Otherwise, using torch.distributions.kl_divergence(z_post_dist, z_prior_dist)
where
z_prior_dist = torch.distributions.normal.Normal(mu_c, sigma_c) # mu_c, sigma_c are the computed mean and standard deviation using contexts
z_post_dist = torch.distributions.normal.Normal(mu_t, sigma_t) # mu_t, sigma_t are the computed mean and standard deviation using targets
would do the job.
PS. thank you for this open-source implementation! ๐
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.