Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
### Added
- `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287))

## [0.14.0] - 16.05.2025

### Added
Expand Down
14 changes: 7 additions & 7 deletions rectools/models/nn/transformers/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
ValMaskCallable,
)
from .constants import MASKING_VALUE, PADDING_VALUE
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
from .data_preparator import BatchElement, InitKwargs, TransformerDataPreparatorBase
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
from .net_blocks import (
LearnableInversePositionalEncoding,
Expand Down Expand Up @@ -128,7 +128,7 @@ def _mask_session(

def _collate_fn_train(
self,
batch: List[Tuple[List[int], List[float]]],
batch: tp.List[BatchElement],
) -> Dict[str, torch.Tensor]:
"""
Mask session elements to receive `x`.
Expand All @@ -141,7 +141,7 @@ def _collate_fn_train(
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, self.session_max_len))
yw = np.zeros((batch_size, self.session_max_len))
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
masked_session, target = self._mask_session(ses)
x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len]
y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len]
Expand All @@ -154,12 +154,12 @@ def _collate_fn_train(
)
return batch_dict

def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_val(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]:
batch_size = len(batch)
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, 1)) # until only leave-one-strategy
yw = np.zeros((batch_size, 1)) # until only leave-one-strategy
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0]
session = input_session.copy()

Expand All @@ -179,14 +179,14 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
)
return batch_dict

def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_recommend(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]:
"""
Right truncation, left padding to `session_max_len`
During inference model will use (`session_max_len` - 1) interactions
and one extra "MASK" token will be added for making predictions.
"""
x = np.zeros((len(batch), self.session_max_len))
for i, (ses, _) in enumerate(batch):
for i, (ses, _, _) in enumerate(batch):
session = ses.copy()
session = session + [self.extra_token_ids[MASKING_VALUE]]
x[i, -len(ses) - 1 :] = session[-self.session_max_len :]
Expand Down
68 changes: 51 additions & 17 deletions rectools/models/nn/transformers/data_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .negative_sampler import TransformerNegativeSamplerBase

InitKwargs = tp.Dict[str, tp.Any]
BatchElement = tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]


class SequenceDataset(TorchDataset):
Expand All @@ -46,17 +47,26 @@ class SequenceDataset(TorchDataset):
Weight of each interaction from the session.
"""

def __init__(self, sessions: tp.List[tp.List[int]], weights: tp.List[tp.List[float]]):
def __init__(
self,
sessions: tp.List[tp.List[int]],
weights: tp.List[tp.List[float]],
extras: tp.Optional[tp.Dict[str, tp.List[tp.Any]]] = None,
):
self.sessions = sessions
self.weights = weights
self.extras = extras

def __len__(self) -> int:
return len(self.sessions)

def __getitem__(self, index: int) -> tp.Tuple[tp.List[int], tp.List[float]]:
def __getitem__(self, index: int) -> BatchElement:
session = self.sessions[index] # [session_len]
weights = self.weights[index] # [session_len]
return session, weights
extras = (
{feature_name: features[index] for feature_name, features in self.extras.items()} if self.extras else {}
)
return session, weights, extras

@classmethod
def from_interactions(
Expand All @@ -73,17 +83,19 @@ def from_interactions(
interactions : pd.DataFrame
User-item interactions.
"""
cols_to_agg = [col for col in interactions.columns if col != Columns.User]
sessions = (
interactions.sort_values(Columns.Datetime, kind="stable")
.groupby(Columns.User, sort=sort_users)[[Columns.Item, Columns.Weight]]
.groupby(Columns.User, sort=sort_users)[cols_to_agg]
.agg(list)
)
sessions, weights = (
sessions_items, weights = (
sessions[Columns.Item].to_list(),
sessions[Columns.Weight].to_list(),
)

return cls(sessions=sessions, weights=weights)
extra_cols = [col for col in interactions.columns if col not in Columns.Interactions]
extras = {col: sessions[col].to_list() for col in extra_cols} if len(extra_cols) > 0 else None
return cls(sessions=sessions_items, weights=weights, extras=extras)


class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -133,6 +145,7 @@ def __init__(
n_negatives: tp.Optional[int] = None,
negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
extra_cols: tp.Optional[tp.List[str]] = None,
**kwargs: tp.Any,
) -> None:
self.item_id_map: IdMap
Expand All @@ -148,6 +161,7 @@ def __init__(
self.shuffle_train = shuffle_train
self.get_val_mask_func = get_val_mask_func
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs
self.extra_cols = extra_cols

def get_known_items_sorted_internal_ids(self) -> np.ndarray:
"""Return internal item ids from processed dataset in sorted order."""
Expand Down Expand Up @@ -203,7 +217,10 @@ def _filter_train_interactions(self, train_interactions: pd.DataFrame) -> pd.Dat

def process_dataset_train(self, dataset: Dataset) -> None:
"""Process train dataset and save data."""
raw_interactions = dataset.get_raw_interactions()
if self.extra_cols is None:
raw_interactions = dataset.get_raw_interactions(include_extra_cols=False)
else:
raw_interactions = dataset.get_raw_interactions()[Columns.Interactions + self.extra_cols]

# Exclude val interaction targets from train if needed
interactions = raw_interactions
Expand Down Expand Up @@ -231,7 +248,12 @@ def process_dataset_train(self, dataset: Dataset) -> None:

# Prepare train dataset
# User features are dropped for now because model doesn't support them
final_interactions = Interactions.from_raw(interactions, user_id_map, item_id_map, keep_extra_cols=True)
final_interactions = Interactions.from_raw(
interactions,
user_id_map,
item_id_map,
keep_extra_cols=True,
)
self.train_dataset = Dataset(user_id_map, item_id_map, final_interactions, item_features=item_features)
self.item_id_map = self.train_dataset.item_id_map
self._init_extra_token_ids()
Expand All @@ -246,7 +268,9 @@ def process_dataset_train(self, dataset: Dataset) -> None:
val_interactions = interactions[interactions[Columns.User].isin(val_targets[Columns.User].unique())].copy()
val_interactions[Columns.Weight] = 0
val_interactions = pd.concat([val_interactions, val_targets], axis=0)
self.val_interactions = Interactions.from_raw(val_interactions, user_id_map, item_id_map).df
self.val_interactions = Interactions.from_raw(
val_interactions, user_id_map, item_id_map, keep_extra_cols=True
).df

def _init_extra_token_ids(self) -> None:
extra_token_ids = self.item_id_map.convert_to_internal(self.item_extra_tokens)
Expand Down Expand Up @@ -340,7 +364,10 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset
Final item_id_map is model item_id_map constructed during training.
"""
# Filter interactions in dataset internal ids
interactions = dataset.interactions.df
if self.extra_cols is None:
interactions = dataset.interactions.df[Columns.Interactions]
else:
interactions = dataset.interactions.df[Columns.Interactions + self.extra_cols]
users_internal = dataset.user_id_map.convert_to_internal(users, strict=False)
items_internal = dataset.item_id_map.convert_to_internal(self.get_known_item_ids(), strict=False)
interactions = interactions[interactions[Columns.User].isin(users_internal)]
Expand All @@ -359,7 +386,9 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset
if n_filtered > 0:
explanation = f"""{n_filtered} target users were considered cold because of missing known items"""
warnings.warn(explanation)
filtered_interactions = Interactions.from_raw(interactions, rec_user_id_map, self.item_id_map)
filtered_interactions = Interactions.from_raw(
interactions, rec_user_id_map, self.item_id_map, keep_extra_cols=True
)
filtered_dataset = Dataset(rec_user_id_map, self.item_id_map, filtered_interactions)
return filtered_dataset

Expand All @@ -381,26 +410,31 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset:
Final user_id_map is the same as dataset original.
Final item_id_map is model item_id_map constructed during training.
"""
interactions = dataset.get_raw_interactions()
if self.extra_cols is None:
interactions = dataset.get_raw_interactions(include_extra_cols=False)
else:
interactions = dataset.get_raw_interactions()[Columns.Interactions + self.extra_cols]
interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())]
filtered_interactions = Interactions.from_raw(interactions, dataset.user_id_map, self.item_id_map)
filtered_interactions = Interactions.from_raw(
interactions, dataset.user_id_map, self.item_id_map, keep_extra_cols=True
)
filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions)
return filtered_dataset

def _collate_fn_train(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[BatchElement],
) -> tp.Dict[str, torch.Tensor]:
raise NotImplementedError()

def _collate_fn_val(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[BatchElement],
) -> tp.Dict[str, torch.Tensor]:
raise NotImplementedError()

def _collate_fn_recommend(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[BatchElement],
) -> tp.Dict[str, torch.Tensor]:
raise NotImplementedError()
2 changes: 2 additions & 0 deletions rectools/models/nn/transformers/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,9 @@ def _get_user_item_embeddings(
for batch in recommend_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
batch_embs = self.torch_model.encode_sessions(batch, item_embs)[:, -1, :]
batch_embs = self.torch_model.similarity_module.session_tower_forward(batch_embs)
user_embs.append(batch_embs.cpu())
item_embs = self.torch_model.similarity_module.item_tower_forward(item_embs)

return torch.cat(user_embs), item_embs

Expand Down
16 changes: 8 additions & 8 deletions rectools/models/nn/transformers/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import typing as tp
from typing import Dict, List, Tuple
from typing import Dict

import numpy as np
import torch
Expand All @@ -36,7 +36,7 @@
TransformerModelConfig,
ValMaskCallable,
)
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
from .data_preparator import BatchElement, InitKwargs, TransformerDataPreparatorBase
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
from .net_blocks import (
LearnableInversePositionalEncoding,
Expand Down Expand Up @@ -80,7 +80,7 @@ class SASRecDataPreparator(TransformerDataPreparatorBase):

def _collate_fn_train(
self,
batch: List[Tuple[List[int], List[float]]],
batch: tp.List[BatchElement],
) -> Dict[str, torch.Tensor]:
"""
Truncate each session from right to keep `session_max_len` items.
Expand All @@ -91,7 +91,7 @@ def _collate_fn_train(
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, self.session_max_len))
yw = np.zeros((batch_size, self.session_max_len))
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
x[i, -len(ses) + 1 :] = ses[:-1] # ses: [session_len] -> x[i]: [session_max_len]
y[i, -len(ses) + 1 :] = ses[1:] # ses: [session_len] -> y[i]: [session_max_len]
yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_max_len]
Expand All @@ -103,12 +103,12 @@ def _collate_fn_train(
)
return batch_dict

def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_val(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]:
batch_size = len(batch)
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses
yw = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0]

# take only first target for leave-one-strategy
Expand All @@ -126,10 +126,10 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
)
return batch_dict

def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_recommend(self, batch: tp.List[BatchElement]) -> Dict[str, torch.Tensor]:
"""Right truncation, left padding to session_max_len"""
x = np.zeros((len(batch), self.session_max_len))
for i, (ses, _) in enumerate(batch):
for i, (ses, _, _) in enumerate(batch):
x[i, -len(ses) :] = ses[-self.session_max_len :]
return {"x": torch.LongTensor(x)}

Expand Down
14 changes: 13 additions & 1 deletion rectools/models/nn/transformers/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def _get_pos_neg_logits(
) -> torch.Tensor:
raise NotImplementedError()

def session_tower_forward(self, session_embs: torch.Tensor) -> torch.Tensor:
"""Forward pass for session tower."""
return session_embs

def item_tower_forward(self, item_embs: torch.Tensor) -> torch.Tensor:
"""Forward pass for item tower."""
return item_embs

def forward(
self,
session_embs: torch.Tensor,
Expand Down Expand Up @@ -62,7 +70,11 @@ class DistanceSimilarityModule(SimilarityModuleBase):
dist_available: tp.List[str] = [Distance.DOT, Distance.COSINE]
epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8])

def __init__(self, distance: str = "dot") -> None:
def __init__(
self,
distance: str = "dot",
**kwargs: tp.Any,
) -> None:
super().__init__()
if distance not in self.dist_available:
raise ValueError("`dist` can only be either `dot` or `cosine`.")
Expand Down
6 changes: 3 additions & 3 deletions tests/models/nn/transformers/test_bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
TransformerLightningModule,
)
from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable
from rectools.models.nn.transformers.data_preparator import InitKwargs
from rectools.models.nn.transformers.data_preparator import BatchElement, InitKwargs
from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
from rectools.models.nn.transformers.similarity import DistanceSimilarityModule
from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone
Expand Down Expand Up @@ -640,13 +640,13 @@ def __init__(

def _collate_fn_train(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[BatchElement],
) -> tp.Dict[str, torch.Tensor]:
batch_size = len(batch)
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, self.session_max_len))
yw = np.zeros((batch_size, self.session_max_len))
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
y[i, -self.n_last_targets] = ses[-self.n_last_targets]
yw[i, -self.n_last_targets] = ses_weights[-self.n_last_targets]
x[i, -len(ses) :] = ses
Expand Down
Loading