Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions rectools/models/nn/transformers/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _mask_session(

def _collate_fn_train(
self,
batch: List[Tuple[List[int], List[float]]],
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
) -> Dict[str, torch.Tensor]:
"""
Mask session elements to receive `x`.
Expand All @@ -141,7 +141,7 @@ def _collate_fn_train(
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, self.session_max_len))
yw = np.zeros((batch_size, self.session_max_len))
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
masked_session, target = self._mask_session(ses)
x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len]
y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len]
Expand All @@ -154,12 +154,14 @@ def _collate_fn_train(
)
return batch_dict

def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_val(
self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]]
) -> Dict[str, torch.Tensor]:
batch_size = len(batch)
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, 1)) # until only leave-one-strategy
yw = np.zeros((batch_size, 1)) # until only leave-one-strategy
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0]
session = input_session.copy()

Expand All @@ -179,14 +181,16 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
)
return batch_dict

def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_recommend(
self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]]
) -> Dict[str, torch.Tensor]:
"""
Right truncation, left padding to `session_max_len`
During inference model will use (`session_max_len` - 1) interactions
and one extra "MASK" token will be added for making predictions.
"""
x = np.zeros((len(batch), self.session_max_len))
for i, (ses, _) in enumerate(batch):
for i, (ses, _, _) in enumerate(batch):
session = ses.copy()
session = session + [self.extra_token_ids[MASKING_VALUE]]
x[i, -len(ses) - 1 :] = session[-self.session_max_len :]
Expand Down
67 changes: 50 additions & 17 deletions rectools/models/nn/transformers/data_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,26 @@
Weight of each interaction from the session.
"""

def __init__(self, sessions: tp.List[tp.List[int]], weights: tp.List[tp.List[float]]):
def __init__(
self,
sessions: tp.List[tp.List[int]],
weights: tp.List[tp.List[float]],
extras: tp.Optional[tp.Dict[str, tp.List[tp.Any]]] = None,
):
self.sessions = sessions
self.weights = weights
self.extras = extras

def __len__(self) -> int:
return len(self.sessions)

def __getitem__(self, index: int) -> tp.Tuple[tp.List[int], tp.List[float]]:
def __getitem__(self, index: int) -> tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]:
session = self.sessions[index] # [session_len]
weights = self.weights[index] # [session_len]
return session, weights
extras = (
{feature_name: features[index] for feature_name, features in self.extras.items()} if self.extras else {}
)
return session, weights, extras

