Code Monkey home page Code Monkey logo

Comments (4)

dhopp1 avatar dhopp1 commented on May 30, 2024

The library was made with regression in mind, this is the first time I've seen someone try to use it for binary classification. However I was able to reproduce your error with the following code:

from datetime import date, timedelta
import numpy as np
import pandas as pd
from nowcast_lstm.LSTM import LSTM
import torch.nn as nn

n_rows = 20

# making example data
data = pd.DataFrame(np.random.rand(n_rows, 4))
data.columns = ["y", "x1", "x2", "x3"]
data.y = [int(np.round(x)) for x in data.y]
data["date"] = pd.to_datetime([date(2023, 1, 1) + timedelta(days=x) for x in range((timedelta(days = n_rows)).days)])
data = data.loc[:, ["date"] + list(data.columns[:-1])]

# LSTM model
result = None
n_tries = 500
counter = 0
while result is None and counter <= n_tries:
    try:
        counter += 1
        loss = nn.BCELoss()
        model = LSTM(
            data = data, 
            target_variable = "y", 
            n_timesteps = 5,
            n_models = 3,
            train_episodes = 20,
            batch_size = 5,
            criterion = loss
        )
        model.train()
        result = "success"
    except:
         pass

The model is actually able to train every once in a while, so if you wrap the model training in a try loop like I did above you will eventually get a correctly trained model. So this is the quick and dirty potential solution to your problem, though this is of course isn't workable for tuning, variable selection, etc.

The error due to the fact that the model is currently set up to give regression outputs, so the model properly trains when it gets lucky and happens to only produce predictions between 0 and 1, which is what the BCELoss() function requires. In order to get the LSTM to output native predictions between 0 and 1, which would allow proper training, hyperparameter tuning, etc., I have to edit the actual structure of the network in this file.

I'm looking into that now and will get back to you in a bit.

from nowcast_lstm.

dhopp1 avatar dhopp1 commented on May 30, 2024

OK I figured it out. It will take me a little bit to implement into the library though. As I said the issue was the model was outputting a regression output which could frequently be <0 or >1, which is what then broke the BCELoss() function. To rectify that a sigmoid layer needs to be added to the LSTM to restrict outputs to between 0 and 1.

You can already use the fix by cloning a local version of the library and making the following addition to this file:

# model layers
x, self.hidden = self.l_lstm(x, self.hidden)
x = x.contiguous().view(batch_size, -1)  # make tensor of right dimensions
x = self.l_linear(x)
x = torch.sigmoid(x) # !!!this is the new line to add
# model layers

You're just adding this final x = torch.sigmoid(x) layer to the model. If you do this to your local copy and run your code referring to the local library rather than the installed one (I'm not sure of your Python level but let me know if this is not clear) the model will train correctly.

Alternatively you can wait a day or two and I will add this functionality to the published library automatically if the BSELoss() function is passed as the loss function.

from nowcast_lstm.

dhopp1 avatar dhopp1 commented on May 30, 2024

I've added the functionality to v0.2.6. You can update your library and you then should be able to train your model correctly. The library now automatically detects when BCELoss() is passed as the criterion and accordingly adds the sigmoid layer.

from nowcast_lstm.

fdfrontdev avatar fdfrontdev commented on May 30, 2024

I tested it and it worked perfectly. Thank you very much

from nowcast_lstm.

Related Issues (12)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.