Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
### Added
- `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287))

## [0.14.0] - 16.05.2025

### Added
Expand Down
5 changes: 4 additions & 1 deletion rectools/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,10 @@ def get_user_item_matrix(
return matrix

def get_raw_interactions(
self, include_weight: bool = True, include_datetime: bool = True, include_extra_cols: bool = True
self,
include_weight: bool = True,
include_datetime: bool = True,
include_extra_cols: tp.Union[bool, tp.List[str]] = True,
) -> pd.DataFrame:
"""
Return interactions as a `pd.DataFrame` object with replacing internal user and item ids to external ones.
Expand Down
15 changes: 10 additions & 5 deletions rectools/dataset/interactions.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def to_external(
item_id_map: IdMap,
include_weight: bool = True,
include_datetime: bool = True,
include_extra_cols: bool = True,
include_extra_cols: tp.Union[bool, tp.List[str]] = True,
) -> pd.DataFrame:
"""
Convert itself to `pd.DataFrame` with replacing internal user and item ids to external ones.
Expand All @@ -182,8 +182,9 @@ def to_external(
Whether to include weight column into resulting table or not
include_datetime : bool, default ``True``
Whether to include datetime column into resulting table or not.
include_extra_cols: bool, default ``True``
Whether to include extra columns into resulting table or not.
include_extra_cols: bool or List[str], default ``True``
If bool, indicates whether to include all extra columns into resulting table or not.
If list of strings, indicates which extra columns to include into resulting table.

Returns
-------
Expand All @@ -201,9 +202,13 @@ def to_external(
cols_to_add.append(Columns.Weight)
if include_datetime:
cols_to_add.append(Columns.Datetime)
if include_extra_cols:

extra_cols = []
if isinstance(include_extra_cols, list):
extra_cols = [col for col in include_extra_cols if col in self.df and col not in Columns.Interactions]
elif include_extra_cols:
extra_cols = [col for col in self.df if col not in Columns.Interactions]
cols_to_add.extend(extra_cols)
cols_to_add.extend(extra_cols)

for col in cols_to_add:
res[col] = self.df[col]
Expand Down
14 changes: 7 additions & 7 deletions rectools/models/nn/transformers/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
ValMaskCallable,
)
from .constants import MASKING_VALUE, PADDING_VALUE
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
from .data_preparator import BatchElement, InitKwargs, TransformerDataPreparatorBase
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
from .net_blocks import (
LearnableInversePositionalEncoding,
Expand Down Expand Up @@ -128,7 +128,7 @@ def _mask_session(

def _collate_fn_train(
self,
batch: List[Tuple[List[int], List[float]]],
batch: tp.List[BatchElement],
) -> Dict[str, torch.Tensor]:
"""
Mask session elements to receive `x`.
Expand All @@ -141,7 +141,7 @@ def _collate_fn_train(
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, self.session_max_len))
yw = np.zeros((batch_size, self.session_max_len))
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
masked_session, target = self._mask_session(ses)
x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len]
y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len]
Expand All @@ -154,12 +154,12 @@ def _collate_fn_train(
)
return batch_dict

def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_val(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]:
batch_size = len(batch)
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, 1)) # until only leave-one-strategy
yw = np.zeros((batch_size, 1)) # until only leave-one-strategy
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0]
session = input_session.copy()

Expand All @@ -179,14 +179,14 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
)
return batch_dict

def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_recommend(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]:
"""
Right truncation, left padding to `session_max_len`
During inference model will use (`session_max_len` - 1) interactions
and one extra "MASK" token will be added for making predictions.
"""
x = np.zeros((len(batch), self.session_max_len))
for i, (ses, _) in enumerate(batch):
for i, (ses, _, _) in enumerate(batch):
session = ses.copy()
session = session + [self.extra_token_ids[MASKING_VALUE]]
x[i, -len(ses) - 1 :] = session[-self.session_max_len :]
Expand Down
67 changes: 50 additions & 17 deletions rectools/models/nn/transformers/data_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
from .negative_sampler import TransformerNegativeSamplerBase

InitKwargs = tp.Dict[str, tp.Any]
# (user session, session weights, extra columns)
BatchElement = tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]


class SequenceDataset(TorchDataset):
Expand All @@ -46,17 +48,26 @@ class SequenceDataset(TorchDataset):
Weight of each interaction from the session.
"""

def __init__(self, sessions: tp.List[tp.List[int]], weights: tp.List[tp.List[float]]):
def __init__(
self,
sessions: tp.List[tp.List[int]],
weights: tp.List[tp.List[float]],
extras: tp.Optional[tp.Dict[str, tp.List[tp.Any]]] = None,
):
self.sessions = sessions
self.weights = weights
self.extras = extras

def __len__(self) -> int:
return len(self.sessions)

def __getitem__(self, index: int) -> tp.Tuple[tp.List[int], tp.List[float]]:
def __getitem__(self, index: int) -> BatchElement:
session = self.sessions[index] # [session_len]
weights = self.weights[index] # [session_len]
return session, weights
extras = (
{feature_name: features[index] for feature_name, features in self.extras.items()} if self.extras else {}
)
return session, weights, extras

@classmethod
def from_interactions(
Expand All @@ -73,17 +84,19 @@ def from_interactions(
interactions : pd.DataFrame
User-item interactions.
"""
cols_to_agg = [col for col in interactions.columns if col != Columns.User]
sessions = (
interactions.sort_values(Columns.Datetime, kind="stable")
.groupby(Columns.User, sort=sort_users)[[Columns.Item, Columns.Weight]]
.groupby(Columns.User, sort=sort_users)[cols_to_agg]
.agg(list)
)
sessions, weights = (
sessions_items, weights = (
sessions[Columns.Item].to_list(),
sessions[Columns.Weight].to_list(),
)

return cls(sessions=sessions, weights=weights)
extra_cols = [col for col in interactions.columns if col not in Columns.Interactions]
extras = {col: sessions[col].to_list() for col in extra_cols} if len(extra_cols) > 0 else None
return cls(sessions=sessions_items, weights=weights, extras=extras)


class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -114,6 +127,8 @@ class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attrib
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
Additional keyword arguments for the get_val_mask_func.
Make sure all dict values have JSON serializable types.
extra_cols: optional(List[str]), default ``None``
Extra columns to keep in train and recommend datasets.
"""

# We sometimes need data preparators to add +1 to actual session_max_len
Expand All @@ -133,6 +148,7 @@ def __init__(
n_negatives: tp.Optional[int] = None,
negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
extra_cols: tp.Optional[tp.List[str]] = None,
**kwargs: tp.Any,
) -> None:
self.item_id_map: IdMap
Expand All @@ -148,6 +164,7 @@ def __init__(
self.shuffle_train = shuffle_train
self.get_val_mask_func = get_val_mask_func
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs
self.extra_cols = extra_cols

def get_known_items_sorted_internal_ids(self) -> np.ndarray:
"""Return internal item ids from processed dataset in sorted order."""
Expand Down Expand Up @@ -203,7 +220,8 @@ def _filter_train_interactions(self, train_interactions: pd.DataFrame) -> pd.Dat

def process_dataset_train(self, dataset: Dataset) -> None:
"""Process train dataset and save data."""
raw_interactions = dataset.get_raw_interactions()
extra_cols = False if self.extra_cols is None else self.extra_cols
raw_interactions = dataset.get_raw_interactions(include_extra_cols=extra_cols)

# Exclude val interaction targets from train if needed
interactions = raw_interactions
Expand Down Expand Up @@ -231,7 +249,12 @@ def process_dataset_train(self, dataset: Dataset) -> None:

# Prepare train dataset
# User features are dropped for now because model doesn't support them
final_interactions = Interactions.from_raw(interactions, user_id_map, item_id_map, keep_extra_cols=True)
final_interactions = Interactions.from_raw(
interactions,
user_id_map,
item_id_map,
keep_extra_cols=True,
)
self.train_dataset = Dataset(user_id_map, item_id_map, final_interactions, item_features=item_features)
self.item_id_map = self.train_dataset.item_id_map
self._init_extra_token_ids()
Expand All @@ -246,7 +269,9 @@ def process_dataset_train(self, dataset: Dataset) -> None:
val_interactions = interactions[interactions[Columns.User].isin(val_targets[Columns.User].unique())].copy()
val_interactions[Columns.Weight] = 0
val_interactions = pd.concat([val_interactions, val_targets], axis=0)
self.val_interactions = Interactions.from_raw(val_interactions, user_id_map, item_id_map).df
self.val_interactions = Interactions.from_raw(
val_interactions, user_id_map, item_id_map, keep_extra_cols=True
).df

def _init_extra_token_ids(self) -> None:
extra_token_ids = self.item_id_map.convert_to_internal(self.item_extra_tokens)
Expand Down Expand Up @@ -340,7 +365,10 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset
Final item_id_map is model item_id_map constructed during training.
"""
# Filter interactions in dataset internal ids
interactions = dataset.interactions.df
required_cols = Columns.Interactions
if self.extra_cols is not None:
required_cols = required_cols + self.extra_cols
interactions = dataset.interactions.df[required_cols]
users_internal = dataset.user_id_map.convert_to_internal(users, strict=False)
items_internal = dataset.item_id_map.convert_to_internal(self.get_known_item_ids(), strict=False)
interactions = interactions[interactions[Columns.User].isin(users_internal)]
Expand All @@ -359,7 +387,9 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset
if n_filtered > 0:
explanation = f"""{n_filtered} target users were considered cold because of missing known items"""
warnings.warn(explanation)
filtered_interactions = Interactions.from_raw(interactions, rec_user_id_map, self.item_id_map)
filtered_interactions = Interactions.from_raw(
interactions, rec_user_id_map, self.item_id_map, keep_extra_cols=True
)
filtered_dataset = Dataset(rec_user_id_map, self.item_id_map, filtered_interactions)
return filtered_dataset

Expand All @@ -381,26 +411,29 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset:
Final user_id_map is the same as dataset original.
Final item_id_map is model item_id_map constructed during training.
"""
interactions = dataset.get_raw_interactions()
extra_cols = False if self.extra_cols is None else self.extra_cols
interactions = dataset.get_raw_interactions(include_extra_cols=extra_cols)
interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())]
filtered_interactions = Interactions.from_raw(interactions, dataset.user_id_map, self.item_id_map)
filtered_interactions = Interactions.from_raw(
interactions, dataset.user_id_map, self.item_id_map, keep_extra_cols=True
)
filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions)
return filtered_dataset

def _collate_fn_train(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[BatchElement],
) -> tp.Dict[str, torch.Tensor]:
raise NotImplementedError()

def _collate_fn_val(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[BatchElement],
) -> tp.Dict[str, torch.Tensor]:
raise NotImplementedError()

def _collate_fn_recommend(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[BatchElement],
) -> tp.Dict[str, torch.Tensor]:
raise NotImplementedError()
2 changes: 2 additions & 0 deletions rectools/models/nn/transformers/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,9 @@ def _get_user_item_embeddings(
for batch in recommend_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
batch_embs = self.torch_model.encode_sessions(batch, item_embs)[:, -1, :]
batch_embs = self.torch_model.similarity_module.session_tower_forward(batch_embs)
user_embs.append(batch_embs.cpu())
item_embs = self.torch_model.similarity_module.item_tower_forward(item_embs)

return torch.cat(user_embs), item_embs

Expand Down
16 changes: 8 additions & 8 deletions rectools/models/nn/transformers/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import typing as tp
from typing import Dict, List, Tuple
from typing import Dict

import numpy as np
import torch
Expand All @@ -36,7 +36,7 @@
TransformerModelConfig,
ValMaskCallable,
)
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
from .data_preparator import BatchElement, InitKwargs, TransformerDataPreparatorBase
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
from .net_blocks import (
LearnableInversePositionalEncoding,
Expand Down Expand Up @@ -80,7 +80,7 @@ class SASRecDataPreparator(TransformerDataPreparatorBase):

def _collate_fn_train(
self,
batch: List[Tuple[List[int], List[float]]],
batch: tp.List[BatchElement],
) -> Dict[str, torch.Tensor]:
"""
Truncate each session from right to keep `session_max_len` items.
Expand All @@ -91,7 +91,7 @@ def _collate_fn_train(
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, self.session_max_len))
yw = np.zeros((batch_size, self.session_max_len))
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
x[i, -len(ses) + 1 :] = ses[:-1] # ses: [session_len] -> x[i]: [session_max_len]
y[i, -len(ses) + 1 :] = ses[1:] # ses: [session_len] -> y[i]: [session_max_len]
yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_max_len]
Expand All @@ -103,12 +103,12 @@ def _collate_fn_train(
)
return batch_dict

def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_val(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]:
batch_size = len(batch)
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses
yw = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0]

# take only first target for leave-one-strategy
Expand All @@ -126,10 +126,10 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
)
return batch_dict

def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_recommend(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]:
"""Right truncation, left padding to session_max_len"""
x = np.zeros((len(batch), self.session_max_len))
for i, (ses, _) in enumerate(batch):
for i, (ses, _, _) in enumerate(batch):
x[i, -len(ses) :] = ses[-self.session_max_len :]
return {"x": torch.LongTensor(x)}

Expand Down
Loading