Skip to content

Commit

Permalink
Fix incorrect input routing for models (#3186)
Browse files Browse the repository at this point in the history
Fixes #3185

*Description of changes:*

There is currently a bug where the model inputs may be routed incorrect
by the forecast generator. This effectively results in
`past_feat_dynamic_real` and `past_feat_dynamic_cat` being ignored by
the TFT model.

MWE:
```python
from unittest import mock
import numpy as np
import pandas as pd
from gluonts.torch.model.tft import TemporalFusionTransformerEstimator

freq = "D"
N = 50
data = [
    {"target": np.arange(N), "past_feat_dynamic_real": np.random.rand(1, N).astype("float32"), "start": pd.Period("2020-01-01", freq=freq)}
]

predictor = TemporalFusionTransformerEstimator(prediction_length=1, freq=freq, past_dynamic_dims=[1], trainer_kwargs={"max_epochs": 1}).train(data)

with mock.patch("gluonts.torch.model.tft.module.TemporalFusionTransformerModel._preprocess") as mock_fwd:
    try:
        fcst = list(predictor.predict(data))
    except:
        pass   
    call_kwargs = mock_fwd.call_args[1]

call_kwargs["feat_dynamic_cat"]  
# tensor([[[0.8073]]])
call_kwargs["past_feat_dynamic_real"]  
# None
```

The bug occurs because model inputs are passed as positional arguments
instead of keyword arguments.


By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.


**Please tag this pr with at least one of these labels to make our
release process faster:** BREAKING, new feature, bug fix, other change,
dev setup
  • Loading branch information
shchur authored May 31, 2024
1 parent c4ff443 commit 5e30960
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/gluonts/model/forecast_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast:
raise NotImplementedError


def make_predictions(prediction_net, inputs: dict):
# MXNet predictors only support positional arguments
class_name = prediction_net.__class__.__module__
if class_name.startswith("gluonts.mx") or class_name.startswith("mxnet"):
return prediction_net(*inputs.values())
else:
return prediction_net(**inputs)


class ForecastGenerator:
"""
Classes used to bring the output of a network into a class.
Expand Down Expand Up @@ -115,7 +124,7 @@ def __call__(
) -> Iterator[Forecast]:
for batch in inference_data_loader:
inputs = select(input_names, batch, ignore_missing=True)
(outputs,), loc, scale = prediction_net(*inputs.values())
(outputs,), loc, scale = make_predictions(prediction_net, inputs)
outputs = to_numpy(outputs)
if scale is not None:
outputs = outputs * to_numpy(scale[..., None])
Expand Down Expand Up @@ -159,14 +168,16 @@ def __call__(
) -> Iterator[Forecast]:
for batch in inference_data_loader:
inputs = select(input_names, batch, ignore_missing=True)
outputs = to_numpy(prediction_net(*inputs.values()))
outputs = to_numpy(make_predictions(prediction_net, inputs))
if output_transform is not None:
outputs = output_transform(batch, outputs)
if num_samples:
num_collected_samples = outputs[0].shape[0]
collected_samples = [outputs]
while num_collected_samples < num_samples:
outputs = to_numpy(prediction_net(*inputs.values()))
outputs = to_numpy(
make_predictions(prediction_net, inputs)
)
if output_transform is not None:
outputs = output_transform(batch, outputs)
collected_samples.append(outputs)
Expand Down Expand Up @@ -209,7 +220,7 @@ def __call__(
) -> Iterator[Forecast]:
for batch in inference_data_loader:
inputs = select(input_names, batch, ignore_missing=True)
outputs = prediction_net(*inputs.values())
outputs = make_predictions(prediction_net, inputs)

if output_transform:
log_once(OUTPUT_TRANSFORM_NOT_SUPPORTED_MSG)
Expand Down

0 comments on commit 5e30960

Please sign in to comment.