From 3ead095da57beae0b76b71fcf437e0cade452940 Mon Sep 17 00:00:00 2001 From: moghadas76 Date: Fri, 22 Mar 2024 10:57:31 +0000 Subject: [PATCH] Auto Conversion shape logic applied --- src/gluonts/transform/convert.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gluonts/transform/convert.py b/src/gluonts/transform/convert.py index 156ddcf638..3267cdae01 100644 --- a/src/gluonts/transform/convert.py +++ b/src/gluonts/transform/convert.py @@ -135,7 +135,9 @@ def __init__( def transform(self, data: DataEntry) -> DataEntry: value = np.asarray(data[self.field], dtype=self.dtype) - + if self.expected_ndim == 1: + if value.shape[0] == 1 or value.shape[-1] == 1: + value = value.ravel() assert_data_error( value.ndim == self.expected_ndim, 'Input for field "{self.field}" does not have the required'