From 33b9b3e9be57b58bf40d3911c22a12c49cc14693 Mon Sep 17 00:00:00 2001 From: Nishant Guvvada Date: Fri, 1 Mar 2024 10:25:08 +0530 Subject: [PATCH 1/2] Added users parameter in cross-validate --- rectools/model_selection/cross_validate.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rectools/model_selection/cross_validate.py b/rectools/model_selection/cross_validate.py index e1e39c15..87f9b378 100644 --- a/rectools/model_selection/cross_validate.py +++ b/rectools/model_selection/cross_validate.py @@ -48,6 +48,7 @@ def cross_validate( # pylint: disable=too-many-locals models: tp.Dict[str, ModelBase], k: int, filter_viewed: bool, + users: tp.Optional[ExternalIds] = None, items_to_recommend: tp.Optional[ExternalIds] = None, ) -> tp.Dict[str, tp.Any]: """ @@ -115,6 +116,8 @@ def cross_validate( # pylint: disable=too-many-locals fold_dataset = _gen_2x_internal_ids_dataset(interactions_df_train, dataset.user_features, dataset.item_features) interactions_df_test = interactions.df.iloc[test_ids] # 1x internal + if users is not None: + interactions_df_test = interactions_df_test[interactions_df_test[Columns.User].isin(users)] test_users = interactions_df_test[Columns.User].unique() # 1x internal catalog = interactions_df_train[Columns.Item].unique() # 1x internal From 425bedc17d52f5a59b453c9a543fc4c3376003e8 Mon Sep 17 00:00:00 2001 From: Nishant Guvvada Date: Wed, 6 Mar 2024 23:27:19 +0530 Subject: [PATCH 2/2] Update cross_validate.py - Moved users to the last - used dataset.user_id_map.convert_to_internal --- rectools/model_selection/cross_validate.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rectools/model_selection/cross_validate.py b/rectools/model_selection/cross_validate.py index 87f9b378..ebb986de 100644 --- a/rectools/model_selection/cross_validate.py +++ b/rectools/model_selection/cross_validate.py @@ -48,8 +48,8 @@ def cross_validate( # pylint: disable=too-many-locals models: tp.Dict[str, ModelBase], k: int, filter_viewed: bool, - users: tp.Optional[ExternalIds] = None, items_to_recommend: tp.Optional[ExternalIds] = None, + users: tp.Optional[ExternalIds] = None, ) -> tp.Dict[str, tp.Any]: """ Run cross validation on multiple models with multiple metrics. @@ -117,7 +117,8 @@ def cross_validate( # pylint: disable=too-many-locals interactions_df_test = interactions.df.iloc[test_ids] # 1x internal if users is not None: - interactions_df_test = interactions_df_test[interactions_df_test[Columns.User].isin(users)] + internal_users = dataset.user_id_map.convert_to_internal(users, strict=False) + interactions_df_test = interactions_df_test[interactions_df_test[Columns.User].isin(internal_users)] test_users = interactions_df_test[Columns.User].unique() # 1x internal catalog = interactions_df_train[Columns.Item].unique() # 1x internal