Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions rectools/model_selection/last_n_split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 MTS (Mobile Telesystems)
# Copyright 2025 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""LastNSplitter."""

import typing as tp

Expand Down Expand Up @@ -103,8 +102,10 @@ def _split_without_filter(
df = interactions.df
idx = pd.RangeIndex(0, len(df))

# last event - rank=1
inv_ranks = df.groupby(Columns.User)[Columns.Datetime].rank(method="first", ascending=False)
# Here we guarantee that last appeared interaction in df will have lowest rank when datetime is not unique
time_order = df.groupby(Columns.User)[Columns.Datetime].rank(method="first", ascending=True).astype(int)
n_interactions = df.groupby(Columns.User).transform("size").astype(int)
inv_ranks = n_interactions - time_order + 1

for i_split in range(self.n_splits)[::-1]:
min_rank = i_split * self.n # excluded
Expand Down
2 changes: 1 addition & 1 deletion rectools/models/nn/item_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,4 +486,4 @@ def forward(self, items: torch.Tensor) -> torch.Tensor:
@property
def out_dim(self) -> int:
"""Return item net constructor output dimension."""
return self.item_net_blocks[0].out_dim # type: ignore[return-value]
return self.item_net_blocks[0].out_dim
46 changes: 46 additions & 0 deletions tests/model_selection/test_last_n_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,52 @@ def _shuffle(values: tp.Sequence[int]) -> tp.List[int]:

return _shuffle

@pytest.fixture
def interactions_equal_timestamps(self, shuffle_arr: np.ndarray) -> Interactions:
df = pd.DataFrame(
[
[1, 1, 1, "2021-09-01"], # 0
[1, 2, 1, "2021-09-02"], # 1
[1, 1, 1, "2021-09-03"], # 2
[1, 2, 1, "2021-09-04"], # 3
[1, 3, 1, "2021-09-05"], # 4
[2, 3, 1, "2021-09-05"], # 5
[2, 2, 1, "2021-08-20"], # 6
[2, 2, 1, "2021-09-06"], # 7
[3, 1, 1, "2021-09-05"], # 8
[1, 6, 1, "2021-09-05"], # 9
],
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
).astype({Columns.Datetime: "datetime64[ns]"})
return Interactions(df)

@pytest.mark.parametrize(
"swap_targets,expected_test_ids, target_item",
(
(False, {9, 7, 8}, 6),
(True, {9, 7, 8}, 3),
),
)
def test_correct_last_interactions(
self,
interactions_equal_timestamps: Interactions,
swap_targets: bool,
expected_test_ids: tp.Set[int],
target_item: int,
) -> None:
# Do not using shuffle fixture, otherwise no valid answers
interactions_et = interactions_equal_timestamps
splitter = LastNSplitter(1, 1, False, False, False)
if swap_targets:
df_swap = interactions_equal_timestamps.df
df_swap.iloc[[4, 9]] = df_swap.iloc[[9, 4]]
interactions_et = Interactions(df_swap)
loo_split = list(splitter.split(interactions_et, collect_fold_stats=True))
target_ids = loo_split[0][1]
assert set(target_ids) == expected_test_ids
assert set(loo_split[0][0]) == set(range(len(interactions_et.df))) - expected_test_ids
assert target_item in set(interactions_et.df.iloc[target_ids][Columns.Item])

@pytest.fixture
def interactions(self, shuffle_arr: np.ndarray) -> Interactions:
df = pd.DataFrame(
Expand Down
Loading