Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions rectools/models/nn/transformers/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def _mask_session(

def _collate_fn_train(
self,
batch: List[Tuple[List[int], List[float]]],
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
) -> Dict[str, torch.Tensor]:
"""
Mask session elements to receive `x`.
Expand All @@ -141,7 +141,7 @@ def _collate_fn_train(
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, self.session_max_len))
yw = np.zeros((batch_size, self.session_max_len))
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
masked_session, target = self._mask_session(ses)
x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len]
y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len]
Expand All @@ -154,12 +154,14 @@ def _collate_fn_train(
)
return batch_dict

def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_val(
self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]]
) -> Dict[str, torch.Tensor]:
batch_size = len(batch)
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, 1)) # until only leave-one-strategy
yw = np.zeros((batch_size, 1)) # until only leave-one-strategy
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0]
session = input_session.copy()

Expand All @@ -179,14 +181,16 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
)
return batch_dict

def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_recommend(
self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]]
) -> Dict[str, torch.Tensor]:
"""
Right truncation, left padding to `session_max_len`
During inference model will use (`session_max_len` - 1) interactions
and one extra "MASK" token will be added for making predictions.
"""
x = np.zeros((len(batch), self.session_max_len))
for i, (ses, _) in enumerate(batch):
for i, (ses, _, _) in enumerate(batch):
session = ses.copy()
session = session + [self.extra_token_ids[MASKING_VALUE]]
x[i, -len(ses) - 1 :] = session[-self.session_max_len :]
Expand Down
64 changes: 47 additions & 17 deletions rectools/models/nn/transformers/data_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,23 +46,33 @@ class SequenceDataset(TorchDataset):
Weight of each interaction from the session.
"""

def __init__(self, sessions: tp.List[tp.List[int]], weights: tp.List[tp.List[float]]):
def __init__(
self,
sessions: tp.List[tp.List[int]],
weights: tp.List[tp.List[float]],
extras: tp.Optional[tp.Dict[str, tp.List[tp.Any]]] = None,
):
self.sessions = sessions
self.weights = weights
self.extras = extras

def __len__(self) -> int:
return len(self.sessions)

def __getitem__(self, index: int) -> tp.Tuple[tp.List[int], tp.List[float]]:
def __getitem__(self, index: int) -> tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]:
session = self.sessions[index] # [session_len]
weights = self.weights[index] # [session_len]
return session, weights
extras = (
{feature_name: features[index] for feature_name, features in self.extras.items()} if self.extras else {}
)
return session, weights, extras

