Comments (4)
The library was made with regression in mind, this is the first time I've seen someone try to use it for binary classification. However I was able to reproduce your error with the following code:
from datetime import date, timedelta
import numpy as np
import pandas as pd
from nowcast_lstm.LSTM import LSTM
import torch.nn as nn
n_rows = 20
# making example data
data = pd.DataFrame(np.random.rand(n_rows, 4))
data.columns = ["y", "x1", "x2", "x3"]
data.y = [int(np.round(x)) for x in data.y]
data["date"] = pd.to_datetime([date(2023, 1, 1) + timedelta(days=x) for x in range((timedelta(days = n_rows)).days)])
data = data.loc[:, ["date"] + list(data.columns[:-1])]
# LSTM model
result = None
n_tries = 500
counter = 0
while result is None and counter <= n_tries:
try:
counter += 1
loss = nn.BCELoss()
model = LSTM(
data = data,
target_variable = "y",
n_timesteps = 5,
n_models = 3,
train_episodes = 20,
batch_size = 5,
criterion = loss
)
model.train()
result = "success"
except:
pass
The model is actually able to train every once in a while, so if you wrap the model training in a try loop like I did above you will eventually get a correctly trained model. So this is the quick and dirty potential solution to your problem, though this is of course isn't workable for tuning, variable selection, etc.
The error due to the fact that the model is currently set up to give regression outputs, so the model properly trains when it gets lucky and happens to only produce predictions between 0 and 1, which is what the BCELoss() function requires. In order to get the LSTM to output native predictions between 0 and 1, which would allow proper training, hyperparameter tuning, etc., I have to edit the actual structure of the network in this file.
I'm looking into that now and will get back to you in a bit.
from nowcast_lstm.
OK I figured it out. It will take me a little bit to implement into the library though. As I said the issue was the model was outputting a regression output which could frequently be <0 or >1, which is what then broke the BCELoss() function. To rectify that a sigmoid layer needs to be added to the LSTM to restrict outputs to between 0 and 1.
You can already use the fix by cloning a local version of the library and making the following addition to this file:
# model layers
x, self.hidden = self.l_lstm(x, self.hidden)
x = x.contiguous().view(batch_size, -1) # make tensor of right dimensions
x = self.l_linear(x)
x = torch.sigmoid(x) # !!!this is the new line to add
# model layers
You're just adding this final x = torch.sigmoid(x)
layer to the model. If you do this to your local copy and run your code referring to the local library rather than the installed one (I'm not sure of your Python level but let me know if this is not clear) the model will train correctly.
Alternatively you can wait a day or two and I will add this functionality to the published library automatically if the BSELoss() function is passed as the loss function.
from nowcast_lstm.
I've added the functionality to v0.2.6. You can update your library and you then should be able to train your model correctly. The library now automatically detects when BCELoss()
is passed as the criterion and accordingly adds the sigmoid layer.
from nowcast_lstm.
I tested it and it worked perfectly. Thank you very much
from nowcast_lstm.
Related Issues (12)
- No module named 'pmdarima' HOT 3
- Code error help HOT 1
- evaluation and accuracy of model. HOT 1
- LSTM model parameters HOT 1
- Lag of Target Variable HOT 1
- ModuleNotFoundError: No module named 'nowcast_lstm.LSTM'; 'nowcast_lstm' is not a package HOT 6
- ARDL and LSTM HOT 1
- forecasting out of test date HOT 1
- list index out of range
- missing values in mix frequencies data HOT 7
- Early stop to prevent overfitting HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from nowcast_lstm.