From 0f8213615f16f9592a59ae024e800e2c40292af0 Mon Sep 17 00:00:00 2001 From: Anirudh Dagar Date: Thu, 31 Oct 2024 18:26:08 +0100 Subject: [PATCH] BUG: Drop train_id_column for train_x --- .../transformer/feature_transformers/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/autogluon_assistant/transformer/feature_transformers/base.py b/src/autogluon_assistant/transformer/feature_transformers/base.py index ba31957..f0c594e 100644 --- a/src/autogluon_assistant/transformer/feature_transformers/base.py +++ b/src/autogluon_assistant/transformer/feature_transformers/base.py @@ -21,7 +21,7 @@ def _fit_dataframes(self, train_X: pd.DataFrame, train_y: pd.Series, **kwargs) - def fit(self, task: TabularPredictionTask) -> "BaseFeatureTransformer": try: train_x = task.train_data.drop( - columns=task.columns_in_train_but_not_test + [task.test_id_column], + columns=task.columns_in_train_but_not_test + [task.train_id_column], errors="ignore", ) train_y = task.train_data[task.label_column] @@ -47,7 +47,7 @@ def _transform_dataframes(self, train_X: pd.DataFrame, test_X: pd.DataFrame) -> def transform(self, task: TabularPredictionTask) -> TabularPredictionTask: try: train_x = task.train_data.drop( - columns=task.columns_in_train_but_not_test + [task.test_id_column], + columns=task.columns_in_train_but_not_test + [task.train_id_column], errors="ignore", ) train_y = task.train_data[task.label_column] @@ -61,9 +61,9 @@ def transform(self, task: TabularPredictionTask) -> TabularPredictionTask: # add back id and label columns transformed_train_data = pd.concat([train_x, train_y.rename(task.label_column)], axis=1) - if task.test_id_column in task.train_data.columns: + if task.train_id_column in task.train_data.columns: transformed_train_data = pd.concat( - [transformed_train_data, task.train_data[task.test_id_column]], axis=1 + [transformed_train_data, task.train_data[task.train_id_column]], axis=1 ) if task.test_id_column in task.test_data.columns: transformed_test_data = pd.concat([test_x, task.test_data[task.test_id_column]], axis=1)