Skip to content

Commit

Permalink
BUG: Drop train_id_column for train_x
Browse files Browse the repository at this point in the history
  • Loading branch information
AnirudhDagar committed Oct 31, 2024
1 parent 6ecef8c commit 0f82136
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def _fit_dataframes(self, train_X: pd.DataFrame, train_y: pd.Series, **kwargs) -
def fit(self, task: TabularPredictionTask) -> "BaseFeatureTransformer":
try:
train_x = task.train_data.drop(
columns=task.columns_in_train_but_not_test + [task.test_id_column],
columns=task.columns_in_train_but_not_test + [task.train_id_column],
errors="ignore",
)
train_y = task.train_data[task.label_column]
Expand All @@ -47,7 +47,7 @@ def _transform_dataframes(self, train_X: pd.DataFrame, test_X: pd.DataFrame) ->
def transform(self, task: TabularPredictionTask) -> TabularPredictionTask:
try:
train_x = task.train_data.drop(
columns=task.columns_in_train_but_not_test + [task.test_id_column],
columns=task.columns_in_train_but_not_test + [task.train_id_column],
errors="ignore",
)
train_y = task.train_data[task.label_column]
Expand All @@ -61,9 +61,9 @@ def transform(self, task: TabularPredictionTask) -> TabularPredictionTask:

# add back id and label columns
transformed_train_data = pd.concat([train_x, train_y.rename(task.label_column)], axis=1)
if task.test_id_column in task.train_data.columns:
if task.train_id_column in task.train_data.columns:
transformed_train_data = pd.concat(
[transformed_train_data, task.train_data[task.test_id_column]], axis=1
[transformed_train_data, task.train_data[task.train_id_column]], axis=1
)
if task.test_id_column in task.test_data.columns:
transformed_test_data = pd.concat([test_x, task.test_data[task.test_id_column]], axis=1)
Expand Down

0 comments on commit 0f82136

Please sign in to comment.