Skip to content

Commit 756f61a

Browse files
authored
[ENH] tests for TiDE Model. (#1843)
### Description This PR fixes #1807 and stacks upon PR #1780. Builds upon the PR #1814 (closed due to complex commit history).
1 parent 9eae959 commit 756f61a

File tree

3 files changed

+244
-0
lines changed

3 files changed

+244
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""Tide model."""
22

33
from pytorch_forecasting.models.tide._tide import TiDEModel
4+
from pytorch_forecasting.models.tide._tide_metadata import TiDEModelMetadata
45
from pytorch_forecasting.models.tide.sub_modules import _TideModule
56

67
__all__ = [
78
"_TideModule",
89
"TiDEModel",
10+
"TiDEModelMetadata",
911
]
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""TiDE metadata container."""
2+
3+
from pytorch_forecasting.models.base._base_object import _BasePtForecaster
4+
5+
6+
class TiDEModelMetadata(_BasePtForecaster):
7+
"""Metadata container for TiDE Model."""
8+
9+
_tags = {
10+
"info:name": "TiDEModel",
11+
"info:compute": 3,
12+
"authors": ["Sohaib-Ahmed21"],
13+
"capability:exogenous": True,
14+
"capability:multivariate": True,
15+
"capability:pred_int": True,
16+
"capability:flexible_history_length": True,
17+
"capability:cold_start": False,
18+
}
19+
20+
@classmethod
21+
def get_model_cls(cls):
22+
"""Get model class."""
23+
from pytorch_forecasting.models.tide import TiDEModel
24+
25+
return TiDEModel
26+
27+
@classmethod
28+
def get_test_train_params(cls):
29+
"""Return testing parameter settings for the trainer.
30+
31+
Returns
32+
-------
33+
params : dict or list of dict, default = {}
34+
Parameters to create testing instances of the class.
35+
"""
36+
37+
from pytorch_forecasting.data.encoders import GroupNormalizer
38+
from pytorch_forecasting.metrics import SMAPE
39+
40+
return [
41+
{
42+
"data_loader_kwargs": dict(
43+
add_relative_time_idx=False,
44+
# must include this everytime since the data_loader_default_kwargs
45+
# include this to be True.
46+
)
47+
},
48+
{
49+
"temporal_decoder_hidden": 16,
50+
"data_loader_kwargs": dict(add_relative_time_idx=False),
51+
},
52+
{
53+
"dropout": 0.2,
54+
"use_layer_norm": True,
55+
"loss": SMAPE(),
56+
"data_loader_kwargs": dict(
57+
target_normalizer=GroupNormalizer(
58+
groups=["agency", "sku"], transformation="softplus"
59+
),
60+
add_relative_time_idx=False,
61+
),
62+
},
63+
]

tests/test_models/test_tide.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import pickle
2+
import shutil
3+
4+
import lightning.pytorch as pl
5+
from lightning.pytorch.callbacks import EarlyStopping
6+
from lightning.pytorch.loggers import TensorBoardLogger
7+
import numpy as np
8+
import pandas as pd
9+
import pytest
10+
11+
from pytorch_forecasting.data.timeseries import TimeSeriesDataSet
12+
from pytorch_forecasting.metrics import MAE, SMAPE, QuantileLoss
13+
from pytorch_forecasting.models import TiDEModel
14+
from pytorch_forecasting.tests.test_all_estimators import _integration
15+
from pytorch_forecasting.utils._dependencies import _get_installed_packages
16+
17+
18+
def _tide_integration(dataloaders, tmp_path, trainer_kwargs=None, **kwargs):
19+
"""TiDE specific wrapper around the common integration test function.
20+
21+
Args:
22+
dataloaders: Dictionary of dataloaders for train, val, and test.
23+
tmp_path: Temporary path for saving the model.
24+
trainer_kwargs: Additional arguments for the Trainer.
25+
**kwargs: Additional arguments for the TiDEModel.
26+
27+
Returns:
28+
Predictions from the trained model.
29+
"""
30+
from pytorch_forecasting.tests._data_scenarios import data_with_covariates
31+
32+
df = data_with_covariates()
33+
34+
tide_kwargs = {
35+
"temporal_decoder_hidden": 8,
36+
"temporal_width_future": 4,
37+
"dropout": 0.1,
38+
}
39+
40+
tide_kwargs.update(kwargs)
41+
train_dataset = dataloaders["train"].dataset
42+
43+
data_loader_kwargs = {
44+
"target": train_dataset.target,
45+
"group_ids": train_dataset.group_ids,
46+
"time_varying_known_reals": train_dataset.time_varying_known_reals,
47+
"time_varying_unknown_reals": train_dataset.time_varying_unknown_reals,
48+
"static_categoricals": train_dataset.static_categoricals,
49+
"static_reals": train_dataset.static_reals,
50+
"add_relative_time_idx": train_dataset.add_relative_time_idx,
51+
}
52+
return _integration(
53+
TiDEModel,
54+
df,
55+
tmp_path,
56+
data_loader_kwargs=data_loader_kwargs,
57+
trainer_kwargs=trainer_kwargs,
58+
**tide_kwargs,
59+
)
60+
61+
62+
@pytest.mark.parametrize(
63+
"kwargs",
64+
[
65+
{},
66+
{"loss": SMAPE()},
67+
{"temporal_decoder_hidden": 16},
68+
{"dropout": 0.2, "use_layer_norm": True},
69+
],
70+
)
71+
def test_integration(dataloaders_with_covariates, tmp_path, kwargs):
72+
_tide_integration(dataloaders_with_covariates, tmp_path, **kwargs)
73+
74+
75+
@pytest.mark.parametrize(
76+
"kwargs",
77+
[
78+
{},
79+
],
80+
)
81+
def test_multi_target_integration(dataloaders_multi_target, tmp_path, kwargs):
82+
_tide_integration(dataloaders_multi_target, tmp_path, **kwargs)
83+
84+
85+
@pytest.fixture
86+
def model(dataloaders_with_covariates):
87+
dataset = dataloaders_with_covariates["train"].dataset
88+
net = TiDEModel.from_dataset(
89+
dataset,
90+
hidden_size=16,
91+
dropout=0.1,
92+
temporal_width_future=4,
93+
)
94+
return net
95+
96+
97+
def test_pickle(model):
98+
pkl = pickle.dumps(model)
99+
pickle.loads(pkl) # noqa: S301
100+
101+
102+
@pytest.mark.skipif(
103+
"matplotlib" not in _get_installed_packages(),
104+
reason="skip test if required package matplotlib not installed",
105+
)
106+
def test_prediction_visualization(model, dataloaders_with_covariates):
107+
raw_predictions = model.predict(
108+
dataloaders_with_covariates["val"],
109+
mode="raw",
110+
return_x=True,
111+
fast_dev_run=True,
112+
)
113+
model.plot_prediction(raw_predictions.x, raw_predictions.output, idx=0)
114+
115+
116+
def test_prediction_with_kwargs(model, dataloaders_with_covariates):
117+
# Tests prediction works with different keyword arguments
118+
model.predict(
119+
dataloaders_with_covariates["val"], return_index=True, fast_dev_run=True
120+
)
121+
model.predict(
122+
dataloaders_with_covariates["val"],
123+
return_x=True,
124+
return_y=True,
125+
fast_dev_run=True,
126+
)
127+
128+
129+
def test_no_exogenous_variable():
130+
data = pd.DataFrame(
131+
{
132+
"target": np.ones(1600),
133+
"group_id": np.repeat(np.arange(16), 100),
134+
"time_idx": np.tile(np.arange(100), 16),
135+
}
136+
)
137+
training_dataset = TimeSeriesDataSet(
138+
data=data,
139+
time_idx="time_idx",
140+
target="target",
141+
group_ids=["group_id"],
142+
max_encoder_length=10,
143+
max_prediction_length=5,
144+
time_varying_unknown_reals=["target"],
145+
time_varying_known_reals=[],
146+
)
147+
validation_dataset = TimeSeriesDataSet.from_dataset(
148+
training_dataset, data, stop_randomization=True, predict=True
149+
)
150+
training_data_loader = training_dataset.to_dataloader(
151+
train=True, batch_size=8, num_workers=0
152+
)
153+
validation_data_loader = validation_dataset.to_dataloader(
154+
train=False, batch_size=8, num_workers=0
155+
)
156+
forecaster = TiDEModel.from_dataset(
157+
training_dataset,
158+
)
159+
from lightning.pytorch import Trainer
160+
161+
trainer = Trainer(
162+
max_epochs=2,
163+
limit_train_batches=8,
164+
limit_val_batches=8,
165+
)
166+
trainer.fit(
167+
forecaster,
168+
train_dataloaders=training_data_loader,
169+
val_dataloaders=validation_data_loader,
170+
)
171+
best_model_path = trainer.checkpoint_callback.best_model_path
172+
best_model = TiDEModel.load_from_checkpoint(best_model_path)
173+
best_model.predict(
174+
validation_data_loader,
175+
fast_dev_run=True,
176+
return_x=True,
177+
return_y=True,
178+
return_index=True,
179+
)

0 commit comments

Comments
 (0)