From 735e2d7b04d26ce659b7aff8814c63f87a7e2d96 Mon Sep 17 00:00:00 2001 From: Aleksey Kuzin Date: Tue, 1 Jul 2025 11:52:06 +0300 Subject: [PATCH 1/7] new splitter & test --- rectools/model_selection/last_n_split.py | 14 ++++-- tests/model_selection/test_last_n_split.py | 53 ++++++++++++++++++++++ 2 files changed, 62 insertions(+), 5 deletions(-) diff --git a/rectools/model_selection/last_n_split.py b/rectools/model_selection/last_n_split.py index 4b3db4a7..aa5dc6ea 100644 --- a/rectools/model_selection/last_n_split.py +++ b/rectools/model_selection/last_n_split.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 MTS (Mobile Telesystems) +# Copyright 2025 MTS (Mobile Telesystems) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""LastNSplitter.""" import typing as tp import numpy as np import pandas as pd - from rectools import Columns from rectools.dataset import Interactions from rectools.model_selection.splitter import Splitter @@ -103,8 +101,14 @@ def _split_without_filter( df = interactions.df idx = pd.RangeIndex(0, len(df)) - # last event - rank=1 - inv_ranks = df.groupby(Columns.User)[Columns.Datetime].rank(method="first", ascending=False) + # Here we guarantee that last appeared interaction in df will have lowest rank when datetime is not unique + time_order = ( + df.groupby(Columns.User)[Columns.Datetime] + .rank(method="first", ascending=True) + .astype(int) + ) + n_interactions = df.groupby(Columns.User).transform("size").astype(int) + inv_ranks = n_interactions - time_order + 1 for i_split in range(self.n_splits)[::-1]: min_rank = i_split * self.n # excluded diff --git a/tests/model_selection/test_last_n_split.py b/tests/model_selection/test_last_n_split.py index bb02a928..aea81200 100644 --- a/tests/model_selection/test_last_n_split.py +++ b/tests/model_selection/test_last_n_split.py @@ -40,6 +40,59 @@ def _shuffle(values: tp.Sequence[int]) -> tp.List[int]: return sorted(inv_shuffle_arr[values]) return _shuffle + @pytest.fixture + def interactions_equal_timestamps(self, shuffle_arr: np.ndarray) -> Interactions: + df = pd.DataFrame( + [ + [1, 1, 1, "2021-09-01"], # 0 + [1, 2, 1, "2021-09-02"], # 1 + [1, 1, 1, "2021-09-03"], # 2 + [1, 2, 1, "2021-09-04"], # 3 + [1, 3, 1, "2021-09-05"], # 4 + [2, 3, 1, "2021-09-05"], # 5 + [2, 2, 1, "2021-08-20"], # 6 + [2, 2, 1, "2021-09-06"], # 7 + [3, 1, 1, "2021-09-05"], # 8 + [1, 6, 1, "2021-09-05"], # 9 + ], + columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], + ).astype({Columns.Datetime: "datetime64[ns]"}) + return Interactions(df) + + @pytest.mark.parametrize( + "swap_targets,expected_test_ids, target_item", + ( + ( + False, + {9, 7, 8}, + 6 + ), + ( + True, + {9, 7, 8}, + 3 + ), + ), + ) + def test_correct_last_interactions( + self, + interactions_equal_timestamps: Interactions, + swap_targets: bool, + expected_test_ids: tp.List[int], + target_item: int, + ) -> None: + # Do not using shuffle fixture, otherwise no valid answers + interactions_et = interactions_equal_timestamps + splitter = LastNSplitter(1, 1, False, False, False) + if swap_targets: + df_swap = interactions_equal_timestamps.df + df_swap.iloc[[4,9]] = df_swap.iloc[[9,4]] + interactions_et = Interactions(df_swap) + loo_split = list(splitter.split(interactions_et, collect_fold_stats=True)) + target_ids = loo_split[0][1] + assert set(target_ids) == expected_test_ids + assert set(loo_split[0][0]) == set(range(len(interactions_et.df))) - expected_test_ids + assert target_item in set(interactions_et.df.iloc[target_ids][Columns.Item]) @pytest.fixture def interactions(self, shuffle_arr: np.ndarray) -> Interactions: From 2ee023a43266b9c2dd9a1bed7664ada8dadf32c4 Mon Sep 17 00:00:00 2001 From: Aleksey Kuzin Date: Tue, 1 Jul 2025 11:59:17 +0300 Subject: [PATCH 2/7] + linter --- rectools/model_selection/last_n_split.py | 7 ++----- rectools/models/nn/item_net.py | 2 +- tests/model_selection/test_last_n_split.py | 21 +++++++-------------- 3 files changed, 10 insertions(+), 20 deletions(-) diff --git a/rectools/model_selection/last_n_split.py b/rectools/model_selection/last_n_split.py index aa5dc6ea..41a08d12 100644 --- a/rectools/model_selection/last_n_split.py +++ b/rectools/model_selection/last_n_split.py @@ -17,6 +17,7 @@ import numpy as np import pandas as pd + from rectools import Columns from rectools.dataset import Interactions from rectools.model_selection.splitter import Splitter @@ -102,11 +103,7 @@ def _split_without_filter( idx = pd.RangeIndex(0, len(df)) # Here we guarantee that last appeared interaction in df will have lowest rank when datetime is not unique - time_order = ( - df.groupby(Columns.User)[Columns.Datetime] - .rank(method="first", ascending=True) - .astype(int) - ) + time_order = df.groupby(Columns.User)[Columns.Datetime].rank(method="first", ascending=True).astype(int) n_interactions = df.groupby(Columns.User).transform("size").astype(int) inv_ranks = n_interactions - time_order + 1 diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index cf0be461..65c2f98f 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -486,4 +486,4 @@ def forward(self, items: torch.Tensor) -> torch.Tensor: @property def out_dim(self) -> int: """Return item net constructor output dimension.""" - return self.item_net_blocks[0].out_dim # type: ignore[return-value] + return self.item_net_blocks[0].out_dim diff --git a/tests/model_selection/test_last_n_split.py b/tests/model_selection/test_last_n_split.py index aea81200..b1cab308 100644 --- a/tests/model_selection/test_last_n_split.py +++ b/tests/model_selection/test_last_n_split.py @@ -40,6 +40,7 @@ def _shuffle(values: tp.Sequence[int]) -> tp.List[int]: return sorted(inv_shuffle_arr[values]) return _shuffle + @pytest.fixture def interactions_equal_timestamps(self, shuffle_arr: np.ndarray) -> Interactions: df = pd.DataFrame( @@ -62,23 +63,15 @@ def interactions_equal_timestamps(self, shuffle_arr: np.ndarray) -> Interactions @pytest.mark.parametrize( "swap_targets,expected_test_ids, target_item", ( - ( - False, - {9, 7, 8}, - 6 - ), - ( - True, - {9, 7, 8}, - 3 - ), + (False, {9, 7, 8}, 6), + (True, {9, 7, 8}, 3), ), ) def test_correct_last_interactions( self, interactions_equal_timestamps: Interactions, swap_targets: bool, - expected_test_ids: tp.List[int], + expected_test_ids: tp.Set[int], target_item: int, ) -> None: # Do not using shuffle fixture, otherwise no valid answers @@ -86,13 +79,13 @@ def test_correct_last_interactions( splitter = LastNSplitter(1, 1, False, False, False) if swap_targets: df_swap = interactions_equal_timestamps.df - df_swap.iloc[[4,9]] = df_swap.iloc[[9,4]] + df_swap.iloc[[4, 9]] = df_swap.iloc[[9, 4]] interactions_et = Interactions(df_swap) loo_split = list(splitter.split(interactions_et, collect_fold_stats=True)) target_ids = loo_split[0][1] assert set(target_ids) == expected_test_ids - assert set(loo_split[0][0]) == set(range(len(interactions_et.df))) - expected_test_ids - assert target_item in set(interactions_et.df.iloc[target_ids][Columns.Item]) + assert set(loo_split[0][0]) == set(range(len(interactions_et.df))) - expected_test_ids + assert target_item in set(interactions_et.df.iloc[target_ids][Columns.Item]) @pytest.fixture def interactions(self, shuffle_arr: np.ndarray) -> Interactions: From 39e3c1e9130d1f9cc390ca26cbce17efecb83100 Mon Sep 17 00:00:00 2001 From: Aleksey Kuzin Date: Wed, 2 Jul 2025 18:48:43 +0300 Subject: [PATCH 3/7] minor fix & changlog --- CHANGELOG.md | 6 ++++++ rectools/model_selection/last_n_split.py | 5 +++-- rectools/models/nn/item_net.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 558ee220..6665cd9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Fixed + +- [Breaking] Now `LastNSplitter` guarantees correct splitting in case of identical timestamps ([#288](https://github.com/MobileTeleSystems/RecTools/pull/288)) + ## [0.14.0] - 16.05.2025 ### Added diff --git a/rectools/model_selection/last_n_split.py b/rectools/model_selection/last_n_split.py index 41a08d12..2d819ca9 100644 --- a/rectools/model_selection/last_n_split.py +++ b/rectools/model_selection/last_n_split.py @@ -103,8 +103,9 @@ def _split_without_filter( idx = pd.RangeIndex(0, len(df)) # Here we guarantee that last appeared interaction in df will have lowest rank when datetime is not unique - time_order = df.groupby(Columns.User)[Columns.Datetime].rank(method="first", ascending=True).astype(int) - n_interactions = df.groupby(Columns.User).transform("size").astype(int) + grouped = df.groupby(Columns.User) + time_order = grouped[Columns.Datetime].rank(method="first", ascending=True).astype(int) + n_interactions = grouped[Columns.User].transform('size').astype(int) inv_ranks = n_interactions - time_order + 1 for i_split in range(self.n_splits)[::-1]: diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index 65c2f98f..c0be8f4b 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -230,7 +230,7 @@ def from_dataset_schema( @property def out_dim(self) -> int: """Return categorical item embedding output dimension.""" - return self.embedding_bag.embedding_dim + return int(self.embedding_bag.embedding_dim) class IdEmbeddingsItemNet(ItemNetBase): From 61f41a43d13ba7264a86c8f6ac8bfde7afebead0 Mon Sep 17 00:00:00 2001 From: Aleksey Kuzin Date: Wed, 2 Jul 2025 19:02:03 +0300 Subject: [PATCH 4/7] +linter --- rectools/model_selection/last_n_split.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rectools/model_selection/last_n_split.py b/rectools/model_selection/last_n_split.py index 2d819ca9..f2f527b2 100644 --- a/rectools/model_selection/last_n_split.py +++ b/rectools/model_selection/last_n_split.py @@ -105,7 +105,7 @@ def _split_without_filter( # Here we guarantee that last appeared interaction in df will have lowest rank when datetime is not unique grouped = df.groupby(Columns.User) time_order = grouped[Columns.Datetime].rank(method="first", ascending=True).astype(int) - n_interactions = grouped[Columns.User].transform('size').astype(int) + n_interactions = grouped[Columns.User].transform("size").astype(int) inv_ranks = n_interactions - time_order + 1 for i_split in range(self.n_splits)[::-1]: From 92c53e6be4bff474ff9cc70dcf13ed4dc1c222f7 Mon Sep 17 00:00:00 2001 From: Aleksey Kuzin Date: Thu, 3 Jul 2025 12:24:46 +0300 Subject: [PATCH 5/7] Feature/correct_splitter(#288) - Fixed `LasNSplitter` class. --- rectools/models/nn/item_net.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index c0be8f4b..5020a433 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -486,4 +486,4 @@ def forward(self, items: torch.Tensor) -> torch.Tensor: @property def out_dim(self) -> int: """Return item net constructor output dimension.""" - return self.item_net_blocks[0].out_dim + return self.item_net_blocks[0].out_dim # type: ignore[return-value] From bdcf2dcdc91e15416a3edc0a81fb81716d8eb0b6 Mon Sep 17 00:00:00 2001 From: Aleksey Kuzin Date: Thu, 10 Jul 2025 17:29:56 +0300 Subject: [PATCH 6/7] Feature/correct_splitter(#288) - Fixed `LasNSplitter` class. --- CHANGELOG.md | 5 ++--- tests/model_selection/test_last_n_split.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6a819366..80742792 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,13 +5,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -## [Unreleased] +## Unreleased ### Added - `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287)) ### Fixed - -- [Breaking] Now `LastNSplitter` class guarantees correct splitting in case of identical timestamps ([#288](https://github.com/MobileTeleSystems/RecTools/pull/288)) +- [Breaking] Now `LastNSplitter` guarantees taking the last ordered interaction in dataframe in case of identical timestamps ([#288](https://github.com/MobileTeleSystems/RecTools/pull/288)) ## [0.14.0] - 16.05.2025 diff --git a/tests/model_selection/test_last_n_split.py b/tests/model_selection/test_last_n_split.py index b1cab308..6ce2093c 100644 --- a/tests/model_selection/test_last_n_split.py +++ b/tests/model_selection/test_last_n_split.py @@ -1,4 +1,4 @@ -# Copyright 2023 MTS (Mobile Telesystems) +# Copyright 2023-2025 MTS (Mobile Telesystems) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. From c41bb441cd1bb328e1bfe2027ed919697d77d067 Mon Sep 17 00:00:00 2001 From: Aleksey Kuzin Date: Thu, 10 Jul 2025 17:51:26 +0300 Subject: [PATCH 7/7] Feature/correct_splitter(#288) - Fixed `LasNSplitter` class --- rectools/model_selection/last_n_split.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rectools/model_selection/last_n_split.py b/rectools/model_selection/last_n_split.py index f2f527b2..44faa0ed 100644 --- a/rectools/model_selection/last_n_split.py +++ b/rectools/model_selection/last_n_split.py @@ -1,4 +1,4 @@ -# Copyright 2025 MTS (Mobile Telesystems) +# Copyright 2023-2025 MTS (Mobile Telesystems) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License.