Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
- `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287))

### Fixed
- [Breaking] Now `LastNSplitter` guarantees taking the last ordered interaction in dataframe in case of identical timestamps ([#288](https://github.com/MobileTeleSystems/RecTools/pull/288))

## [0.14.0] - 16.05.2025

### Added
Expand Down
10 changes: 6 additions & 4 deletions rectools/model_selection/last_n_split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 MTS (Mobile Telesystems)
# Copyright 2025 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""LastNSplitter."""

import typing as tp

Expand Down Expand Up @@ -103,8 +102,11 @@ def _split_without_filter(
df = interactions.df
idx = pd.RangeIndex(0, len(df))

# last event - rank=1
inv_ranks = df.groupby(Columns.User)[Columns.Datetime].rank(method="first", ascending=False)
# Here we guarantee that last appeared interaction in df will have lowest rank when datetime is not unique
grouped = df.groupby(Columns.User)
time_order = grouped[Columns.Datetime].rank(method="first", ascending=True).astype(int)
n_interactions = grouped[Columns.User].transform("size").astype(int)
inv_ranks = n_interactions - time_order + 1

for i_split in range(self.n_splits)[::-1]:
min_rank = i_split * self.n # excluded
Expand Down
2 changes: 1 addition & 1 deletion rectools/models/nn/item_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def from_dataset_schema(
@property
def out_dim(self) -> int:
"""Return categorical item embedding output dimension."""
return self.embedding_bag.embedding_dim
return int(self.embedding_bag.embedding_dim)


class IdEmbeddingsItemNet(ItemNetBase):
Expand Down
48 changes: 47 additions & 1 deletion tests/model_selection/test_last_n_split.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 MTS (Mobile Telesystems)
# Copyright 2023-2025 MTS (Mobile Telesystems)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,6 +41,52 @@ def _shuffle(values: tp.Sequence[int]) -> tp.List[int]:

return _shuffle

@pytest.fixture
def interactions_equal_timestamps(self, shuffle_arr: np.ndarray) -> Interactions:
df = pd.DataFrame(
[
[1, 1, 1, "2021-09-01"], # 0
[1, 2, 1, "2021-09-02"], # 1
[1, 1, 1, "2021-09-03"], # 2
[1, 2, 1, "2021-09-04"], # 3
[1, 3, 1, "2021-09-05"], # 4
[2, 3, 1, "2021-09-05"], # 5
[2, 2, 1, "2021-08-20"], # 6
[2, 2, 1, "2021-09-06"], # 7
[3, 1, 1, "2021-09-05"], # 8
[1, 6, 1, "2021-09-05"], # 9
],
columns=[Columns.User, Columns.Item, Columns.Weight, Columns.Datetime],
).astype({Columns.Datetime: "datetime64[ns]"})
return Interactions(df)

@pytest.mark.parametrize(
"swap_targets,expected_test_ids, target_item",
(
(False, {9, 7, 8}, 6),
(True, {9, 7, 8}, 3),
),
)
def test_correct_last_interactions(
self,
interactions_equal_timestamps: Interactions,
swap_targets: bool,
expected_test_ids: tp.Set[int],
target_item: int,
) -> None:
# Do not using shuffle fixture, otherwise no valid answers
interactions_et = interactions_equal_timestamps
splitter = LastNSplitter(1, 1, False, False, False)
if swap_targets:
df_swap = interactions_equal_timestamps.df
df_swap.iloc[[4, 9]] = df_swap.iloc[[9, 4]]
interactions_et = Interactions(df_swap)
loo_split = list(splitter.split(interactions_et, collect_fold_stats=True))
target_ids = loo_split[0][1]
assert set(target_ids) == expected_test_ids
assert set(loo_split[0][0]) == set(range(len(interactions_et.df))) - expected_test_ids
assert target_item in set(interactions_et.df.iloc[target_ids][Columns.Item])

@pytest.fixture
def interactions(self, shuffle_arr: np.ndarray) -> Interactions:
df = pd.DataFrame(
Expand Down