Comments (4)
Hi there,
Thanks for reaching out! I need to be clearer about this, I haven't had time to join together the two scripts yet. I'll get back to you ASAP with an updated answer but for now:
init_alpha: -785.866918162 is an error. (<0)
Note that for big magnitutes of alpha mean of tte is same as the complex estimate using log etc
Furthermore
- Initialization is important. Gradients explode if you're too far off. More censored data leads to higher probability of exploding grad initially.
- Learning rate is dependent on data and can be in magnitudes you didn't expect
- Are you feeding in masked steps? Varying length sequences has no clean implementation atm, haven't had time to get masking layer to work. Current solution: set n_timesteps = None and run training step with one input sequence with something like:
OBS NOT TESTED:
def epoch():
for i in xrange(n_samples):
model.fit(x_train[i,:seq_length[i],:], y_train[i,:seq_length[i],:],
epochs=1,
batch_size=1,
verbose=2
)
But even better debug-mode initially is to simply transform the data to [n_non_masked_samples,1,n_features] (feed in only seen timesteps) to a simple ANN and when that works test the RNN.
Would love to see forks!
from wtte-rnn.
There's multiple reasons for NANs to show up but just found a very important:
shift_discrete_padded_features
is currently broken which is supposed to hide target but apparently doesn't. This means that if input is "event" then it's possible to make a perfect prediction, causing exploding gradient
I'm trying to fix it asap
from wtte-rnn.
Hi Egil,
Thanks for the update! Here's a fork with the notebook Combined_data_pipeline_and_analysis in examples/keras.
https://github.com/NataliaVConnolly/wtte-rnn-1
The last cell shows an example of training with just one input sequence. It does result in a non-NaN loss, although a very large one (but I didn't optimize the initial alpha or the network config much).
Cheers,
Natalia (aka hedgy123 :))
from wtte-rnn.
@NataliaVConnolly Sorry for the wait. It took me some time to figure out what was wrong!
- Too much censoring leads to instability. Works when using more frequent committers, <50% censoring. In the example I use only those who committed at least 10 days.
- You train on one subject but initialize alpha using the mean over all subjects. This leads to high probability of exploding gradient.
- As mentioned above, if it was done before the fix of
shift_discrete_padded_features
that would also lead to NaN (perfect fit) after some training.
Check out the new data_pipeline and let me know if you have more questions! :)
from wtte-rnn.
Related Issues (20)
- Event with duration
- Is it applicable for my dataset HOT 1
- Loss Function - Not the PCF? HOT 2
- Keras and Theano why? HOT 1
- Log-likelihood for discrete Weibull distribution HOT 3
- c-index
- wtte.pipelines.data_pipeline returns wrong seq_ids
- possible memory issue with large data
- Weird Beta outputs
- Stability of loss function for left censored data HOT 1
- References of success of the WTTE-RNN structure?
- multi variate time series : we have categorical and continues data
- Why do you use a log in the discrete weibull loss function?
- How to use the model to predict
- Porting WTTE-RNN to PyTorch HOT 2
- Numerical instability parameterization tricks
- How to label for "time to the next event" ?
- will it work for multivariate time series prediction both regression and classification
- preparation data for churn prediction HOT 1
- how would one support 3 labels: win, loss, censored?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from wtte-rnn.