@classmethod
def from_interactions(
cls,
interactions: pd.DataFrame,
sort_users: bool = False,
extra_cols: tp.Optional[tp.List[str]] = None,
) -> "SequenceDataset":
"""
Group interactions by user.
Expand All @@ -73,17 +83,18 @@ def from_interactions(
interactions : pd.DataFrame
User-item interactions.
"""
cols_to_agg = [col for col in interactions.columns if col != Columns.User]
sessions = (
interactions.sort_values(Columns.Datetime, kind="stable")
.groupby(Columns.User, sort=sort_users)[[Columns.Item, Columns.Weight]]
.groupby(Columns.User, sort=sort_users)[cols_to_agg]
.agg(list)
)
sessions, weights = (
sessions_items, weights = (
sessions[Columns.Item].to_list(),
sessions[Columns.Weight].to_list(),
)

return cls(sessions=sessions, weights=weights)
extras = {col: sessions[col].to_list() for col in extra_cols} if extra_cols else None
return cls(sessions=sessions_items, weights=weights, extras=extras)


class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attributes
Expand Down Expand Up @@ -133,6 +144,7 @@ def __init__(
n_negatives: tp.Optional[int] = None,
negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
extra_cols_kwargs: tp.Optional[InitKwargs] = None,
**kwargs: tp.Any,
) -> None:
self.item_id_map: IdMap
Expand All @@ -148,6 +160,7 @@ def __init__(
self.shuffle_train = shuffle_train
self.get_val_mask_func = get_val_mask_func
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs
self.extra_cols_kwargs = extra_cols_kwargs

def get_known_items_sorted_internal_ids(self) -> np.ndarray:
"""Return internal item ids from processed dataset in sorted order."""
Expand Down Expand Up @@ -231,7 +244,10 @@ def process_dataset_train(self, dataset: Dataset) -> None:

# Prepare train dataset
# User features are dropped for now because model doesn't support them
final_interactions = Interactions.from_raw(interactions, user_id_map, item_id_map, keep_extra_cols=True)
keep_extra_cols = self.extra_cols_kwargs is not None
final_interactions = Interactions.from_raw(
interactions, user_id_map, item_id_map, keep_extra_cols=keep_extra_cols
)
self.train_dataset = Dataset(user_id_map, item_id_map, final_interactions, item_features=item_features)
self.item_id_map = self.train_dataset.item_id_map
self._init_extra_token_ids()
Expand All @@ -246,7 +262,9 @@ def process_dataset_train(self, dataset: Dataset) -> None:
val_interactions = interactions[interactions[Columns.User].isin(val_targets[Columns.User].unique())].copy()
val_interactions[Columns.Weight] = 0
val_interactions = pd.concat([val_interactions, val_targets], axis=0)
self.val_interactions = Interactions.from_raw(val_interactions, user_id_map, item_id_map).df
self.val_interactions = Interactions.from_raw(
val_interactions, user_id_map, item_id_map, keep_extra_cols=keep_extra_cols
).df

def _init_extra_token_ids(self) -> None:
extra_token_ids = self.item_id_map.convert_to_internal(self.item_extra_tokens)
Expand All @@ -261,7 +279,9 @@ def get_dataloader_train(self) -> DataLoader:
DataLoader
Train dataloader.
"""
sequence_dataset = SequenceDataset.from_interactions(self.train_dataset.interactions.df)
sequence_dataset = SequenceDataset.from_interactions(
self.train_dataset.interactions.df, **self._ensure_kwargs_dict(self.extra_cols_kwargs)
)
train_dataloader = DataLoader(
sequence_dataset,
collate_fn=self._collate_fn_train,
Expand All @@ -283,7 +303,9 @@ def get_dataloader_val(self) -> tp.Optional[DataLoader]:
if self.val_interactions is None:
return None

sequence_dataset = SequenceDataset.from_interactions(self.val_interactions)
sequence_dataset = SequenceDataset.from_interactions(
self.val_interactions, **self._ensure_kwargs_dict(self.extra_cols_kwargs)
)
val_dataloader = DataLoader(
sequence_dataset,
collate_fn=self._collate_fn_val,
Expand All @@ -306,7 +328,9 @@ def get_dataloader_recommend(self, dataset: Dataset, batch_size: int) -> DataLoa
# User ids here are internal user ids in dataset.interactions.df that was prepared for recommendations.
# Sorting sessions by user ids will ensure that these ids will also be correct indexes in user embeddings matrix
# that will be returned by the net.
sequence_dataset = SequenceDataset.from_interactions(interactions=dataset.interactions.df, sort_users=True)
sequence_dataset = SequenceDataset.from_interactions(
interactions=dataset.interactions.df, sort_users=True, **self._ensure_kwargs_dict(self.extra_cols_kwargs)
)
recommend_dataloader = DataLoader(
sequence_dataset,
batch_size=batch_size,
Expand Down Expand Up @@ -359,7 +383,10 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset
if n_filtered > 0:
explanation = f"""{n_filtered} target users were considered cold because of missing known items"""
warnings.warn(explanation)
filtered_interactions = Interactions.from_raw(interactions, rec_user_id_map, self.item_id_map)
keep_extra_cols = self.extra_cols_kwargs is not None
filtered_interactions = Interactions.from_raw(
interactions, rec_user_id_map, self.item_id_map, keep_extra_cols=keep_extra_cols
)
filtered_dataset = Dataset(rec_user_id_map, self.item_id_map, filtered_interactions)
return filtered_dataset

Expand All @@ -383,24 +410,27 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset:
"""
interactions = dataset.get_raw_interactions()
interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())]
filtered_interactions = Interactions.from_raw(interactions, dataset.user_id_map, self.item_id_map)
keep_extra_cols = self.extra_cols_kwargs is not None
filtered_interactions = Interactions.from_raw(
interactions, dataset.user_id_map, self.item_id_map, keep_extra_cols
)
filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions)
return filtered_dataset