@classmethod
def from_interactions(
Expand All @@ -73,17 +82,19 @@
interactions : pd.DataFrame
User-item interactions.
"""
cols_to_agg = [col for col in interactions.columns if col != Columns.User]
sessions = (
interactions.sort_values(Columns.Datetime, kind="stable")
.groupby(Columns.User, sort=sort_users)[[Columns.Item, Columns.Weight]]
.groupby(Columns.User, sort=sort_users)[cols_to_agg]
.agg(list)
)
sessions, weights = (
sessions_items, weights = (
sessions[Columns.Item].to_list(),
sessions[Columns.Weight].to_list(),
)

return cls(sessions=sessions, weights=weights)
extra_cols = [col for col in interactions.columns if col not in Columns.Interactions]
extras = {col: sessions[col].to_list() for col in extra_cols} if len(extra_cols) > 0 else None
return cls(sessions=sessions_items, weights=weights, extras=extras)


class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -133,6 +144,7 @@
n_negatives: tp.Optional[int] = None,
negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
extra_cols: tp.Optional[tp.List[str]] = None,
**kwargs: tp.Any,
) -> None:
self.item_id_map: IdMap
Expand All @@ -148,6 +160,7 @@
self.shuffle_train = shuffle_train
self.get_val_mask_func = get_val_mask_func
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs
self.extra_cols = extra_cols

def get_known_items_sorted_internal_ids(self) -> np.ndarray:
"""Return internal item ids from processed dataset in sorted order."""
Expand Down Expand Up @@ -203,7 +216,10 @@

def process_dataset_train(self, dataset: Dataset) -> None:
"""Process train dataset and save data."""
raw_interactions = dataset.get_raw_interactions()
if self.extra_cols is None:
raw_interactions = dataset.get_raw_interactions(include_extra_cols=False)
else:
raw_interactions = dataset.get_raw_interactions()[Columns.Interactions + self.extra_cols]

# Exclude val interaction targets from train if needed
interactions = raw_interactions
Expand Down Expand Up @@ -231,7 +247,12 @@

# Prepare train dataset
# User features are dropped for now because model doesn't support them
final_interactions = Interactions.from_raw(interactions, user_id_map, item_id_map, keep_extra_cols=True)
final_interactions = Interactions.from_raw(
interactions,
user_id_map,
item_id_map,
keep_extra_cols=True,
)
self.train_dataset = Dataset(user_id_map, item_id_map, final_interactions, item_features=item_features)
self.item_id_map = self.train_dataset.item_id_map
self._init_extra_token_ids()
Expand All @@ -246,7 +267,9 @@
val_interactions = interactions[interactions[Columns.User].isin(val_targets[Columns.User].unique())].copy()
val_interactions[Columns.Weight] = 0
val_interactions = pd.concat([val_interactions, val_targets], axis=0)
self.val_interactions = Interactions.from_raw(val_interactions, user_id_map, item_id_map).df
self.val_interactions = Interactions.from_raw(
val_interactions, user_id_map, item_id_map, keep_extra_cols=True
).df

def _init_extra_token_ids(self) -> None:
extra_token_ids = self.item_id_map.convert_to_internal(self.item_extra_tokens)
Expand Down Expand Up @@ -340,7 +363,10 @@
Final item_id_map is model item_id_map constructed during training.
"""
# Filter interactions in dataset internal ids
interactions = dataset.interactions.df
if self.extra_cols is None:
interactions = dataset.interactions.df[Columns.Interactions]
else:
interactions = dataset.interactions.df[Columns.Interactions + self.extra_cols]
users_internal = dataset.user_id_map.convert_to_internal(users, strict=False)
items_internal = dataset.item_id_map.convert_to_internal(self.get_known_item_ids(), strict=False)
interactions = interactions[interactions[Columns.User].isin(users_internal)]
Expand All @@ -359,7 +385,9 @@
if n_filtered > 0:
explanation = f"""{n_filtered} target users were considered cold because of missing known items"""
warnings.warn(explanation)
filtered_interactions = Interactions.from_raw(interactions, rec_user_id_map, self.item_id_map)
filtered_interactions = Interactions.from_raw(
interactions, rec_user_id_map, self.item_id_map, keep_extra_cols=True
)
filtered_dataset = Dataset(rec_user_id_map, self.item_id_map, filtered_interactions)
return filtered_dataset

Expand All @@ -381,26 +409,31 @@
Final user_id_map is the same as dataset original.
Final item_id_map is model item_id_map constructed during training.
"""
interactions = dataset.get_raw_interactions()
if self.extra_cols is None:
interactions = dataset.get_raw_interactions(include_extra_cols=False)
else:
interactions = dataset.get_raw_interactions()[Columns.Interactions + self.extra_cols]

Check warning on line 415 in rectools/models/nn/transformers/data_preparator.py

View check run for this annotation

Codecov / codecov/patch

rectools/models/nn/transformers/data_preparator.py#L415

Added line #L415 was not covered by tests
interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())]
filtered_interactions = Interactions.from_raw(interactions, dataset.user_id_map, self.item_id_map)
filtered_interactions = Interactions.from_raw(
interactions, dataset.user_id_map, self.item_id_map, self.extra_cols is not None
)
filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions)
return filtered_dataset

def _collate_fn_train(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
) -> tp.Dict[str, torch.Tensor]:
raise NotImplementedError()

def _collate_fn_val(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
) -> tp.Dict[str, torch.Tensor]:
raise NotImplementedError()

def _collate_fn_recommend(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
) -> tp.Dict[str, torch.Tensor]:
raise NotImplementedError()
2 changes: 2 additions & 0 deletions rectools/models/nn/transformers/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,9 @@ def _get_user_item_embeddings(
for batch in recommend_dataloader:
batch = {k: v.to(device) for k, v in batch.items()}
batch_embs = self.torch_model.encode_sessions(batch, item_embs)[:, -1, :]
batch_embs = self.torch_model.similarity_module.session_tower_forward(batch_embs)
user_embs.append(batch_embs.cpu())
item_embs = self.torch_model.similarity_module.item_tower_forward(item_embs)

return torch.cat(user_embs), item_embs

Expand Down
18 changes: 11 additions & 7 deletions rectools/models/nn/transformers/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import typing as tp
from typing import Dict, List, Tuple
from typing import Dict

import numpy as np
import torch
Expand Down Expand Up @@ -80,7 +80,7 @@ class SASRecDataPreparator(TransformerDataPreparatorBase):

def _collate_fn_train(
self,
batch: List[Tuple[List[int], List[float]]],
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
) -> Dict[str, torch.Tensor]:
"""
Truncate each session from right to keep `session_max_len` items.
Expand All @@ -91,7 +91,7 @@ def _collate_fn_train(
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, self.session_max_len))
yw = np.zeros((batch_size, self.session_max_len))
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
x[i, -len(ses) + 1 :] = ses[:-1] # ses: [session_len] -> x[i]: [session_max_len]
y[i, -len(ses) + 1 :] = ses[1:] # ses: [session_len] -> y[i]: [session_max_len]
yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_max_len]
Expand All @@ -103,12 +103,14 @@ def _collate_fn_train(
)
return batch_dict

