Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions rectools/models/nn/transformers/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import typing as tp

import numpy as np
import pandas as pd

from rectools import Columns, ExternalIds


def leave_one_out_mask(
interactions: pd.DataFrame, val_users: tp.Optional[tp.Union[ExternalIds, int]] = None
) -> np.ndarray:
"""
Create a boolean mask for leave-one-out validation by selecting the last interaction per user.
Identifies the most recent interaction for specified validation users based on timestamp ranking.
Users can be filtered using `val_users` parameter which supports slicing or explicit user IDs.
Parameters
----------
interactions : pd.DataFrame
User-item interaction data with at least two columns:
val_users : Optional[Union[ExternalIds, int]]
Validation user filter. Can be:
- None: use all users
- int: take first N users from unique user list
- array-like: explicit list of user IDs to include
Returns
-------
np.ndarray
Boolean array where True indicates the interaction is the last one for its user
in the validation set.
"""
groups = interactions.groupby(Columns.User)
time_order = groups[Columns.Datetime].rank(method="first", ascending=True).astype(int)
n_interactions = groups.transform("size").astype(int)
inv_ranks = n_interactions - time_order
last_interact_mask = inv_ranks == 0
users = interactions[Columns.User].unique()
if isinstance(val_users, int):
val_users = users[:val_users]
elif val_users is None:
val_users = users

mask = (interactions[Columns.User].isin(val_users)) & last_interact_mask
return mask.values
5 changes: 3 additions & 2 deletions tests/models/nn/transformers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,11 @@
from rectools.models import BERT4RecModel, SASRecModel, load_model
from rectools.models.nn.transformers.base import TransformerModelBase
from rectools.models.nn.transformers.lightning import TransformerLightningModule
from rectools.models.nn.transformers.utils import leave_one_out_mask
from tests.models.data import INTERACTIONS
from tests.models.utils import assert_save_load_do_not_change_model

from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask
from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt


def assert_torch_models_equal(model_a: nn.Module, model_b: nn.Module) -> None:
Expand Down Expand Up @@ -200,7 +201,7 @@ def test_save_load_for_fitted_model(
"model_params_update",
(
{
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
"get_val_mask_func": "rectools.models.nn.transformers.utils.leave_one_out_mask",
"get_trainer_func": "tests.models.nn.transformers.utils.custom_trainer",
},
{
Expand Down
5 changes: 3 additions & 2 deletions tests/models/nn/transformers/test_bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@
from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
from rectools.models.nn.transformers.similarity import DistanceSimilarityModule
from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone
from rectools.models.nn.transformers.utils import leave_one_out_mask
from tests.models.data import DATASET
from tests.models.utils import (
assert_default_config_and_default_model_params_are_the_same,
assert_second_fit_refits_model,
)

from .utils import custom_trainer, leave_one_out_mask
from .utils import custom_trainer


class TestBERT4RecModel:
Expand Down Expand Up @@ -1027,7 +1028,7 @@ def test_get_config(
"data_preparator_type": "rectools.models.nn.transformers.bert4rec.BERT4RecDataPreparator",
"lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule",
"negative_sampler_type": "rectools.models.nn.transformers.negative_sampler.CatalogUniformSampler",
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
"get_val_mask_func": "rectools.models.nn.transformers.utils.leave_one_out_mask",
"similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule",
"backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone",
}
Expand Down
5 changes: 3 additions & 2 deletions tests/models/nn/transformers/test_sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,15 @@
from rectools.models.nn.transformers.sasrec import SASRecDataPreparator, SASRecTransformerLayers
from rectools.models.nn.transformers.similarity import DistanceSimilarityModule
from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone
from rectools.models.nn.transformers.utils import leave_one_out_mask
from tests.models.data import DATASET
from tests.models.utils import (
assert_default_config_and_default_model_params_are_the_same,
assert_second_fit_refits_model,
)
from tests.testing_utils import assert_id_map_equal, assert_interactions_set_equal

from .utils import custom_trainer, leave_one_out_mask
from .utils import custom_trainer


class TestSASRecModel:
Expand Down Expand Up @@ -1017,7 +1018,7 @@ def test_get_config(
"data_preparator_type": "rectools.models.nn.transformers.sasrec.SASRecDataPreparator",
"lightning_module_type": "rectools.models.nn.transformers.lightning.TransformerLightningModule",
"negative_sampler_type": "rectools.models.nn.transformers.negative_sampler.CatalogUniformSampler",
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
"get_val_mask_func": "rectools.models.nn.transformers.utils.leave_one_out_mask",
"similarity_module_type": "rectools.models.nn.transformers.similarity.DistanceSimilarityModule",
"backbone_type": "rectools.models.nn.transformers.torch_backbone.TransformerTorchBackbone",
}
Expand Down
75 changes: 75 additions & 0 deletions tests/models/nn/transformers/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright 2022-2025 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: disable=attribute-defined-outside-init

import typing as tp

import pandas as pd
import pytest

from rectools import Columns
from rectools.dataset import Interactions
from rectools.models.nn.transformers.utils import leave_one_out_mask


class TestUtils:
@pytest.fixture
def interactions(self) -> Interactions:
df = pd.DataFrame(
[
[1, 1, 1, "2021-09-01"], # 0
[1, 2, 1, "2021-09-02"], # 1
[1, 1, 1, "2021-09-03"], # 2
[1, 2, 1, "2021-09-04"], # 3
[1, 3, 1, "2021-09-05"], # 4
[2, 3, 1, "2021-09-06"], # 5
[2, 2, 1, "2021-08-20"], # 6
[2, 2, 1, "2021-09-06"], # 7
[3, 1, 1, "2021-09-05"], # 8
[1, 6, 1, "2021-09-05"], # 9
],
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
).astype({Columns.Datetime: "datetime64[ns]"})
return Interactions(df)

@pytest.mark.parametrize(
"swap_interactions,expected_val_index, expected_val_item, val_users",
(
([9, 9], [7, 8, 9], 6, None),
([9, 9], [7, 8, 9], 6, 3),
([4, 9], [7, 8, 9], 3, None),
([4, 9], [7, 8, 9], 3, 3),
([7, 7], [7, 8], 2, [2, 3]),
([5, 7], [7, 8], 3, [2, 3]),
([8, 8], [8], 1, [3]),
),
)
def test_correct_last_interactions(
self,
interactions: Interactions,
swap_interactions: tuple,
expected_val_index: tp.List[int],
expected_val_item: int,
val_users: tp.Optional[tp.List[int]],
) -> None:
interactions_df = interactions.df
swap_revert = swap_interactions[::-1]
interactions_df.iloc[swap_interactions] = interactions_df.iloc[swap_revert]
val_mask = leave_one_out_mask(interactions_df, val_users)
val_interactions = interactions_df[val_mask]
last_index = max(swap_interactions)

assert list(val_interactions.index) == expected_val_index
assert val_interactions.loc[last_index, [Columns.Item]].values[0] == expected_val_item
12 changes: 0 additions & 12 deletions tests/models/nn/transformers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pandas as pd
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from rectools import Columns


def leave_one_out_mask(interactions: pd.DataFrame) -> pd.Series:
rank = (
interactions.sort_values(Columns.Datetime, ascending=False, kind="stable")
.groupby(Columns.User, sort=False)
.cumcount()
)
return rank == 0


def custom_trainer() -> Trainer:
return Trainer(
Expand Down