diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index ca4bbd1951..33b0320808 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -84,7 +84,8 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast: def make_predictions(prediction_net, inputs: dict): # MXNet predictors only support positional arguments - if prediction_net.__class__.__module__.startswith("gluonts.mx"): + class_name = prediction_net.__class__.__module__ + if class_name.startswith("gluonts.mx") or class_name.startswith("mxnet"): return prediction_net(*inputs.values()) else: return prediction_net(**inputs)