Comments (4)
I have succeeded to combine the SoftDTWLossPytorch
from tslearn
with NeuralProphet
from neuralprophet
defining:
from tslearn.metrics import SoftDTWLossPyTorch
from tslearn.metrics.soft_dtw_loss_pytorch import _SoftDTWLossPyTorch
def soft_dtw_loss_function(x, y, dist_func=SoftDTWLossPyTorch._euclidean_squared_dist, gamma=0.1):
d_xy = dist_func(x, y)
return _SoftDTWLossPyTorch.apply(d_xy, gamma)
and then:
m = NeuralProphet(loss_func=soft_dtw_loss_function)
from tslearn.
Hello @weidongzhou1994,
NeuralProphet
(https://github.com/ourownstory/neural_prophet/blob/main/neuralprophet/forecaster.py) proposes to use metrics from PyTorch
. Indeed, the class NeuralProphet
has the optional parameter loss_func
:
loss_func : str, torch.nn.functional.loss
Type of loss to use:
Options
* (default) ``Huber``: Huber loss function
* ``MSE``: Mean Squared Error loss function
* ``MAE``: Mean Absolute Error loss function
* ``torch.nn.functional.loss.``: loss or callable for custom loss, eg. L1-Loss
Examples
--------
>>> from neuralprophet import NeuralProphet
>>> import torch
>>> import torch.nn as nn
>>> m = NeuralProphet(loss_func=torch.nn.L1Loss)
However, when I run the code:
from neuralprophet import NeuralProphet
from tslearn.metrics import SoftDTWLossPyTorch
m = NeuralProphet(loss_func=SoftDTWLossPyTorch)
I obtain the following error message:
Traceback (most recent call last):
File "/home/ycabanes/work/tslearn/codes/try_neuralprophet_with_softdtwlosspytorch.py", line 15, in <module>
m = NeuralProphet(loss_func=SoftDTWLossPyTorch)
File "/home/ycabanes/.local/lib/python3.8/site-packages/neuralprophet/forecaster.py", line 398, in __init__
self.config_train = configure.Train(
File "<string>", line 18, in __init__
File "/home/ycabanes/.local/lib/python3.8/site-packages/neuralprophet/configure.py", line 112, in __post_init__
self.set_loss_func()
File "/home/ycabanes/.local/lib/python3.8/site-packages/neuralprophet/configure.py", line 134, in set_loss_func
raise NotImplementedError(f"Loss function {self.loss_func} not found")
NotImplementedError: Loss function <class 'tslearn.metrics.soft_dtw_loss_pytorch.SoftDTWLossPyTorch'> not found
from tslearn.
Here is the full code of a notebook that is running on Google Colab, inspired by a notebook available on NeuralProphet (https://github.com/ourownstory/neural_prophet/blob/main/docs/source/tutorials/tutorial10.ipynb):
Install the modules
try:
import neuralprophet
except ImportError:
!pip install neuralprophet[live]
try:
import tslearn
except ImportError:
!pip install tslearn
Import the modules
import pandas as pd
import torch
from neuralprophet import NeuralProphet, set_log_level
from tslearn.metrics import SoftDTWLossPyTorch
from tslearn.metrics.soft_dtw_loss_pytorch import _SoftDTWLossPyTorch
Define a SoftDTW loss function using tslearn
def soft_dtw_loss_function(x, y, dist_func=SoftDTWLossPyTorch._euclidean_squared_dist, gamma=0.1):
d_xy = dist_func(x, y)
return _SoftDTWLossPyTorch.apply(d_xy, gamma)
# Load the dataset from the CSV file using pandas
df = pd.read_csv("https://github.com/ourownstory/neuralprophet-data/raw/main/kaggle-energy/datasets/tutorial01.csv")
# Disable logging messages unless there is an error
set_log_level("ERROR")
# Model and prediction
m = NeuralProphet(loss_func=soft_dtw_loss_function)
m.set_plotting_backend("plotly")
df_train, df_val = m.split_df(df, valid_p=0.2)
print("Dataset size:", len(df))
print("Train dataset size:", len(df_train))
print("Validation dataset size:", len(df_val))
metrics = m.fit(df_train, validation_df=df_val, progress=None)
metrics
forecast = m.predict(df)
m.plot(forecast)
from tslearn.
This helped me and my team a lot! Thank you @YannCabanes
from tslearn.
Related Issues (20)
- TypeError when using TimeSeriesKMeans from kmeans_plusplus requiring sample_weight argument HOT 1
- 0.6.0 fails to import when pytorch is not installed with ModuleNotFoundError: No module named 'torch'
- Scalers inverse_trasnsform() function HOT 2
- UCR_UEA_datasets().list_datasets() return KEY error
- Compute SoftDTWLossPyTorch with normalization option and time series of different lengths HOT 1
- TimeSeriesKMeans with custom metric HOT 1
- [BUG] `cdist_soft_dtw_normalized` fails unexpectedly when time series panels have different number of instances HOT 3
- [BUG] non-conformance of `metrics.lcss` with input interface expectations (3D numpy) HOT 2
- Cluster Centers are not updating after assigning init HOT 1
- How to use to_time_series_dataset with a multidimensional dataset HOT 1
- Got a message "NoneType has no atribute 'values'" when trying to extract the shapelets HOT 5
- N-dimensional features issue in the method HOT 2
- Columns and DataType Not Explicitly Set on line 552 of cast.py
- LearningShapelets implmentation for imbalanced dataset in the params providing class_weights and loss will be helpful HOT 1
- [BUG] `silhouette_score` crashed with large dataset HOT 2
- lcss similarity is returns unity for all timeseries HOT 1
- Global alignment kernel returns NaN for all timeseries HOT 1
- How to scale cluster centers back in the original scale HOT 1
- Soft DTW with ignore_padding_token HOT 2
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tslearn.