diff --git a/CHANGELOG.md b/CHANGELOG.md index ce6b2293..19fe3b0a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased + +### Added +- `leave_one_out_mask` function (`rectools.models.nn.transformers.utils.leave_one_out_mask`) for applying leave-one-out validation during transformer models training.([#292](https://github.com/MobileTeleSystems/RecTools/pull/292)) + ## [0.15.0] - 17.07.2025 ### Added diff --git a/rectools/models/nn/transformers/utils.py b/rectools/models/nn/transformers/utils.py new file mode 100644 index 00000000..8497f1b9 --- /dev/null +++ b/rectools/models/nn/transformers/utils.py @@ -0,0 +1,45 @@ +import typing as tp + +import numpy as np +import pandas as pd + +from rectools import Columns, ExternalIds + + +def leave_one_out_mask(interactions: pd.DataFrame, val_users: tp.Union[ExternalIds, int, None] = None) -> np.ndarray: + """ + Create a boolean mask for leave-one-out validation by selecting the last interaction per user. + + Identifies the most recent interaction for specified validation users based on timestamp ranking. + Users can be filtered using `val_users` parameter which supports slicing or explicit user IDs. + + Parameters + ---------- + interactions : pd.DataFrame + User-item interactions data with at least three columns: + Columns.User, Columns.Item and Columns.Datetime + val_users : Optional[Union[ExternalIds, int]], default ``None`` + Validation user filter. Can be: + - None: use all users + - int: randomly sample N users from unique user list without replacement + - array-like: explicit list of user IDs to include + + Returns + ------- + np.ndarray + Boolean array where True indicates the interaction is the last one for its user + in the validation set. + """ + groups = interactions.groupby(Columns.User) + time_order = groups[Columns.Datetime].rank(method="first", ascending=True).astype(int) + n_interactions = groups.transform("size").astype(int) + inv_ranks = n_interactions - time_order + last_interact_mask = inv_ranks == 0 + if isinstance(val_users, int): + users = interactions[Columns.User].unique() + val_users = np.random.choice(users, size=val_users, replace=False) + elif val_users is None: + return last_interact_mask.values + + mask = interactions[Columns.User].isin(val_users) & last_interact_mask + return mask.values diff --git a/tests/models/nn/transformers/test_base.py b/tests/models/nn/transformers/test_base.py index eb68a9e2..217a89a7 100644 --- a/tests/models/nn/transformers/test_base.py +++ b/tests/models/nn/transformers/test_base.py @@ -31,10 +31,11 @@ from rectools.models import BERT4RecModel, SASRecModel, load_model from rectools.models.nn.transformers.base import TransformerModelBase from rectools.models.nn.transformers.lightning import TransformerLightningModule +from rectools.models.nn.transformers.utils import leave_one_out_mask from tests.models.data import INTERACTIONS from tests.models.utils import assert_save_load_do_not_change_model -from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask +from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt def assert_torch_models_equal(model_a: nn.Module, model_b: nn.Module) -> None: @@ -200,7 +201,7 @@ def test_save_load_for_fitted_model( "model_params_update", ( { - "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", + "get_val_mask_func": "rectools.models.nn.transformers.utils.leave_one_out_mask", "get_trainer_func": "tests.models.nn.transformers.utils.custom_trainer", }, { diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index f89e82ee..b786f767 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -38,13 +38,14 @@ from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase from rectools.models.nn.transformers.similarity import DistanceSimilarityModule from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone +from rectools.models.nn.transformers.utils import leave_one_out_mask from tests.models.data import DATASET from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, assert_second_fit_refits_model, ) -from .utils import custom_trainer, leave_one_out_mask +from .utils import custom_trainer class TestBERT4RecModel: @@ -1027,7 +1028,7 @@ def test_get_config( "data_preparator_type": "rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator", "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "negative_sampler_type": "rectools.models.nn.transformers.negative_sampler.CatalogUniformSampler", - "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", + "get_val_mask_func": "rectools.models.nn.transformers.utils.leave_one_out_mask", "similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule", "backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone", } diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index d1bd5911..f947b440 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -37,6 +37,7 @@ from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers from rectools.models.nn.transformers.similarity import DistanceSimilarityModule from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone +from rectools.models.nn.transformers.utils import leave_one_out_mask from tests.models.data import DATASET from tests.models.utils import ( assert_default_config_and_default_model_params_are_the_same, @@ -44,7 +45,7 @@ ) from tests.testing_utils import assert_id_map_equal, assert_interactions_set_equal -from .utils import custom_trainer, leave_one_out_mask +from .utils import custom_trainer class TestSASRecModel: @@ -1017,7 +1018,7 @@ def test_get_config( "data_preparator_type": "rectools.models.nn.transformers.sasrec.SASRecDataPreparator", "lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule", "negative_sampler_type": "rectools.models.nn.transformers.negative_sampler.CatalogUniformSampler", - "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", + "get_val_mask_func": "rectools.models.nn.transformers.utils.leave_one_out_mask", "similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule", "backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone", } diff --git a/tests/models/nn/transformers/test_utils.py b/tests/models/nn/transformers/test_utils.py new file mode 100644 index 00000000..9bf77e7f --- /dev/null +++ b/tests/models/nn/transformers/test_utils.py @@ -0,0 +1,79 @@ +# Copyright 2022-2025 MTS (Mobile Telesystems) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing as tp + +import numpy as np +import pandas as pd +import pytest + +from rectools import Columns +from rectools.dataset import Interactions +from rectools.models.nn.transformers.utils import leave_one_out_mask + + +class TestLeaveOneOutMask: + def setup_method(self) -> None: + np.random.seed(32) + + @pytest.fixture + def interactions(self) -> Interactions: + df = pd.DataFrame( + [ + [1, 1, 1, "2021-09-01"], # 0 + [1, 2, 1, "2021-09-02"], # 1 + [1, 1, 1, "2021-09-03"], # 2 + [1, 2, 1, "2021-09-04"], # 3 + [1, 3, 1, "2021-09-05"], # 4 + [2, 3, 1, "2021-09-06"], # 5 + [2, 2, 1, "2021-08-20"], # 6 + [2, 2, 1, "2021-09-06"], # 7 + [3, 1, 1, "2021-09-05"], # 8 + [1, 6, 1, "2021-09-05"], # 9 + ], + columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], + ).astype({Columns.Datetime: "datetime64[ns]"}) + return Interactions(df) + + @pytest.mark.parametrize( + "swap_interactions,expected_val_index, expected_val_item, val_users", + ( + ([9, 9], [7, 8, 9], 6, None), + ([9, 9], [7, 8, 9], 6, 3), + ([9, 9], [8, 9], 6, 2), + ([4, 9], [7, 8, 9], 3, None), + ([4, 9], [7, 8, 9], 3, 3), + ([4, 9], [8, 9], 3, 2), + ([7, 7], [7, 8], 2, [2, 3]), + ([5, 7], [7, 8], 3, [2, 3]), + ([8, 8], [8], 1, [3]), + ), + ) + def test_correct_last_interactions( + self, + interactions: Interactions, + swap_interactions: tuple, + expected_val_index: tp.List[int], + expected_val_item: int, + val_users: tp.Optional[tp.List[int]], + ) -> None: + interactions_df = interactions.df + swap_revert = swap_interactions[::-1] + interactions_df.iloc[swap_interactions] = interactions_df.iloc[swap_revert] + val_mask = leave_one_out_mask(interactions_df, val_users) + val_interactions = interactions_df[val_mask] + last_index = max(swap_interactions) + + assert list(val_interactions.index) == expected_val_index + assert val_interactions.loc[last_index, [Columns.Item]].values[0] == expected_val_item diff --git a/tests/models/nn/transformers/utils.py b/tests/models/nn/transformers/utils.py index 7f6954a6..57fc1982 100644 --- a/tests/models/nn/transformers/utils.py +++ b/tests/models/nn/transformers/utils.py @@ -12,21 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pandas as pd from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint -from rectools import Columns - - -def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series: - rank = ( - interactions.sort_values(Columns.Datetime, ascending=False, kind="stable") - .groupby(Columns.User, sort=False) - .cumcount() - ) - return rank == 0 - def custom_trainer() -> Trainer: return Trainer(