From 1fdc000fbe5720575fabdb6ef3feb69953d8ddf9 Mon Sep 17 00:00:00 2001 From: Spirina Majya Aleksandrovna Date: Tue, 24 Jun 2025 10:52:44 +0300 Subject: [PATCH 1/7] add extra columns --- rectools/models/nn/transformers/bert4rec.py | 16 ++- .../models/nn/transformers/data_preparator.py | 64 ++++++++--- rectools/models/nn/transformers/sasrec.py | 18 +-- rectools/models/nn/transformers/similarity.py | 6 +- tests/models/nn/transformers/test_bert4rec.py | 4 +- .../nn/transformers/test_data_preparator.py | 105 ++++++++++-------- 6 files changed, 136 insertions(+), 77 deletions(-) diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index 8e31d6ff..552e5ae8 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -128,7 +128,7 @@ def _mask_session( def _collate_fn_train( self, - batch: List[Tuple[List[int], List[float]]], + batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]], ) -> Dict[str, torch.Tensor]: """ Mask session elements to receive `x`. @@ -141,7 +141,7 @@ def _collate_fn_train( x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, self.session_max_len)) yw = np.zeros((batch_size, self.session_max_len)) - for i, (ses, ses_weights) in enumerate(batch): + for i, (ses, ses_weights, _) in enumerate(batch): masked_session, target = self._mask_session(ses) x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len] y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len] @@ -154,12 +154,14 @@ def _collate_fn_train( ) return batch_dict - def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]: + def _collate_fn_val( + self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]] + ) -> Dict[str, torch.Tensor]: batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, 1)) # until only leave-one-strategy yw = np.zeros((batch_size, 1)) # until only leave-one-strategy - for i, (ses, ses_weights) in enumerate(batch): + for i, (ses, ses_weights, _) in enumerate(batch): input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0] session = input_session.copy() @@ -179,14 +181,16 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st ) return batch_dict - def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]: + def _collate_fn_recommend( + self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]] + ) -> Dict[str, torch.Tensor]: """ Right truncation, left padding to `session_max_len` During inference model will use (`session_max_len` - 1) interactions and one extra "MASK" token will be added for making predictions. """ x = np.zeros((len(batch), self.session_max_len)) - for i, (ses, _) in enumerate(batch): + for i, (ses, _, _) in enumerate(batch): session = ses.copy() session = session + [self.extra_token_ids[MASKING_VALUE]] x[i, -len(ses) - 1 :] = session[-self.session_max_len :] diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index b13ec87d..b6028d6d 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -46,23 +46,33 @@ class SequenceDataset(TorchDataset): Weight of each interaction from the session. """ - def __init__(self, sessions: tp.List[tp.List[int]], weights: tp.List[tp.List[float]]): + def __init__( + self, + sessions: tp.List[tp.List[int]], + weights: tp.List[tp.List[float]], + extras: tp.Optional[tp.Dict[str, tp.List[tp.Any]]] = None, + ): self.sessions = sessions self.weights = weights + self.extras = extras def __len__(self) -> int: return len(self.sessions) - def __getitem__(self, index: int) -> tp.Tuple[tp.List[int], tp.List[float]]: + def __getitem__(self, index: int) -> tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]: session = self.sessions[index] # [session_len] weights = self.weights[index] # [session_len] - return session, weights + extras = ( + {feature_name: features[index] for feature_name, features in self.extras.items()} if self.extras else {} + ) + return session, weights, extras @classmethod def from_interactions( cls, interactions: pd.DataFrame, sort_users: bool = False, + extra_cols: tp.Optional[tp.List[str]] = None, ) -> "SequenceDataset": """ Group interactions by user. @@ -73,17 +83,18 @@ def from_interactions( interactions : pd.DataFrame User-item interactions. """ + cols_to_agg = [col for col in interactions.columns if col != Columns.User] sessions = ( interactions.sort_values(Columns.Datetime, kind="stable") - .groupby(Columns.User, sort=sort_users)[[Columns.Item, Columns.Weight]] + .groupby(Columns.User, sort=sort_users)[cols_to_agg] .agg(list) ) - sessions, weights = ( + sessions_items, weights = ( sessions[Columns.Item].to_list(), sessions[Columns.Weight].to_list(), ) - - return cls(sessions=sessions, weights=weights) + extras = {col: sessions[col].to_list() for col in extra_cols} if extra_cols else None + return cls(sessions=sessions_items, weights=weights, extras=extras) class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attributes @@ -133,6 +144,7 @@ def __init__( n_negatives: tp.Optional[int] = None, negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None, get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None, + extra_cols_kwargs: tp.Optional[InitKwargs] = None, **kwargs: tp.Any, ) -> None: self.item_id_map: IdMap @@ -148,6 +160,7 @@ def __init__( self.shuffle_train = shuffle_train self.get_val_mask_func = get_val_mask_func self.get_val_mask_func_kwargs = get_val_mask_func_kwargs + self.extra_cols_kwargs = extra_cols_kwargs def get_known_items_sorted_internal_ids(self) -> np.ndarray: """Return internal item ids from processed dataset in sorted order.""" @@ -231,7 +244,10 @@ def process_dataset_train(self, dataset: Dataset) -> None: # Prepare train dataset # User features are dropped for now because model doesn't support them - final_interactions = Interactions.from_raw(interactions, user_id_map, item_id_map, keep_extra_cols=True) + keep_extra_cols = self.extra_cols_kwargs is not None + final_interactions = Interactions.from_raw( + interactions, user_id_map, item_id_map, keep_extra_cols=keep_extra_cols + ) self.train_dataset = Dataset(user_id_map, item_id_map, final_interactions, item_features=item_features) self.item_id_map = self.train_dataset.item_id_map self._init_extra_token_ids() @@ -246,7 +262,9 @@ def process_dataset_train(self, dataset: Dataset) -> None: val_interactions = interactions[interactions[Columns.User].isin(val_targets[Columns.User].unique())].copy() val_interactions[Columns.Weight] = 0 val_interactions = pd.concat([val_interactions, val_targets], axis=0) - self.val_interactions = Interactions.from_raw(val_interactions, user_id_map, item_id_map).df + self.val_interactions = Interactions.from_raw( + val_interactions, user_id_map, item_id_map, keep_extra_cols=keep_extra_cols + ).df def _init_extra_token_ids(self) -> None: extra_token_ids = self.item_id_map.convert_to_internal(self.item_extra_tokens) @@ -261,7 +279,9 @@ def get_dataloader_train(self) -> DataLoader: DataLoader Train dataloader. """ - sequence_dataset = SequenceDataset.from_interactions(self.train_dataset.interactions.df) + sequence_dataset = SequenceDataset.from_interactions( + self.train_dataset.interactions.df, **self._ensure_kwargs_dict(self.extra_cols_kwargs) + ) train_dataloader = DataLoader( sequence_dataset, collate_fn=self._collate_fn_train, @@ -283,7 +303,9 @@ def get_dataloader_val(self) -> tp.Optional[DataLoader]: if self.val_interactions is None: return None - sequence_dataset = SequenceDataset.from_interactions(self.val_interactions) + sequence_dataset = SequenceDataset.from_interactions( + self.val_interactions, **self._ensure_kwargs_dict(self.extra_cols_kwargs) + ) val_dataloader = DataLoader( sequence_dataset, collate_fn=self._collate_fn_val, @@ -306,7 +328,9 @@ def get_dataloader_recommend(self, dataset: Dataset, batch_size: int) -> DataLoa # User ids here are internal user ids in dataset.interactions.df that was prepared for recommendations. # Sorting sessions by user ids will ensure that these ids will also be correct indexes in user embeddings matrix # that will be returned by the net. - sequence_dataset = SequenceDataset.from_interactions(interactions=dataset.interactions.df, sort_users=True) + sequence_dataset = SequenceDataset.from_interactions( + interactions=dataset.interactions.df, sort_users=True, **self._ensure_kwargs_dict(self.extra_cols_kwargs) + ) recommend_dataloader = DataLoader( sequence_dataset, batch_size=batch_size, @@ -359,7 +383,10 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset if n_filtered > 0: explanation = f"""{n_filtered} target users were considered cold because of missing known items""" warnings.warn(explanation) - filtered_interactions = Interactions.from_raw(interactions, rec_user_id_map, self.item_id_map) + keep_extra_cols = self.extra_cols_kwargs is not None + filtered_interactions = Interactions.from_raw( + interactions, rec_user_id_map, self.item_id_map, keep_extra_cols=keep_extra_cols + ) filtered_dataset = Dataset(rec_user_id_map, self.item_id_map, filtered_interactions) return filtered_dataset @@ -383,24 +410,27 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: """ interactions = dataset.get_raw_interactions() interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())] - filtered_interactions = Interactions.from_raw(interactions, dataset.user_id_map, self.item_id_map) + keep_extra_cols = self.extra_cols_kwargs is not None + filtered_interactions = Interactions.from_raw( + interactions, dataset.user_id_map, self.item_id_map, keep_extra_cols + ) filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions) return filtered_dataset def _collate_fn_train( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]], + batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]], ) -> tp.Dict[str, torch.Tensor]: raise NotImplementedError() def _collate_fn_val( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]], + batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]], ) -> tp.Dict[str, torch.Tensor]: raise NotImplementedError() def _collate_fn_recommend( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]], + batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]], ) -> tp.Dict[str, torch.Tensor]: raise NotImplementedError() diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index b8350f72..3491b33e 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -13,7 +13,7 @@ # limitations under the License. import typing as tp -from typing import Dict, List, Tuple +from typing import Dict import numpy as np import torch @@ -80,7 +80,7 @@ class SASRecDataPreparator(TransformerDataPreparatorBase): def _collate_fn_train( self, - batch: List[Tuple[List[int], List[float]]], + batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]], ) -> Dict[str, torch.Tensor]: """ Truncate each session from right to keep `session_max_len` items. @@ -91,7 +91,7 @@ def _collate_fn_train( x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, self.session_max_len)) yw = np.zeros((batch_size, self.session_max_len)) - for i, (ses, ses_weights) in enumerate(batch): + for i, (ses, ses_weights, _) in enumerate(batch): x[i, -len(ses) + 1 :] = ses[:-1] # ses: [session_len] -> x[i]: [session_max_len] y[i, -len(ses) + 1 :] = ses[1:] # ses: [session_len] -> y[i]: [session_max_len] yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_max_len] @@ -103,12 +103,14 @@ def _collate_fn_train( ) return batch_dict - def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]: + def _collate_fn_val( + self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]] + ) -> Dict[str, torch.Tensor]: batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses yw = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses - for i, (ses, ses_weights) in enumerate(batch): + for i, (ses, ses_weights, _) in enumerate(batch): input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0] # take only first target for leave-one-strategy @@ -126,10 +128,12 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st ) return batch_dict - def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]: + def _collate_fn_recommend( + self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]] + ) -> Dict[str, torch.Tensor]: """Right truncation, left padding to session_max_len""" x = np.zeros((len(batch), self.session_max_len)) - for i, (ses, _) in enumerate(batch): + for i, (ses, _, _) in enumerate(batch): x[i, -len(ses) :] = ses[-self.session_max_len :] return {"x": torch.LongTensor(x)} diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index da1ac615..40f82011 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -62,7 +62,11 @@ class DistanceSimilarityModule(SimilarityModuleBase): dist_available: tp.List[str] = [Distance.DOT, Distance.COSINE] epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8]) - def __init__(self, distance: str = "dot") -> None: + def __init__( + self, + distance: str = "dot", + **kwargs: tp.Any, + ) -> None: super().__init__() if distance not in self.dist_available: raise ValueError("`dist` can only be either `dot` or `cosine`.") diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 140389aa..526dcc5f 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -640,13 +640,13 @@ def __init__( def _collate_fn_train( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]], + batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]], ) -> tp.Dict[str, torch.Tensor]: batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, self.session_max_len)) yw = np.zeros((batch_size, self.session_max_len)) - for i, (ses, ses_weights) in enumerate(batch): + for i, (ses, ses_weights, _) in enumerate(batch): y[i, -self.n_last_targets] = ses[-self.n_last_targets] yw[i, -self.n_last_targets] = ses_weights[-self.n_last_targets] x[i, -len(ses) :] = ses diff --git a/tests/models/nn/transformers/test_data_preparator.py b/tests/models/nn/transformers/test_data_preparator.py index 5f41ea8e..33a56b77 100644 --- a/tests/models/nn/transformers/test_data_preparator.py +++ b/tests/models/nn/transformers/test_data_preparator.py @@ -31,38 +31,53 @@ class TestSequenceDataset: def interactions_df(self) -> pd.DataFrame: interactions_df = pd.DataFrame( [ - [10, 13, 1, "2021-11-30"], - [10, 11, 1, "2021-11-29"], - [10, 12, 4, "2021-11-29"], - [30, 11, 1, "2021-11-27"], - [30, 12, 2, "2021-11-26"], - [30, 15, 1, "2021-11-25"], - [40, 11, 1, "2021-11-25"], - [40, 17, 8, "2021-11-26"], - [50, 16, 1, "2021-11-25"], - [10, 14, 1, "2021-11-28"], + [10, 13, 1, "2021-11-30", 0], + [10, 11, 1, "2021-11-29", 1], + [10, 12, 4, "2021-11-29", 1], + [30, 11, 1, "2021-11-27", 0], + [30, 12, 2, "2021-11-26", 1], + [30, 15, 1, "2021-11-25", 1], + [40, 11, 1, "2021-11-25", 2], + [40, 17, 8, "2021-11-26", 1], + [50, 16, 1, "2021-11-25", 0], + [10, 14, 1, "2021-11-28", 0], ], - columns=Columns.Interactions, + columns=Columns.Interactions + ["extra_column"], ) return interactions_df @pytest.mark.parametrize( - "expected_sessions, expected_weights", - (([[14, 11, 12, 13], [15, 12, 11], [11, 17], [16]], [[1, 1, 4, 1], [1, 2, 1], [1, 8], [1]]),), + "expected_sessions, expected_weights, expected_extra_column", + ( + ( + [[14, 11, 12, 13], [15, 12, 11], [11, 17], [16]], + [[1, 1, 4, 1], [1, 2, 1], [1, 8], [1]], + [[0, 1, 1, 0], [1, 1, 0], [2, 1], [0]], + ), + ), ) def test_from_interactions( self, interactions_df: pd.DataFrame, expected_sessions: tp.List[tp.List[int]], expected_weights: tp.List[tp.List[float]], + expected_extra_column: tp.Dict[str, tp.List[tp.Any]], ) -> None: - actual = SequenceDataset.from_interactions(interactions=interactions_df, sort_users=True) + actual = SequenceDataset.from_interactions( + interactions=interactions_df, sort_users=True, extra_cols=["extra_column"] + ) assert len(actual.sessions) == len(expected_sessions) assert all( actual_list == expected_list for actual_list, expected_list in zip(actual.sessions, expected_sessions) ) assert len(actual.weights) == len(expected_weights) assert all(actual_list == expected_list for actual_list, expected_list in zip(actual.weights, expected_weights)) + assert actual.extras is not None + assert len(actual.extras["extra_column"]) == len(expected_extra_column) + assert all( + actual_list == expected_list + for actual_list, expected_list in zip(actual.extras["extra_column"], expected_extra_column) + ) class TestTransformerDataPreparatorBase: @@ -71,26 +86,26 @@ class TestTransformerDataPreparatorBase: def interactions_df(self) -> pd.DataFrame: interactions_df = pd.DataFrame( [ - [10, 13, 1, "2021-11-30"], - [10, 11, 1, "2021-11-29"], - [10, 12, 1, "2021-11-29"], - [30, 11, 1, "2021-11-27"], - [30, 12, 2, "2021-11-26"], - [30, 15, 1, "2021-11-25"], - [40, 11, 1, "2021-11-25"], - [40, 17, 1, "2021-11-26"], - [50, 16, 1, "2021-11-25"], - [10, 14, 1, "2021-11-28"], - [10, 16, 1, "2021-11-27"], - [20, 13, 9, "2021-11-28"], + [10, 13, 1, "2021-11-30", 0], + [10, 11, 1, "2021-11-29", 2], + [10, 12, 1, "2021-11-29", 3], + [30, 11, 1, "2021-11-27", 4], + [30, 12, 2, "2021-11-26", 1], + [30, 15, 1, "2021-11-25", 0], + [40, 11, 1, "2021-11-25", 1], + [40, 17, 1, "2021-11-26", 1], + [50, 16, 1, "2021-11-25", 2], + [10, 14, 1, "2021-11-28", 2], + [10, 16, 1, "2021-11-27", 1], + [20, 13, 9, "2021-11-28", 1], ], - columns=Columns.Interactions, + columns=Columns.Interactions + ["extra_column"], ) return interactions_df @pytest.fixture def dataset(self, interactions_df: pd.DataFrame) -> Dataset: - return Dataset.construct(interactions_df) + return Dataset.construct(interactions_df, keep_extra_cols=True) @pytest.fixture def dataset_dense_item_features(self, interactions_df: pd.DataFrame) -> Dataset: @@ -130,17 +145,17 @@ def data_preparator(self) -> TransformerDataPreparatorBase: Interactions( pd.DataFrame( [ - [0, 1, 1.0, "2021-11-25"], - [1, 2, 1.0, "2021-11-25"], - [0, 3, 2.0, "2021-11-26"], - [1, 4, 1.0, "2021-11-26"], - [0, 2, 1.0, "2021-11-27"], - [2, 5, 1.0, "2021-11-28"], - [2, 2, 1.0, "2021-11-29"], - [2, 3, 1.0, "2021-11-29"], - [2, 6, 1.0, "2021-11-30"], + [0, 1, 1.0, "2021-11-25", 0], + [1, 2, 1.0, "2021-11-25", 1], + [0, 3, 2.0, "2021-11-26", 1], + [1, 4, 1.0, "2021-11-26", 1], + [0, 2, 1.0, "2021-11-27", 4], + [2, 5, 1.0, "2021-11-28", 2], + [2, 2, 1.0, "2021-11-29", 2], + [2, 3, 1.0, "2021-11-29", 3], + [2, 6, 1.0, "2021-11-30", 0], ], - columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], + columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime, "extra_column"], ), ), ), @@ -154,6 +169,7 @@ def test_process_dataset_train( expected_item_id_map: IdMap, expected_user_id_map: IdMap, ) -> None: + data_preparator.extra_cols_kwargs = {"extra_cols": "extra_column"} data_preparator.process_dataset_train(dataset) actual = data_preparator.train_dataset assert_id_map_equal(actual.user_id_map, expected_user_id_map) @@ -192,13 +208,13 @@ def test_process_dataset_train_with_dense_item_features( Interactions( pd.DataFrame( [ - [0, 6, 1.0, "2021-11-30"], - [0, 2, 1.0, "2021-11-29"], - [0, 3, 1.0, "2021-11-29"], - [0, 5, 1.0, "2021-11-28"], - [1, 6, 9.0, "2021-11-28"], + [0, 6, 1.0, "2021-11-30", 0], + [0, 2, 1.0, "2021-11-29", 2], + [0, 3, 1.0, "2021-11-29", 3], + [0, 5, 1.0, "2021-11-28", 2], + [1, 6, 9.0, "2021-11-28", 1], ], - columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], + columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime, "extra_column"], ), ), ), @@ -212,6 +228,7 @@ def test_transform_dataset_u2i( expected_item_id_map: IdMap, expected_user_id_map: IdMap, ) -> None: + data_preparator.extra_cols_kwargs = {"extra_cols": "extra_column"} data_preparator.process_dataset_train(dataset) users = [10, 20] actual = data_preparator.transform_dataset_u2i(dataset, users) From 2f8d7cfd4eab878c986401da1718125498fb15a0 Mon Sep 17 00:00:00 2001 From: Spirina Majya Aleksandrovna Date: Tue, 24 Jun 2025 20:50:32 +0300 Subject: [PATCH 2/7] change extr_cols kwarg --- .../models/nn/transformers/data_preparator.py | 49 ++++++++++--------- .../nn/transformers/test_data_preparator.py | 8 ++- 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index b6028d6d..36a3d65c 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -72,7 +72,6 @@ def from_interactions( cls, interactions: pd.DataFrame, sort_users: bool = False, - extra_cols: tp.Optional[tp.List[str]] = None, ) -> "SequenceDataset": """ Group interactions by user. @@ -93,7 +92,8 @@ def from_interactions( sessions[Columns.Item].to_list(), sessions[Columns.Weight].to_list(), ) - extras = {col: sessions[col].to_list() for col in extra_cols} if extra_cols else None + extra_cols = [col for col in interactions.columns if col not in Columns.Interactions] + extras = {col: sessions[col].to_list() for col in extra_cols} if len(extra_cols) > 0 else None return cls(sessions=sessions_items, weights=weights, extras=extras) @@ -144,7 +144,7 @@ def __init__( n_negatives: tp.Optional[int] = None, negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None, get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None, - extra_cols_kwargs: tp.Optional[InitKwargs] = None, + extra_cols: tp.Optional[tp.List[str]] = None, **kwargs: tp.Any, ) -> None: self.item_id_map: IdMap @@ -160,7 +160,7 @@ def __init__( self.shuffle_train = shuffle_train self.get_val_mask_func = get_val_mask_func self.get_val_mask_func_kwargs = get_val_mask_func_kwargs - self.extra_cols_kwargs = extra_cols_kwargs + self.extra_cols = extra_cols def get_known_items_sorted_internal_ids(self) -> np.ndarray: """Return internal item ids from processed dataset in sorted order.""" @@ -216,7 +216,10 @@ def _filter_train_interactions(self, train_interactions: pd.DataFrame) -> pd.Dat def process_dataset_train(self, dataset: Dataset) -> None: """Process train dataset and save data.""" - raw_interactions = dataset.get_raw_interactions() + if self.extra_cols is None: + raw_interactions = dataset.get_raw_interactions(include_extra_cols=False) + else: + raw_interactions = dataset.get_raw_interactions()[Columns.Interactions + self.extra_cols] # Exclude val interaction targets from train if needed interactions = raw_interactions @@ -244,9 +247,11 @@ def process_dataset_train(self, dataset: Dataset) -> None: # Prepare train dataset # User features are dropped for now because model doesn't support them - keep_extra_cols = self.extra_cols_kwargs is not None final_interactions = Interactions.from_raw( - interactions, user_id_map, item_id_map, keep_extra_cols=keep_extra_cols + interactions, + user_id_map, + item_id_map, + keep_extra_cols=self.extra_cols is not None, ) self.train_dataset = Dataset(user_id_map, item_id_map, final_interactions, item_features=item_features) self.item_id_map = self.train_dataset.item_id_map @@ -263,7 +268,7 @@ def process_dataset_train(self, dataset: Dataset) -> None: val_interactions[Columns.Weight] = 0 val_interactions = pd.concat([val_interactions, val_targets], axis=0) self.val_interactions = Interactions.from_raw( - val_interactions, user_id_map, item_id_map, keep_extra_cols=keep_extra_cols + val_interactions, user_id_map, item_id_map, keep_extra_cols=self.extra_cols is not None ).df def _init_extra_token_ids(self) -> None: @@ -279,9 +284,7 @@ def get_dataloader_train(self) -> DataLoader: DataLoader Train dataloader. """ - sequence_dataset = SequenceDataset.from_interactions( - self.train_dataset.interactions.df, **self._ensure_kwargs_dict(self.extra_cols_kwargs) - ) + sequence_dataset = SequenceDataset.from_interactions(self.train_dataset.interactions.df) train_dataloader = DataLoader( sequence_dataset, collate_fn=self._collate_fn_train, @@ -303,9 +306,7 @@ def get_dataloader_val(self) -> tp.Optional[DataLoader]: if self.val_interactions is None: return None - sequence_dataset = SequenceDataset.from_interactions( - self.val_interactions, **self._ensure_kwargs_dict(self.extra_cols_kwargs) - ) + sequence_dataset = SequenceDataset.from_interactions(self.val_interactions) val_dataloader = DataLoader( sequence_dataset, collate_fn=self._collate_fn_val, @@ -328,9 +329,7 @@ def get_dataloader_recommend(self, dataset: Dataset, batch_size: int) -> DataLoa # User ids here are internal user ids in dataset.interactions.df that was prepared for recommendations. # Sorting sessions by user ids will ensure that these ids will also be correct indexes in user embeddings matrix # that will be returned by the net. - sequence_dataset = SequenceDataset.from_interactions( - interactions=dataset.interactions.df, sort_users=True, **self._ensure_kwargs_dict(self.extra_cols_kwargs) - ) + sequence_dataset = SequenceDataset.from_interactions(interactions=dataset.interactions.df, sort_users=True) recommend_dataloader = DataLoader( sequence_dataset, batch_size=batch_size, @@ -364,7 +363,10 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset Final item_id_map is model item_id_map constructed during training. """ # Filter interactions in dataset internal ids - interactions = dataset.interactions.df + if self.extra_cols is None: + interactions = dataset.interactions.df[Columns.Interactions] + else: + interactions = dataset.interactions.df[Columns.Interactions + self.extra_cols] users_internal = dataset.user_id_map.convert_to_internal(users, strict=False) items_internal = dataset.item_id_map.convert_to_internal(self.get_known_item_ids(), strict=False) interactions = interactions[interactions[Columns.User].isin(users_internal)] @@ -383,9 +385,8 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset if n_filtered > 0: explanation = f"""{n_filtered} target users were considered cold because of missing known items""" warnings.warn(explanation) - keep_extra_cols = self.extra_cols_kwargs is not None filtered_interactions = Interactions.from_raw( - interactions, rec_user_id_map, self.item_id_map, keep_extra_cols=keep_extra_cols + interactions, rec_user_id_map, self.item_id_map, keep_extra_cols=self.extra_cols is not None ) filtered_dataset = Dataset(rec_user_id_map, self.item_id_map, filtered_interactions) return filtered_dataset @@ -408,11 +409,13 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: Final user_id_map is the same as dataset original. Final item_id_map is model item_id_map constructed during training. """ - interactions = dataset.get_raw_interactions() + if self.extra_cols is None: + interactions = dataset.get_raw_interactions(include_extra_cols=False) + else: + interactions = dataset.get_raw_interactions()[Columns.Interactions + self.extra_cols] interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())] - keep_extra_cols = self.extra_cols_kwargs is not None filtered_interactions = Interactions.from_raw( - interactions, dataset.user_id_map, self.item_id_map, keep_extra_cols + interactions, dataset.user_id_map, self.item_id_map, self.extra_cols is not None ) filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions) return filtered_dataset diff --git a/tests/models/nn/transformers/test_data_preparator.py b/tests/models/nn/transformers/test_data_preparator.py index 33a56b77..dac33aaf 100644 --- a/tests/models/nn/transformers/test_data_preparator.py +++ b/tests/models/nn/transformers/test_data_preparator.py @@ -63,9 +63,7 @@ def test_from_interactions( expected_weights: tp.List[tp.List[float]], expected_extra_column: tp.Dict[str, tp.List[tp.Any]], ) -> None: - actual = SequenceDataset.from_interactions( - interactions=interactions_df, sort_users=True, extra_cols=["extra_column"] - ) + actual = SequenceDataset.from_interactions(interactions=interactions_df, sort_users=True) assert len(actual.sessions) == len(expected_sessions) assert all( actual_list == expected_list for actual_list, expected_list in zip(actual.sessions, expected_sessions) @@ -169,7 +167,7 @@ def test_process_dataset_train( expected_item_id_map: IdMap, expected_user_id_map: IdMap, ) -> None: - data_preparator.extra_cols_kwargs = {"extra_cols": "extra_column"} + data_preparator.extra_cols = ["extra_column"] data_preparator.process_dataset_train(dataset) actual = data_preparator.train_dataset assert_id_map_equal(actual.user_id_map, expected_user_id_map) @@ -228,7 +226,7 @@ def test_transform_dataset_u2i( expected_item_id_map: IdMap, expected_user_id_map: IdMap, ) -> None: - data_preparator.extra_cols_kwargs = {"extra_cols": "extra_column"} + data_preparator.extra_cols = ["extra_column"] data_preparator.process_dataset_train(dataset) users = [10, 20] actual = data_preparator.transform_dataset_u2i(dataset, users) From 87822f7fe104a1469dc8a31628112107d1c75fdf Mon Sep 17 00:00:00 2001 From: Spirina Majya Aleksandrovna Date: Tue, 24 Jun 2025 21:45:00 +0300 Subject: [PATCH 3/7] test --- tests/models/nn/transformers/test_data_preparator.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/models/nn/transformers/test_data_preparator.py b/tests/models/nn/transformers/test_data_preparator.py index dac33aaf..c8e83c94 100644 --- a/tests/models/nn/transformers/test_data_preparator.py +++ b/tests/models/nn/transformers/test_data_preparator.py @@ -47,12 +47,12 @@ def interactions_df(self) -> pd.DataFrame: return interactions_df @pytest.mark.parametrize( - "expected_sessions, expected_weights, expected_extra_column", + "expected_sessions, expected_weights, expected_extras", ( ( [[14, 11, 12, 13], [15, 12, 11], [11, 17], [16]], [[1, 1, 4, 1], [1, 2, 1], [1, 8], [1]], - [[0, 1, 1, 0], [1, 1, 0], [2, 1], [0]], + {"extra_column": [[0, 1, 1, 0], [1, 1, 0], [2, 1], [0]]}, ), ), ) @@ -61,7 +61,7 @@ def test_from_interactions( interactions_df: pd.DataFrame, expected_sessions: tp.List[tp.List[int]], expected_weights: tp.List[tp.List[float]], - expected_extra_column: tp.Dict[str, tp.List[tp.Any]], + expected_extras: tp.Dict[str, tp.List[tp.Any]], ) -> None: actual = SequenceDataset.from_interactions(interactions=interactions_df, sort_users=True) assert len(actual.sessions) == len(expected_sessions) @@ -71,10 +71,10 @@ def test_from_interactions( assert len(actual.weights) == len(expected_weights) assert all(actual_list == expected_list for actual_list, expected_list in zip(actual.weights, expected_weights)) assert actual.extras is not None - assert len(actual.extras["extra_column"]) == len(expected_extra_column) + assert len(actual.extras["extra_column"]) == len(expected_extras["extra_column"]) assert all( actual_list == expected_list - for actual_list, expected_list in zip(actual.extras["extra_column"], expected_extra_column) + for actual_list, expected_list in zip(actual.extras["extra_column"], expected_extras["extra_column"]) ) From f8b422d87328b1ed2a00b929e8a60da3b03d5dba Mon Sep 17 00:00:00 2001 From: Spirina Majya Aleksandrovna Date: Wed, 25 Jun 2025 16:51:44 +0300 Subject: [PATCH 4/7] add session, item towers to similarity module --- .../models/nn/transformers/data_preparator.py | 6 +++--- rectools/models/nn/transformers/lightning.py | 2 ++ rectools/models/nn/transformers/similarity.py | 16 ++++++++++++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index 36a3d65c..d2ba9973 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -251,7 +251,7 @@ def process_dataset_train(self, dataset: Dataset) -> None: interactions, user_id_map, item_id_map, - keep_extra_cols=self.extra_cols is not None, + keep_extra_cols=True, ) self.train_dataset = Dataset(user_id_map, item_id_map, final_interactions, item_features=item_features) self.item_id_map = self.train_dataset.item_id_map @@ -268,7 +268,7 @@ def process_dataset_train(self, dataset: Dataset) -> None: val_interactions[Columns.Weight] = 0 val_interactions = pd.concat([val_interactions, val_targets], axis=0) self.val_interactions = Interactions.from_raw( - val_interactions, user_id_map, item_id_map, keep_extra_cols=self.extra_cols is not None + val_interactions, user_id_map, item_id_map, keep_extra_cols=True ).df def _init_extra_token_ids(self) -> None: @@ -386,7 +386,7 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset explanation = f"""{n_filtered} target users were considered cold because of missing known items""" warnings.warn(explanation) filtered_interactions = Interactions.from_raw( - interactions, rec_user_id_map, self.item_id_map, keep_extra_cols=self.extra_cols is not None + interactions, rec_user_id_map, self.item_id_map, keep_extra_cols=True ) filtered_dataset = Dataset(rec_user_id_map, self.item_id_map, filtered_interactions) return filtered_dataset diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index df97f882..15eeba8c 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -387,7 +387,9 @@ def _get_user_item_embeddings( for batch in recommend_dataloader: batch = {k: v.to(device) for k, v in batch.items()} batch_embs = self.torch_model.encode_sessions(batch, item_embs)[:, -1, :] + batch_embs = self.torch_model.similarity_module.session_tower_forward(batch_embs) user_embs.append(batch_embs.cpu()) + item_embs = self.torch_model.similarity_module.item_tower_forward(item_embs) return torch.cat(user_embs), item_embs diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index 40f82011..711d0621 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -34,6 +34,14 @@ def _get_pos_neg_logits( ) -> torch.Tensor: raise NotImplementedError() + def session_tower_forward(self, session_embs: torch.Tensor) -> torch.Tensor: + """Forward pass for session tower.""" + raise NotImplementedError() + + def item_tower_forward(self, item_embs: torch.Tensor) -> torch.Tensor: + """Forward pass for item tower.""" + raise NotImplementedError() + def forward( self, session_embs: torch.Tensor, @@ -91,6 +99,14 @@ def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: embeddings = embeddings / torch.max(embedding_norm, self.epsilon_cosine_dist.to(embeddings)) return embeddings + def session_tower_forward(self, session_embs: torch.Tensor) -> torch.Tensor: + """Forward pass for session tower.""" + return session_embs + + def item_tower_forward(self, item_embs: torch.Tensor) -> torch.Tensor: + """Forward pass for item tower.""" + return item_embs + def forward( self, session_embs: torch.Tensor, From 670d0c26725943d45750134cb62f5373c0389acd Mon Sep 17 00:00:00 2001 From: Spirina Majya Aleksandrovna Date: Thu, 26 Jun 2025 13:11:40 +0300 Subject: [PATCH 5/7] BatchElement --- rectools/models/nn/transformers/bert4rec.py | 12 +++------ .../models/nn/transformers/data_preparator.py | 11 ++++---- rectools/models/nn/transformers/sasrec.py | 12 +++------ rectools/models/nn/transformers/similarity.py | 12 ++------- tests/models/nn/transformers/test_bert4rec.py | 4 +-- .../nn/transformers/test_data_preparator.py | 26 +++++++++---------- 6 files changed, 31 insertions(+), 46 deletions(-) diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index 552e5ae8..a58a4502 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -36,7 +36,7 @@ ValMaskCallable, ) from .constants import MASKING_VALUE, PADDING_VALUE -from .data_preparator import InitKwargs, TransformerDataPreparatorBase +from .data_preparator import BatchElement, InitKwargs, TransformerDataPreparatorBase from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase from .net_blocks import ( LearnableInversePositionalEncoding, @@ -128,7 +128,7 @@ def _mask_session( def _collate_fn_train( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]], + batch: tp.List[BatchElement], ) -> Dict[str, torch.Tensor]: """ Mask session elements to receive `x`. @@ -154,9 +154,7 @@ def _collate_fn_train( ) return batch_dict - def _collate_fn_val( - self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]] - ) -> Dict[str, torch.Tensor]: + def _collate_fn_val(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]: batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, 1)) # until only leave-one-strategy @@ -181,9 +179,7 @@ def _collate_fn_val( ) return batch_dict - def _collate_fn_recommend( - self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]] - ) -> Dict[str, torch.Tensor]: + def _collate_fn_recommend(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]: """ Right truncation, left padding to `session_max_len` During inference model will use (`session_max_len` - 1) interactions diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index d2ba9973..56c43a6d 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -32,6 +32,7 @@ from .negative_sampler import TransformerNegativeSamplerBase InitKwargs = tp.Dict[str, tp.Any] +BatchElement = tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]] class SequenceDataset(TorchDataset): @@ -59,7 +60,7 @@ def __init__( def __len__(self) -> int: return len(self.sessions) - def __getitem__(self, index: int) -> tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]: + def __getitem__(self, index: int) -> BatchElement: session = self.sessions[index] # [session_len] weights = self.weights[index] # [session_len] extras = ( @@ -415,25 +416,25 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: interactions = dataset.get_raw_interactions()[Columns.Interactions + self.extra_cols] interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())] filtered_interactions = Interactions.from_raw( - interactions, dataset.user_id_map, self.item_id_map, self.extra_cols is not None + interactions, dataset.user_id_map, self.item_id_map, keep_extra_cols=True ) filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions) return filtered_dataset def _collate_fn_train( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]], + batch: tp.List[BatchElement], ) -> tp.Dict[str, torch.Tensor]: raise NotImplementedError() def _collate_fn_val( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]], + batch: tp.List[BatchElement], ) -> tp.Dict[str, torch.Tensor]: raise NotImplementedError() def _collate_fn_recommend( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]], + batch: tp.List[BatchElement], ) -> tp.Dict[str, torch.Tensor]: raise NotImplementedError() diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index 3491b33e..a3f0c73b 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -36,7 +36,7 @@ TransformerModelConfig, ValMaskCallable, ) -from .data_preparator import InitKwargs, TransformerDataPreparatorBase +from .data_preparator import BatchElement, InitKwargs, TransformerDataPreparatorBase from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase from .net_blocks import ( LearnableInversePositionalEncoding, @@ -80,7 +80,7 @@ class SASRecDataPreparator(TransformerDataPreparatorBase): def _collate_fn_train( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]], + batch: tp.List[BatchElement], ) -> Dict[str, torch.Tensor]: """ Truncate each session from right to keep `session_max_len` items. @@ -103,9 +103,7 @@ def _collate_fn_train( ) return batch_dict - def _collate_fn_val( - self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]] - ) -> Dict[str, torch.Tensor]: + def _collate_fn_val(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]: batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) y = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses @@ -128,9 +126,7 @@ def _collate_fn_val( ) return batch_dict - def _collate_fn_recommend( - self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]] - ) -> Dict[str, torch.Tensor]: + def _collate_fn_recommend(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]: """Right truncation, left padding to session_max_len""" x = np.zeros((len(batch), self.session_max_len)) for i, (ses, _, _) in enumerate(batch): diff --git a/rectools/models/nn/transformers/similarity.py b/rectools/models/nn/transformers/similarity.py index 711d0621..006f1ba0 100644 --- a/rectools/models/nn/transformers/similarity.py +++ b/rectools/models/nn/transformers/similarity.py @@ -36,11 +36,11 @@ def _get_pos_neg_logits( def session_tower_forward(self, session_embs: torch.Tensor) -> torch.Tensor: """Forward pass for session tower.""" - raise NotImplementedError() + return session_embs def item_tower_forward(self, item_embs: torch.Tensor) -> torch.Tensor: """Forward pass for item tower.""" - raise NotImplementedError() + return item_embs def forward( self, @@ -99,14 +99,6 @@ def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor: embeddings = embeddings / torch.max(embedding_norm, self.epsilon_cosine_dist.to(embeddings)) return embeddings - def session_tower_forward(self, session_embs: torch.Tensor) -> torch.Tensor: - """Forward pass for session tower.""" - return session_embs - - def item_tower_forward(self, item_embs: torch.Tensor) -> torch.Tensor: - """Forward pass for item tower.""" - return item_embs - def forward( self, session_embs: torch.Tensor, diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 526dcc5f..f89e82ee 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -34,7 +34,7 @@ TransformerLightningModule, ) from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable -from rectools.models.nn.transformers.data_preparator import InitKwargs +from rectools.models.nn.transformers.data_preparator import BatchElement, InitKwargs from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase from rectools.models.nn.transformers.similarity import DistanceSimilarityModule from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone @@ -640,7 +640,7 @@ def __init__( def _collate_fn_train( self, - batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]], + batch: tp.List[BatchElement], ) -> tp.Dict[str, torch.Tensor]: batch_size = len(batch) x = np.zeros((batch_size, self.session_max_len)) diff --git a/tests/models/nn/transformers/test_data_preparator.py b/tests/models/nn/transformers/test_data_preparator.py index c8e83c94..ec6bcd67 100644 --- a/tests/models/nn/transformers/test_data_preparator.py +++ b/tests/models/nn/transformers/test_data_preparator.py @@ -132,6 +132,7 @@ def data_preparator(self) -> TransformerDataPreparatorBase: session_max_len=4, batch_size=4, dataloader_num_workers=0, + extra_cols=["extra_column"], ) @pytest.mark.parametrize( @@ -167,7 +168,6 @@ def test_process_dataset_train( expected_item_id_map: IdMap, expected_user_id_map: IdMap, ) -> None: - data_preparator.extra_cols = ["extra_column"] data_preparator.process_dataset_train(dataset) actual = data_preparator.train_dataset assert_id_map_equal(actual.user_id_map, expected_user_id_map) @@ -179,6 +179,7 @@ def test_process_dataset_train_with_dense_item_features( dataset_dense_item_features: Dataset, data_preparator: TransformerDataPreparatorBase, ) -> None: + data_preparator.extra_cols = None data_preparator.process_dataset_train(dataset_dense_item_features) actual = data_preparator.train_dataset.item_features expected_values = np.array( @@ -226,7 +227,6 @@ def test_transform_dataset_u2i( expected_item_id_map: IdMap, expected_user_id_map: IdMap, ) -> None: - data_preparator.extra_cols = ["extra_column"] data_preparator.process_dataset_train(dataset) users = [10, 20] actual = data_preparator.transform_dataset_u2i(dataset, users) @@ -243,18 +243,18 @@ def test_transform_dataset_u2i( Interactions( pd.DataFrame( [ - [0, 6, 1.0, "2021-11-30"], - [0, 2, 1.0, "2021-11-29"], - [0, 3, 1.0, "2021-11-29"], - [1, 2, 1.0, "2021-11-27"], - [1, 3, 2.0, "2021-11-26"], - [1, 1, 1.0, "2021-11-25"], - [2, 2, 1.0, "2021-11-25"], - [2, 4, 1.0, "2021-11-26"], - [0, 5, 1.0, "2021-11-28"], - [4, 6, 9.0, "2021-11-28"], + [0, 6, 1.0, "2021-11-30", 0], + [0, 2, 1.0, "2021-11-29", 2], + [0, 3, 1.0, "2021-11-29", 3], + [1, 2, 1.0, "2021-11-27", 4], + [1, 3, 2.0, "2021-11-26", 1], + [1, 1, 1.0, "2021-11-25", 0], + [2, 2, 1.0, "2021-11-25", 1], + [2, 4, 1.0, "2021-11-26", 1], + [0, 5, 1.0, "2021-11-28", 2], + [4, 6, 9.0, "2021-11-28", 1], ], - columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], + columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime, "extra_column"], ), ), ), From 24c19bc907cf5b09500b5b2d5c9446a57e63cfa8 Mon Sep 17 00:00:00 2001 From: Spirina Majya Aleksandrovna Date: Thu, 26 Jun 2025 15:11:36 +0300 Subject: [PATCH 6/7] changelog --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 558ee220..acc13ced 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased +### Added +- `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287)) + ## [0.14.0] - 16.05.2025 ### Added From 0a0a2aa00e35d6b1404a91817bd48bae0c0e3122 Mon Sep 17 00:00:00 2001 From: Spirina Majya Aleksandrovna Date: Mon, 30 Jun 2025 16:50:58 +0300 Subject: [PATCH 7/7] change get_raw_interactions --- rectools/dataset/dataset.py | 5 +++- rectools/dataset/interactions.py | 15 ++++++++---- .../models/nn/transformers/data_preparator.py | 23 +++++++++---------- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/rectools/dataset/dataset.py b/rectools/dataset/dataset.py index 6d7a7d52..22cd71fd 100644 --- a/rectools/dataset/dataset.py +++ b/rectools/dataset/dataset.py @@ -348,7 +348,10 @@ def get_user_item_matrix( return matrix def get_raw_interactions( - self, include_weight: bool = True, include_datetime: bool = True, include_extra_cols: bool = True + self, + include_weight: bool = True, + include_datetime: bool = True, + include_extra_cols: tp.Union[bool, tp.List[str]] = True, ) -> pd.DataFrame: """ Return interactions as a `pd.DataFrame` object with replacing internal user and item ids to external ones. diff --git a/rectools/dataset/interactions.py b/rectools/dataset/interactions.py index 3f06ba70..2cfda50a 100644 --- a/rectools/dataset/interactions.py +++ b/rectools/dataset/interactions.py @@ -167,7 +167,7 @@ def to_external( item_id_map: IdMap, include_weight: bool = True, include_datetime: bool = True, - include_extra_cols: bool = True, + include_extra_cols: tp.Union[bool, tp.List[str]] = True, ) -> pd.DataFrame: """ Convert itself to `pd.DataFrame` with replacing internal user and item ids to external ones. @@ -182,8 +182,9 @@ def to_external( Whether to include weight column into resulting table or not include_datetime : bool, default ``True`` Whether to include datetime column into resulting table or not. - include_extra_cols: bool, default ``True`` - Whether to include extra columns into resulting table or not. + include_extra_cols: bool or List[str], default ``True`` + If bool, indicates whether to include all extra columns into resulting table or not. + If list of strings, indicates which extra columns to include into resulting table. Returns ------- @@ -201,9 +202,13 @@ def to_external( cols_to_add.append(Columns.Weight) if include_datetime: cols_to_add.append(Columns.Datetime) - if include_extra_cols: + + extra_cols = [] + if isinstance(include_extra_cols, list): + extra_cols = [col for col in include_extra_cols if col in self.df and col not in Columns.Interactions] + elif include_extra_cols: extra_cols = [col for col in self.df if col not in Columns.Interactions] - cols_to_add.extend(extra_cols) + cols_to_add.extend(extra_cols) for col in cols_to_add: res[col] = self.df[col] diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index 56c43a6d..40993d40 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -32,6 +32,7 @@ from .negative_sampler import TransformerNegativeSamplerBase InitKwargs = tp.Dict[str, tp.Any] +# (user session, session weights, extra columns) BatchElement = tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]] @@ -126,6 +127,8 @@ class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attrib get_val_mask_func_kwargs: optional(InitKwargs), default ``None`` Additional keyword arguments for the get_val_mask_func. Make sure all dict values have JSON serializable types. + extra_cols: optional(List[str]), default ``None`` + Extra columns to keep in train and recommend datasets. """ # We sometimes need data preparators to add +1 to actual session_max_len @@ -217,10 +220,8 @@ def _filter_train_interactions(self, train_interactions: pd.DataFrame) -> pd.Dat def process_dataset_train(self, dataset: Dataset) -> None: """Process train dataset and save data.""" - if self.extra_cols is None: - raw_interactions = dataset.get_raw_interactions(include_extra_cols=False) - else: - raw_interactions = dataset.get_raw_interactions()[Columns.Interactions + self.extra_cols] + extra_cols = False if self.extra_cols is None else self.extra_cols + raw_interactions = dataset.get_raw_interactions(include_extra_cols=extra_cols) # Exclude val interaction targets from train if needed interactions = raw_interactions @@ -364,10 +365,10 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset Final item_id_map is model item_id_map constructed during training. """ # Filter interactions in dataset internal ids - if self.extra_cols is None: - interactions = dataset.interactions.df[Columns.Interactions] - else: - interactions = dataset.interactions.df[Columns.Interactions + self.extra_cols] + required_cols = Columns.Interactions + if self.extra_cols is not None: + required_cols = required_cols + self.extra_cols + interactions = dataset.interactions.df[required_cols] users_internal = dataset.user_id_map.convert_to_internal(users, strict=False) items_internal = dataset.item_id_map.convert_to_internal(self.get_known_item_ids(), strict=False) interactions = interactions[interactions[Columns.User].isin(users_internal)] @@ -410,10 +411,8 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset: Final user_id_map is the same as dataset original. Final item_id_map is model item_id_map constructed during training. """ - if self.extra_cols is None: - interactions = dataset.get_raw_interactions(include_extra_cols=False) - else: - interactions = dataset.get_raw_interactions()[Columns.Interactions + self.extra_cols] + extra_cols = False if self.extra_cols is None else self.extra_cols + interactions = dataset.get_raw_interactions(include_extra_cols=extra_cols) interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())] filtered_interactions = Interactions.from_raw( interactions, dataset.user_id_map, self.item_id_map, keep_extra_cols=True