Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TemporalFusionTransformerEstimator (torch) fails when using PAST_FEAT_DYNAMIC_CAT #3215

Open
jmberutich opened this issue Sep 5, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@jmberutich
Copy link

Description

When using a PAST_FEAT_DYNAMIC_CAT in the torch implementation of TemporalFusionTransformerEstimator, it fails. Commenting the past_dynamic_cardinalities argument makes the example work

To Reproduce

(Please provide minimal example of code snippet that reproduces the error. For existing examples, please provide link.)

from gluonts.dataset.common import ListDataset
import pandas as pd
import numpy as np

# Example data
data = [
    {
        "start": pd.Timestamp("2021-01-01 00:00:00"),
        "target": np.random.rand(100),
        "feat_static_real": [1.0],
        "feat_static_cat": [0],
        "feat_dynamic_real": [np.random.rand(100)],
        "feat_dynamic_cat": [np.random.randint(0, 10, size=100)],
        "past_feat_dynamic_real": [np.random.rand(100)],
        "past_feat_dynamic_cat": [np.random.randint(0, 10, size=100)],
    }
]

dataset = ListDataset(data, freq="1H")

from gluonts.torch.model.tft import TemporalFusionTransformerEstimator
from gluonts.torch.distributions import QuantileOutput

# Define the estimator
estimator = TemporalFusionTransformerEstimator(
    freq="1H",
    prediction_length=24,
    context_length=24,
    quantiles=[0.1, 0.5, 0.9],
    num_heads=4,
    hidden_dim=32,
    variable_dim=32,
    static_dims=[1],  # Size of feat_static_real
    dynamic_dims=[1],  # Size of feat_dynamic_real
    past_dynamic_dims=[1],  # Size of past_feat_dynamic_real
    static_cardinalities=[1],  # Cardinality of feat_static_cat
    dynamic_cardinalities=[10],  # Cardinality of feat_dynamic_cat
    past_dynamic_cardinalities=[10],  # Cardinality of past_feat_dynamic_cat
    time_features=None,
    lr=0.001,
    weight_decay=1e-8,
    dropout_rate=0.1,
    patience=10,
    batch_size=32,
    num_batches_per_epoch=5,
)

predictor = estimator.train(dataset)

Error message or code output

(Paste the complete error message, including stack trace, or the undesired output that the above snippet produces.)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/home/coder/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/configuration_validator.py:74: You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.

  | Name  | Type                           | Params | In sizes                                                                           | Out sizes                     
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
0 | model | TemporalFusionTransformerModel | 156 K  | [[1, 24], [1, 24], [1, 1], [1, 1], [1, 48, 5], [1, 48, 1], [1, 24, 1], [1, 24, 1]] | [[[1, 24, 3]], [1, 1], [1, 1]]
-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
156 K     Trainable params
0         Non-trainable params
156 K     Total params
0.624     Total estimated model params size (MB)
Output exceeds the [size limit](command:workbench.action.openSettings?[). Open the full output data [in a text editor](command:workbench.action.openLargeOutput?73bf1843-cf5a-4e9a-b065-41980b2d9f71)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[17], line 1
----> 1 predictor = estimator.train(dataset)

File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/gluonts/torch/model/estimator.py:246, in PyTorchLightningEstimator.train(self, training_data, validation_data, shuffle_buffer_length, cache_data, ckpt_path, **kwargs)
    237 def train(
    238     self,
    239     training_data: Dataset,
   (...)
    244     **kwargs,
    245 ) -> PyTorchPredictor:
--> 246     return self.train_model(
    247         training_data,
    248         validation_data,
    249         shuffle_buffer_length=shuffle_buffer_length,
    250         cache_data=cache_data,
    251         ckpt_path=ckpt_path,
    252     ).predictor

File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/gluonts/torch/model/estimator.py:209, in PyTorchLightningEstimator.train_model(self, training_data, validation_data, from_predictor, shuffle_buffer_length, cache_data, ckpt_path, **kwargs)
    200 custom_callbacks = self.trainer_kwargs.pop("callbacks", [])
    201 trainer = pl.Trainer(
    202     **{
    203         "accelerator": "auto",
   (...)
    206     }
    207 )
--> 209 trainer.fit(
    210     model=training_network,
    211     train_dataloaders=training_data_loader,
    212     val_dataloaders=validation_data_loader,
    213     ckpt_path=ckpt_path,
    214 )
    216 if checkpoint.best_model_path != "":
    217     logger.info(
    218         f"Loading best model from {checkpoint.best_model_path}"
    219     )

File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:544, in Trainer.fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    542 self.state.status = TrainerStatus.RUNNING
    543 self.training = True
--> 544 call._call_and_handle_interrupt(
    545     self, self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    546 )

File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:44, in _call_and_handle_interrupt(trainer, trainer_fn, *args, **kwargs)
     42     if trainer.strategy.launcher is not None:
     43         return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
---> 44     return trainer_fn(*args, **kwargs)
     46 except _TunerExitException:
     47     _call_teardown_hook(trainer)

File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:580, in Trainer._fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    573 assert self.state.fn is not None
    574 ckpt_path = self._checkpoint_connector._select_ckpt_path(
    575     self.state.fn,
    576     ckpt_path,
    577     model_provided=True,
    578     model_connected=self.lightning_module is not None,
    579 )
--> 580 self._run(model, ckpt_path=ckpt_path)
    582 assert self.state.stopped
    583 self.training = False

File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:989, in Trainer._run(self, model, ckpt_path)
    984 self._signal_connector.register_signal_handlers()
    986 # ----------------------------
    987 # RUN THE TRAINER
    988 # ----------------------------
--> 989 results = self._run_stage()
    991 # ----------------------------
    992 # POST-Training CLEAN UP
    993 # ----------------------------
    994 log.debug(f"{self.__class__.__name__}: trainer tearing down")

File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py:1035, in Trainer._run_stage(self)
   1033         self._run_sanity_check()
   1034     with torch.autograd.set_detect_anomaly(self._detect_anomaly):
-> 1035         self.fit_loop.run()
   1036     return None
   1037 raise RuntimeError(f"Unexpected state {self.state}")

File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:194, in _FitLoop.run(self)
    193 def run(self) -> None:
--> 194     self.setup_data()
    195     if self.skip:
    196         return

File ~/projects/efds-fcpf-forecasting-engine/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:258, in _FitLoop.setup_data(self)
    256 self._data_fetcher = _select_data_fetcher(trainer, RunningStage.TRAINING)
    257 self._data_fetcher.setup(combined_loader)
--> 258 iter(self._data_fetcher)  # creates the iterator inside the fetcher
    259 max_batches = sized_len(combined_loader)
    260 self.max_batches = max_batches if max_batches is not None else float("inf")
...
    547             fill_value=self.dummy_value,
    548             dtype=d[field].dtype,
    549         )

TypeError: list indices must be integers or slices, not tuple

Environment

  • Operating system: Debian GNU/Linux 11 (bullseye)
  • Python version: 3.11
  • GluonTS version: 0.15.1
  • Torch version: 2.4.0

(Add as much information about your environment as possible, e.g. dependencies versions.)
pytorch_lightning 2.1.4

@jmberutich jmberutich added the bug Something isn't working label Sep 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants
@jmberutich and others