From e8afdde9d99ff6735578d156b47932e9addea5d7 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Fri, 31 May 2024 12:28:13 +0000 Subject: [PATCH] Fix MXNet tests --- src/gluonts/model/forecast_generator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/gluonts/model/forecast_generator.py b/src/gluonts/model/forecast_generator.py index ca4bbd1951..33b0320808 100644 --- a/src/gluonts/model/forecast_generator.py +++ b/src/gluonts/model/forecast_generator.py @@ -84,7 +84,8 @@ def make_distribution_forecast(distr, *args, **kwargs) -> Forecast: def make_predictions(prediction_net, inputs: dict): # MXNet predictors only support positional arguments - if prediction_net.__class__.__module__.startswith("gluonts.mx"): + class_name = prediction_net.__class__.__module__ + if class_name.startswith("gluonts.mx") or class_name.startswith("mxnet"): return prediction_net(*inputs.values()) else: return prediction_net(**inputs)