Skip to content

Commit

Permalink
Fix input for MXNet models
Browse files Browse the repository at this point in the history
  • Loading branch information
shchur committed May 31, 2024
1 parent 25812d9 commit 4022835
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions src/gluonts/model/forecast_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,14 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast:
raise NotImplementedError


def make_predictions(prediction_net, inputs: dict):
# MXNet predictors only support positional arguments
if prediction_net.__class__.__module__.startswith("gluonts.mx"):
return prediction_net(*inputs.values())
else:
return prediction_net(**inputs)


class ForecastGenerator:
"""
Classes used to bring the output of a network into a class.
Expand Down Expand Up @@ -115,7 +123,7 @@ def __call__(
) -> Iterator[Forecast]:
for batch in inference_data_loader:
inputs = select(input_names, batch, ignore_missing=True)
(outputs,), loc, scale = prediction_net(**inputs)
(outputs,), loc, scale = make_predictions(prediction_net, inputs)
outputs = to_numpy(outputs)
if scale is not None:
outputs = outputs * to_numpy(scale[..., None])
Expand Down Expand Up @@ -159,14 +167,16 @@ def __call__(
) -> Iterator[Forecast]:
for batch in inference_data_loader:
inputs = select(input_names, batch, ignore_missing=True)
outputs = to_numpy(prediction_net(**inputs))
outputs = to_numpy(make_predictions(prediction_net, inputs))
if output_transform is not None:
outputs = output_transform(batch, outputs)
if num_samples:
num_collected_samples = outputs[0].shape[0]
collected_samples = [outputs]
while num_collected_samples < num_samples:
outputs = to_numpy(prediction_net(**inputs))
outputs = to_numpy(
make_predictions(prediction_net, inputs)
)
if output_transform is not None:
outputs = output_transform(batch, outputs)
collected_samples.append(outputs)
Expand Down Expand Up @@ -209,7 +219,7 @@ def __call__(
) -> Iterator[Forecast]:
for batch in inference_data_loader:
inputs = select(input_names, batch, ignore_missing=True)
outputs = prediction_net(**inputs)
outputs = make_predictions(prediction_net, inputs)

if output_transform:
log_once(OUTPUT_TRANSFORM_NOT_SUPPORTED_MSG)
Expand Down

0 comments on commit 4022835

Please sign in to comment.