Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased

### Added
- `leave_one_out_mask` function (`rectools.models.nn.transformers.utils.leave_one_out_mask`) for applying leave-one-out validation during transformer models training.([#292](https://github.com/MobileTeleSystems/RecTools/pull/292))

## [0.15.0] - 17.07.2025

### Added
Expand Down
47 changes: 47 additions & 0 deletions rectools/models/nn/transformers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import typing as tp

import numpy as np
import pandas as pd

from rectools import Columns, ExternalIds


def leave_one_out_mask(
interactions: pd.DataFrame, val_users: tp.Optional[tp.Union[ExternalIds, int]] = None
) -> np.ndarray:
"""
Create a boolean mask for leave-one-out validation by selecting the last interaction per user.

Identifies the most recent interaction for specified validation users based on timestamp ranking.
Users can be filtered using `val_users` parameter which supports slicing or explicit user IDs.

Parameters
----------
interactions : pd.DataFrame
User-item interactions data with at least three columns:
Columns.User, Columns.Item and Columns.Datetime
val_users : Optional[Union[ExternalIds, int]], default ``None``
Validation user filter. Can be:
- None: use all users
- int: take first N users from unique user list
- array-like: explicit list of user IDs to include

Returns
-------
np.ndarray
Boolean array where True indicates the interaction is the last one for its user
in the validation set.
"""
groups = interactions.groupby(Columns.User)
time_order = groups[Columns.Datetime].rank(method="first", ascending=True).astype(int)
n_interactions = groups.transform("size").astype(int)
inv_ranks = n_interactions - time_order
last_interact_mask = inv_ranks == 0
if isinstance(val_users, int):
users = interactions[Columns.User].unique()
val_users = users[:val_users]
elif val_users is None:
return last_interact_mask.values

mask = (interactions[Columns.User].isin(val_users)) & last_interact_mask
return mask.values
5 changes: 3 additions & 2 deletions tests/models/nn/transformers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
from rectools.models import BERT4RecModel, SASRecModel, load_model
from rectools.models.nn.transformers.base import TransformerModelBase
from rectools.models.nn.transformers.lightning import TransformerLightningModule
from rectools.models.nn.transformers.utils import leave_one_out_mask
from tests.models.data import INTERACTIONS
from tests.models.utils import assert_save_load_do_not_change_model

from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask
from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt


def assert_torch_models_equal(model_a: nn.Module, model_b: nn.Module) -> None:
Expand Down Expand Up @@ -200,7 +201,7 @@ def test_save_load_for_fitted_model(
"model_params_update",
(
{
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
"get_val_mask_func": "rectools.models.nn.transformers.utils.leave_one_out_mask",
"get_trainer_func": "tests.models.nn.transformers.utils.custom_trainer",
},
{
Expand Down
5 changes: 3 additions & 2 deletions tests/models/nn/transformers/test_bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@
from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
from rectools.models.nn.transformers.similarity import DistanceSimilarityModule
from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone
from rectools.models.nn.transformers.utils import leave_one_out_mask
from tests.models.data import DATASET
from tests.models.utils import (
assert_default_config_and_default_model_params_are_the_same,
assert_second_fit_refits_model,
)

from .utils import custom_trainer, leave_one_out_mask
from .utils import custom_trainer


class TestBERT4RecModel:
Expand Down Expand Up @@ -1027,7 +1028,7 @@ def test_get_config(
"data_preparator_type": "rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator",
"lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule",
"negative_sampler_type": "rectools.models.nn.transformers.negative_sampler.CatalogUniformSampler",
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
"get_val_mask_func": "rectools.models.nn.transformers.utils.leave_one_out_mask",
"similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule",
"backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone",
}
Expand Down
5 changes: 3 additions & 2 deletions tests/models/nn/transformers/test_sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@
from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers
from rectools.models.nn.transformers.similarity import DistanceSimilarityModule
from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone
from rectools.models.nn.transformers.utils import leave_one_out_mask
from tests.models.data import DATASET
from tests.models.utils import (
assert_default_config_and_default_model_params_are_the_same,
assert_second_fit_refits_model,
)
from tests.testing_utils import assert_id_map_equal, assert_interactions_set_equal

from .utils import custom_trainer, leave_one_out_mask
from .utils import custom_trainer


class TestSASRecModel:
Expand Down Expand Up @@ -1017,7 +1018,7 @@ def test_get_config(
"data_preparator_type": "rectools.models.nn.transformers.sasrec.SASRecDataPreparator",
"lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule",
"negative_sampler_type": "rectools.models.nn.transformers.negative_sampler.CatalogUniformSampler",
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
"get_val_mask_func": "rectools.models.nn.transformers.utils.leave_one_out_mask",
"similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule",
"backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone",
}
Expand Down
75 changes: 75 additions & 0 deletions tests/models/nn/transformers/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2022-2025 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import typing as tp

import pandas as pd
import pytest

from rectools import Columns
from rectools.dataset import Interactions
from rectools.models.nn.transformers.utils import leave_one_out_mask


class TestLeaveOneOutMask:
@pytest.fixture
def interactions(self) -> Interactions:
df = pd.DataFrame(
[
[1, 1, 1, "2021-09-01"], # 0
[1, 2, 1, "2021-09-02"], # 1
[1, 1, 1, "2021-09-03"], # 2
[1, 2, 1, "2021-09-04"], # 3
[1, 3, 1, "2021-09-05"], # 4
[2, 3, 1, "2021-09-06"], # 5
[2, 2, 1, "2021-08-20"], # 6
[2, 2, 1, "2021-09-06"], # 7
[3, 1, 1, "2021-09-05"], # 8
[1, 6, 1, "2021-09-05"], # 9
],
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
).astype({Columns.Datetime: "datetime64[ns]"})
return Interactions(df)

@pytest.mark.parametrize(
"swap_interactions,expected_val_index, expected_val_item, val_users",
(
([9, 9], [7, 8, 9], 6, None),
([9, 9], [7, 8, 9], 6, 3),
([9, 9], [7, 9], 6, 2),
([4, 9], [7, 8, 9], 3, None),
([4, 9], [7, 8, 9], 3, 3),
([4, 9], [7, 9], 3, 2),
([7, 7], [7, 8], 2, [2, 3]),
([5, 7], [7, 8], 3, [2, 3]),
([8, 8], [8], 1, [3]),
),
)
def test_correct_last_interactions(
self,
interactions: Interactions,
swap_interactions: tuple,
expected_val_index: tp.List[int],
expected_val_item: int,
val_users: tp.Optional[tp.List[int]],
) -> None:
interactions_df = interactions.df
swap_revert = swap_interactions[::-1]
interactions_df.iloc[swap_interactions] = interactions_df.iloc[swap_revert]
val_mask = leave_one_out_mask(interactions_df, val_users)
val_interactions = interactions_df[val_mask]
last_index = max(swap_interactions)

assert list(val_interactions.index) == expected_val_index
assert val_interactions.loc[last_index, [Columns.Item]].values[0] == expected_val_item
12 changes: 0 additions & 12 deletions tests/models/nn/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pandas as pd
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from rectools import Columns


def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series:
rank = (
interactions.sort_values(Columns.Datetime, ascending=False, kind="stable")
.groupby(Columns.User, sort=False)
.cumcount()
)
return rank == 0


def custom_trainer() -> Trainer:
return Trainer(
Expand Down