Code Monkey home page Code Monkey logo

Comments (4)

lostella avatar lostella commented on May 12, 2024 1

#125 solves this issues: now features must be explicitly enabled, when constructing the estimator, in order for the model to use them.

from gluonts.

lostella avatar lostella commented on May 12, 2024

@houchangtao One issue with your example is that you're feeding a list directly as a dataset for training.

Using the ListDataset would be appropriate here, but this results in an error as well:

import pandas as pd
url = "https://raw.githubusercontent.com/numenta/NAB/master/data/realTweets/Twitter_volume_AMZN.csv"
df = pd.read_csv(url, header=0, index_col=0)
from gluonts.dataset.common import ListDataset
training_data = ListDataset(
    data_iter=[
        {"start": pd.Timestamp(df.index[0], freq='5min'), "target": df.value[:"2015-04-05 00:00:00"],
         "feat_dynamic_real": pd.to_datetime(df[:"2015-04-05 00:00:00"].index).dayofweek.values},
        {"start": pd.Timestamp(df.index[0], freq='5min'), "target": df.value[:"2015-04-10 00:00:00"],
         "feat_dynamic_real": pd.to_datetime(df[:"2015-04-10 00:00:00"].index).dayofweek.values}
    ],
    freq="5min"
)
from gluonts.model.deepar import DeepAREstimator
from gluonts.trainer import Trainer
from gluonts.distribution import NegativeBinomialOutput
estimator = DeepAREstimator(freq="5min", prediction_length=12, distr_output=NegativeBinomialOutput(),
                            trainer=Trainer(epochs=10))
predictor = estimator.train(training_data=training_data)

Results in:

Traceback (most recent call last):
  File "/Users/stellalo/gluon-ts/temp/run_issue_94.py", line 19, in <module>
    predictor = estimator.train(training_data=training_data)
  File "/Users/stellalo/gluon-ts/src/gluonts/model/estimator.py", line 189, in train
    training_transformation, trained_net = self.train_model(training_data)
  File "/Users/stellalo/gluon-ts/src/gluonts/model/estimator.py", line 182, in train_model
    train_iter=training_data_loader,
  File "/Users/stellalo/gluon-ts/src/gluonts/trainer/_base.py", line 251, in __call__
    for batch_no, data_entry in enumerate(it, start=1):
  File "/Users/stellalo/.virtualenvs/gluonts/lib/python3.6/site-packages/tqdm/_tqdm.py", line 930, in __iter__
    for obj in iterable:
  File "/Users/stellalo/gluon-ts/src/gluonts/dataset/loader.py", line 195, in __iter__
    self.batch_size - 1
  File "/Users/stellalo/gluon-ts/src/gluonts/dataset/loader.py", line 162, in _emit_batches_while_buffer_larger_than
    yield self._buffer.next_batch()
  File "/Users/stellalo/gluon-ts/src/gluonts/dataset/loader.py", line 54, in next_batch
    batch = {k: self.stack(v[:n]) for k, v in self._buffers.items()}
  File "/Users/stellalo/gluon-ts/src/gluonts/dataset/loader.py", line 54, in <dictcomp>
    batch = {k: self.stack(v[:n]) for k, v in self._buffers.items()}
  File "/Users/stellalo/gluon-ts/src/gluonts/dataset/loader.py", line 62, in stack
    data = np.asarray(xs)
  File "/Users/stellalo/.virtualenvs/gluonts/lib/python3.6/site-packages/numpy/core/numeric.py", line 492, in asarray
    return array(a, dtype, copy=False, order=order)
ValueError: could not broadcast input array from shape (10684) into shape (1)

Process finished with exit code 1

from gluonts.

lostella avatar lostella commented on May 12, 2024

MWE:

from gluonts.dataset.common import ListDataset
from gluonts.model.deepar import DeepAREstimator
from gluonts.trainer import Trainer
training_data = ListDataset(
    data_iter=[
        {"start": "2019-01-01 00:00:00", "target": [1.0, 2.0, 3.0, 4.0],
         "feat_dynamic_real": [1.0, 2.0, 3.0, 4.0]},
        {"start": "2019-01-01 00:00:00", "target": [1.0, 2.0, 3.0, 4.0, 5.0],
         "feat_dynamic_real": [1.0, 2.0, 3.0, 4.0, 5.0]},
    ],
    freq="5min"
)
estimator = DeepAREstimator(freq="5min", prediction_length=2,
                            trainer=Trainer(epochs=10))
predictor = estimator.train(training_data=training_data)

from gluonts.

lostella avatar lostella commented on May 12, 2024

The problem is that the "feat_dynamic_real" field is processed as part of the ListDataset, but not stacked together with other time-dependent fields in the transformation chain. Therefore when a batch of data is formed this field is stacked as-is between different entries in the dataset.

As far as I can see, solving this issues means either:

  1. (easier) at the beginning of the transformation, explicitly filtering out any fields which will not be consumed
  2. (harder) completing the transformation chain of DeepAREstimator (and potentially other models), and make it consume all possible fields, possibly defaulting the missing ones (so that no KeyErrors are raised)

from gluonts.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.