def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_val(
self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]]
) -> Dict[str, torch.Tensor]:
batch_size = len(batch)
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses
yw = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0]

# take only first target for leave-one-strategy
Expand All @@ -126,10 +128,12 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
)
return batch_dict

def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_recommend(
self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]]
) -> Dict[str, torch.Tensor]:
"""Right truncation, left padding to session_max_len"""
x = np.zeros((len(batch), self.session_max_len))
for i, (ses, _) in enumerate(batch):
for i, (ses, _, _) in enumerate(batch):
x[i, -len(ses) :] = ses[-self.session_max_len :]
return {"x": torch.LongTensor(x)}

Expand Down
22 changes: 21 additions & 1 deletion rectools/models/nn/transformers/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def _get_pos_neg_logits(
) -> torch.Tensor:
raise NotImplementedError()

def session_tower_forward(self, session_embs: torch.Tensor) -> torch.Tensor:
"""Forward pass for session tower."""
raise NotImplementedError()

def item_tower_forward(self, item_embs: torch.Tensor) -> torch.Tensor:
"""Forward pass for item tower."""
raise NotImplementedError()

def forward(
self,
session_embs: torch.Tensor,
Expand Down Expand Up @@ -62,7 +70,11 @@ class DistanceSimilarityModule(SimilarityModuleBase):
dist_available: tp.List[str] = [Distance.DOT, Distance.COSINE]
epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8])

def __init__(self, distance: str = "dot") -> None:
def __init__(
self,
distance: str = "dot",
**kwargs: tp.Any,
) -> None:
super().__init__()
if distance not in self.dist_available:
raise ValueError("`dist` can only be either `dot` or `cosine`.")
Expand All @@ -87,6 +99,14 @@ def _get_embeddings_norm(self, embeddings: torch.Tensor) -> torch.Tensor:
embeddings = embeddings / torch.max(embedding_norm, self.epsilon_cosine_dist.to(embeddings))
return embeddings

def session_tower_forward(self, session_embs: torch.Tensor) -> torch.Tensor:
"""Forward pass for session tower."""
return session_embs

def item_tower_forward(self, item_embs: torch.Tensor) -> torch.Tensor:
"""Forward pass for item tower."""
return item_embs

def forward(
self,
session_embs: torch.Tensor,
Expand Down
4 changes: 2 additions & 2 deletions tests/models/nn/transformers/test_bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,13 +640,13 @@ def __init__(

def _collate_fn_train(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
) -> tp.Dict[str, torch.Tensor]:
batch_size = len(batch)
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, self.session_max_len))
yw = np.zeros((batch_size, self.session_max_len))
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
y[i, -self.n_last_targets] = ses[-self.n_last_targets]
yw[i, -self.n_last_targets] = ses_weights[-self.n_last_targets]
x[i, -len(ses) :] = ses
Expand Down
Loading