def _collate_fn_train(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
) -> tp.Dict[str, torch.Tensor]:
raise NotImplementedError()

def _collate_fn_val(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
) -> tp.Dict[str, torch.Tensor]:
raise NotImplementedError()

def _collate_fn_recommend(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
) -> tp.Dict[str, torch.Tensor]:
raise NotImplementedError()
18 changes: 11 additions & 7 deletions rectools/models/nn/transformers/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import typing as tp
from typing import Dict, List, Tuple
from typing import Dict

import numpy as np
import torch
Expand Down Expand Up @@ -80,7 +80,7 @@ class SASRecDataPreparator(TransformerDataPreparatorBase):

def _collate_fn_train(
self,
batch: List[Tuple[List[int], List[float]]],
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
) -> Dict[str, torch.Tensor]:
"""
Truncate each session from right to keep `session_max_len` items.
Expand All @@ -91,7 +91,7 @@ def _collate_fn_train(
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, self.session_max_len))
yw = np.zeros((batch_size, self.session_max_len))
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
x[i, -len(ses) + 1 :] = ses[:-1] # ses: [session_len] -> x[i]: [session_max_len]
y[i, -len(ses) + 1 :] = ses[1:] # ses: [session_len] -> y[i]: [session_max_len]
yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_max_len]
Expand All @@ -103,12 +103,14 @@ def _collate_fn_train(
)
return batch_dict

def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_val(
self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]]
) -> Dict[str, torch.Tensor]:
batch_size = len(batch)
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses
yw = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0]

# take only first target for leave-one-strategy
Expand All @@ -126,10 +128,12 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
)
return batch_dict

def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
def _collate_fn_recommend(
self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]]
) -> Dict[str, torch.Tensor]:
"""Right truncation, left padding to session_max_len"""
x = np.zeros((len(batch), self.session_max_len))
for i, (ses, _) in enumerate(batch):
for i, (ses, _, _) in enumerate(batch):
x[i, -len(ses) :] = ses[-self.session_max_len :]
return {"x": torch.LongTensor(x)}

Expand Down
6 changes: 5 additions & 1 deletion rectools/models/nn/transformers/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,11 @@ class DistanceSimilarityModule(SimilarityModuleBase):
dist_available: tp.List[str] = [Distance.DOT, Distance.COSINE]
epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8])

def __init__(self, distance: str = "dot") -> None:
def __init__(
self,
distance: str = "dot",
**kwargs: tp.Any,
) -> None:
super().__init__()
if distance not in self.dist_available:
raise ValueError("`dist` can only be either `dot` or `cosine`.")
Expand Down
4 changes: 2 additions & 2 deletions tests/models/nn/transformers/test_bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,13 +640,13 @@ def __init__(

def _collate_fn_train(
self,
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
) -> tp.Dict[str, torch.Tensor]:
batch_size = len(batch)
x = np.zeros((batch_size, self.session_max_len))
y = np.zeros((batch_size, self.session_max_len))
yw = np.zeros((batch_size, self.session_max_len))
for i, (ses, ses_weights) in enumerate(batch):
for i, (ses, ses_weights, _) in enumerate(batch):
y[i, -self.n_last_targets] = ses[-self.n_last_targets]
yw[i, -self.n_last_targets] = ses_weights[-self.n_last_targets]
x[i, -len(ses) :] = ses
Expand Down
Loading