diff --git a/CHANGELOG.md b/CHANGELOG.md index acc13ced..80742792 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287)) +### Fixed +- [Breaking] Now `LastNSplitter` guarantees taking the last ordered interaction in dataframe in case of identical timestamps ([#288](https://github.com/MobileTeleSystems/RecTools/pull/288)) + ## [0.14.0] - 16.05.2025 ### Added diff --git a/rectools/model_selection/last_n_split.py b/rectools/model_selection/last_n_split.py index 4b3db4a7..44faa0ed 100644 --- a/rectools/model_selection/last_n_split.py +++ b/rectools/model_selection/last_n_split.py @@ -1,4 +1,4 @@ -# Copyright 2023-2024 MTS (Mobile Telesystems) +# Copyright 2023-2025 MTS (Mobile Telesystems) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""LastNSplitter.""" import typing as tp @@ -103,8 +102,11 @@ def _split_without_filter( df = interactions.df idx = pd.RangeIndex(0, len(df)) - # last event - rank=1 - inv_ranks = df.groupby(Columns.User)[Columns.Datetime].rank(method="first", ascending=False) + # Here we guarantee that last appeared interaction in df will have lowest rank when datetime is not unique + grouped = df.groupby(Columns.User) + time_order = grouped[Columns.Datetime].rank(method="first", ascending=True).astype(int) + n_interactions = grouped[Columns.User].transform("size").astype(int) + inv_ranks = n_interactions - time_order + 1 for i_split in range(self.n_splits)[::-1]: min_rank = i_split * self.n # excluded diff --git a/rectools/models/nn/item_net.py b/rectools/models/nn/item_net.py index cf0be461..5020a433 100644 --- a/rectools/models/nn/item_net.py +++ b/rectools/models/nn/item_net.py @@ -230,7 +230,7 @@ def from_dataset_schema( @property def out_dim(self) -> int: """Return categorical item embedding output dimension.""" - return self.embedding_bag.embedding_dim + return int(self.embedding_bag.embedding_dim) class IdEmbeddingsItemNet(ItemNetBase): diff --git a/tests/model_selection/test_last_n_split.py b/tests/model_selection/test_last_n_split.py index bb02a928..6ce2093c 100644 --- a/tests/model_selection/test_last_n_split.py +++ b/tests/model_selection/test_last_n_split.py @@ -1,4 +1,4 @@ -# Copyright 2023 MTS (Mobile Telesystems) +# Copyright 2023-2025 MTS (Mobile Telesystems) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -41,6 +41,52 @@ def _shuffle(values: tp.Sequence[int]) -> tp.List[int]: return _shuffle + @pytest.fixture + def interactions_equal_timestamps(self, shuffle_arr: np.ndarray) -> Interactions: + df = pd.DataFrame( + [ + [1, 1, 1, "2021-09-01"], # 0 + [1, 2, 1, "2021-09-02"], # 1 + [1, 1, 1, "2021-09-03"], # 2 + [1, 2, 1, "2021-09-04"], # 3 + [1, 3, 1, "2021-09-05"], # 4 + [2, 3, 1, "2021-09-05"], # 5 + [2, 2, 1, "2021-08-20"], # 6 + [2, 2, 1, "2021-09-06"], # 7 + [3, 1, 1, "2021-09-05"], # 8 + [1, 6, 1, "2021-09-05"], # 9 + ], + columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime], + ).astype({Columns.Datetime: "datetime64[ns]"}) + return Interactions(df) + + @pytest.mark.parametrize( + "swap_targets,expected_test_ids, target_item", + ( + (False, {9, 7, 8}, 6), + (True, {9, 7, 8}, 3), + ), + ) + def test_correct_last_interactions( + self, + interactions_equal_timestamps: Interactions, + swap_targets: bool, + expected_test_ids: tp.Set[int], + target_item: int, + ) -> None: + # Do not using shuffle fixture, otherwise no valid answers + interactions_et = interactions_equal_timestamps + splitter = LastNSplitter(1, 1, False, False, False) + if swap_targets: + df_swap = interactions_equal_timestamps.df + df_swap.iloc[[4, 9]] = df_swap.iloc[[9, 4]] + interactions_et = Interactions(df_swap) + loo_split = list(splitter.split(interactions_et, collect_fold_stats=True)) + target_ids = loo_split[0][1] + assert set(target_ids) == expected_test_ids + assert set(loo_split[0][0]) == set(range(len(interactions_et.df))) - expected_test_ids + assert target_item in set(interactions_et.df.iloc[target_ids][Columns.Item]) + @pytest.fixture def interactions(self, shuffle_arr: np.ndarray) -> Interactions: df = pd.DataFrame(