Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix MXNet tests
Browse files Browse the repository at this point in the history
shchur committed May 31, 2024
1 parent 4022835 commit e8afdde
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/gluonts/model/forecast_generator.py
Original file line number Diff line number Diff line change
@@ -84,7 +84,8 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast:

def make_predictions(prediction_net, inputs: dict):
# MXNet predictors only support positional arguments
if prediction_net.__class__.__module__.startswith("gluonts.mx"):
class_name = prediction_net.__class__.__module__
if class_name.startswith("gluonts.mx") or class_name.startswith("mxnet"):
return prediction_net(*inputs.values())
else:
return prediction_net(**inputs)

0 comments on commit e8afdde

Please sign in to comment.