Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287))

### Fixed
- [Breaking] Now `LastNSplitter` guarantees taking the last ordered interaction in dataframe in case of identical timestamps ([#288](https://github.com/MobileTeleSystems/RecTools/pull/288))

## [0.14.0] - 16.05.2025

### Added
Expand Down
10 changes: 6 additions & 4 deletions rectools/model_selection/last_n_split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 MTS (Mobile Telesystems)
# Copyright 2023-2025 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""LastNSplitter."""

import typing as tp

Expand Down Expand Up @@ -103,8 +102,11 @@ def _split_without_filter(
df = interactions.df
idx = pd.RangeIndex(0, len(df))

# last event - rank=1
inv_ranks = df.groupby(Columns.User)[Columns.Datetime].rank(method="first", ascending=False)
# Here we guarantee that last appeared interaction in df will have lowest rank when datetime is not unique
grouped = df.groupby(Columns.User)
time_order = grouped[Columns.Datetime].rank(method="first", ascending=True).astype(int)
n_interactions = grouped[Columns.User].transform("size").astype(int)
inv_ranks = n_interactions - time_order + 1

for i_split in range(self.n_splits)[::-1]:
min_rank = i_split * self.n # excluded
Expand Down
2 changes: 1 addition & 1 deletion rectools/models/nn/item_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def from_dataset_schema(
@property
def out_dim(self) -> int:
"""Return categorical item embedding output dimension."""
return self.embedding_bag.embedding_dim
return int(self.embedding_bag.embedding_dim)


class IdEmbeddingsItemNet(ItemNetBase):
Expand Down
48 changes: 47 additions & 1 deletion tests/model_selection/test_last_n_split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 MTS (Mobile Telesystems)
# Copyright 2023-2025 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,6 +41,52 @@ def _shuffle(values: tp.Sequence[int]) -> tp.List[int]:

return _shuffle

@pytest.fixture
def interactions_equal_timestamps(self, shuffle_arr: np.ndarray) -> Interactions:
df = pd.DataFrame(
[
[1, 1, 1, "2021-09-01"], # 0
[1, 2, 1, "2021-09-02"], # 1
[1, 1, 1, "2021-09-03"], # 2
[1, 2, 1, "2021-09-04"], # 3
[1, 3, 1, "2021-09-05"], # 4
[2, 3, 1, "2021-09-05"], # 5
[2, 2, 1, "2021-08-20"], # 6
[2, 2, 1, "2021-09-06"], # 7
[3, 1, 1, "2021-09-05"], # 8
[1, 6, 1, "2021-09-05"], # 9
],
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
).astype({Columns.Datetime: "datetime64[ns]"})
return Interactions(df)

@pytest.mark.parametrize(
"swap_targets,expected_test_ids, target_item",
(
(False, {9, 7, 8}, 6),
(True, {9, 7, 8}, 3),
),
)
def test_correct_last_interactions(
self,
interactions_equal_timestamps: Interactions,
swap_targets: bool,
expected_test_ids: tp.Set[int],
target_item: int,
) -> None:
# Do not using shuffle fixture, otherwise no valid answers
interactions_et = interactions_equal_timestamps
splitter = LastNSplitter(1, 1, False, False, False)
if swap_targets:
df_swap = interactions_equal_timestamps.df
df_swap.iloc[[4, 9]] = df_swap.iloc[[9, 4]]
interactions_et = Interactions(df_swap)
loo_split = list(splitter.split(interactions_et, collect_fold_stats=True))
target_ids = loo_split[0][1]
assert set(target_ids) == expected_test_ids
assert set(loo_split[0][0]) == set(range(len(interactions_et.df))) - expected_test_ids
assert target_item in set(interactions_et.df.iloc[target_ids][Columns.Item])

@pytest.fixture
def interactions(self, shuffle_arr: np.ndarray) -> Interactions:
df = pd.DataFrame(
Expand Down