Code Monkey home page Code Monkey logo

Comments (4)

YannCabanes avatar YannCabanes commented on June 2, 2024 1

I have succeeded to combine the SoftDTWLossPytorch from tslearn with NeuralProphet from neuralprophet defining:

from tslearn.metrics import SoftDTWLossPyTorch
from tslearn.metrics.soft_dtw_loss_pytorch import _SoftDTWLossPyTorch

def soft_dtw_loss_function(x, y, dist_func=SoftDTWLossPyTorch._euclidean_squared_dist, gamma=0.1):
    d_xy = dist_func(x, y)
    return _SoftDTWLossPyTorch.apply(d_xy, gamma)

and then:

m = NeuralProphet(loss_func=soft_dtw_loss_function)

from tslearn.

YannCabanes avatar YannCabanes commented on June 2, 2024

Hello @weidongzhou1994,

NeuralProphet (https://github.com/ourownstory/neural_prophet/blob/main/neuralprophet/forecaster.py) proposes to use metrics from PyTorch. Indeed, the class NeuralProphet has the optional parameter loss_func:

        loss_func : str, torch.nn.functional.loss
            Type of loss to use:

            Options
                * (default) ``Huber``: Huber loss function
                * ``MSE``: Mean Squared Error loss function
                * ``MAE``: Mean Absolute Error loss function
                * ``torch.nn.functional.loss.``: loss or callable for custom loss, eg. L1-Loss

            Examples
            --------
            >>> from neuralprophet import NeuralProphet
            >>> import torch
            >>> import torch.nn as nn
            >>> m = NeuralProphet(loss_func=torch.nn.L1Loss)

However, when I run the code:

from neuralprophet import NeuralProphet
from tslearn.metrics import SoftDTWLossPyTorch
m = NeuralProphet(loss_func=SoftDTWLossPyTorch)

I obtain the following error message:

Traceback (most recent call last):
  File "/home/ycabanes/work/tslearn/codes/try_neuralprophet_with_softdtwlosspytorch.py", line 15, in <module>
    m = NeuralProphet(loss_func=SoftDTWLossPyTorch)
  File "/home/ycabanes/.local/lib/python3.8/site-packages/neuralprophet/forecaster.py", line 398, in __init__
    self.config_train = configure.Train(
  File "<string>", line 18, in __init__
  File "/home/ycabanes/.local/lib/python3.8/site-packages/neuralprophet/configure.py", line 112, in __post_init__
    self.set_loss_func()
  File "/home/ycabanes/.local/lib/python3.8/site-packages/neuralprophet/configure.py", line 134, in set_loss_func
    raise NotImplementedError(f"Loss function {self.loss_func} not found")
NotImplementedError: Loss function <class 'tslearn.metrics.soft_dtw_loss_pytorch.SoftDTWLossPyTorch'> not found

from tslearn.

YannCabanes avatar YannCabanes commented on June 2, 2024

Here is the full code of a notebook that is running on Google Colab, inspired by a notebook available on NeuralProphet (https://github.com/ourownstory/neural_prophet/blob/main/docs/source/tutorials/tutorial10.ipynb):

Install the modules

try:
    import neuralprophet
except ImportError:
    !pip install neuralprophet[live]

try:
    import tslearn
except ImportError:
    !pip install tslearn

Import the modules

import pandas as pd
import torch
from neuralprophet import NeuralProphet, set_log_level
from tslearn.metrics import SoftDTWLossPyTorch
from tslearn.metrics.soft_dtw_loss_pytorch import _SoftDTWLossPyTorch

Define a SoftDTW loss function using tslearn

def soft_dtw_loss_function(x, y, dist_func=SoftDTWLossPyTorch._euclidean_squared_dist, gamma=0.1):
    d_xy = dist_func(x, y)
    return _SoftDTWLossPyTorch.apply(d_xy, gamma)
# Load the dataset from the CSV file using pandas
df = pd.read_csv("https://github.com/ourownstory/neuralprophet-data/raw/main/kaggle-energy/datasets/tutorial01.csv")

# Disable logging messages unless there is an error
set_log_level("ERROR")

# Model and prediction
m = NeuralProphet(loss_func=soft_dtw_loss_function)
m.set_plotting_backend("plotly")

df_train, df_val = m.split_df(df, valid_p=0.2)

print("Dataset size:", len(df))
print("Train dataset size:", len(df_train))
print("Validation dataset size:", len(df_val))

metrics = m.fit(df_train, validation_df=df_val, progress=None)
metrics

forecast = m.predict(df)
m.plot(forecast)

from tslearn.

anandaheino avatar anandaheino commented on June 2, 2024

This helped me and my team a lot! Thank you @YannCabanes

from tslearn.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.