Skip to content

Commit 1fdc000

Browse files
author
Spirina Majya Aleksandrovna
committed
add extra columns
1 parent ea266cd commit 1fdc000

File tree

6 files changed

+136
-77
lines changed

6 files changed

+136
-77
lines changed

rectools/models/nn/transformers/bert4rec.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def _mask_session(
128128

129129
def _collate_fn_train(
130130
self,
131-
batch: List[Tuple[List[int], List[float]]],
131+
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
132132
) -> Dict[str, torch.Tensor]:
133133
"""
134134
Mask session elements to receive `x`.
@@ -141,7 +141,7 @@ def _collate_fn_train(
141141
x = np.zeros((batch_size, self.session_max_len))
142142
y = np.zeros((batch_size, self.session_max_len))
143143
yw = np.zeros((batch_size, self.session_max_len))
144-
for i, (ses, ses_weights) in enumerate(batch):
144+
for i, (ses, ses_weights, _) in enumerate(batch):
145145
masked_session, target = self._mask_session(ses)
146146
x[i, -len(ses) :] = masked_session # ses: [session_len] -> x[i]: [session_max_len]
147147
y[i, -len(ses) :] = target # ses: [session_len] -> y[i]: [session_max_len]
@@ -154,12 +154,14 @@ def _collate_fn_train(
154154
)
155155
return batch_dict
156156

157-
def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
157+
def _collate_fn_val(
158+
self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]]
159+
) -> Dict[str, torch.Tensor]:
158160
batch_size = len(batch)
159161
x = np.zeros((batch_size, self.session_max_len))
160162
y = np.zeros((batch_size, 1)) # until only leave-one-strategy
161163
yw = np.zeros((batch_size, 1)) # until only leave-one-strategy
162-
for i, (ses, ses_weights) in enumerate(batch):
164+
for i, (ses, ses_weights, _) in enumerate(batch):
163165
input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0]
164166
session = input_session.copy()
165167

@@ -179,14 +181,16 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
179181
)
180182
return batch_dict
181183

182-
def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
184+
def _collate_fn_recommend(
185+
self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]]
186+
) -> Dict[str, torch.Tensor]:
183187
"""
184188
Right truncation, left padding to `session_max_len`
185189
During inference model will use (`session_max_len` - 1) interactions
186190
and one extra "MASK" token will be added for making predictions.
187191
"""
188192
x = np.zeros((len(batch), self.session_max_len))
189-
for i, (ses, _) in enumerate(batch):
193+
for i, (ses, _, _) in enumerate(batch):
190194
session = ses.copy()
191195
session = session + [self.extra_token_ids[MASKING_VALUE]]
192196
x[i, -len(ses) - 1 :] = session[-self.session_max_len :]

rectools/models/nn/transformers/data_preparator.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,23 +46,33 @@ class SequenceDataset(TorchDataset):
4646
Weight of each interaction from the session.
4747
"""
4848

49-
def __init__(self, sessions: tp.List[tp.List[int]], weights: tp.List[tp.List[float]]):
49+
def __init__(
50+
self,
51+
sessions: tp.List[tp.List[int]],
52+
weights: tp.List[tp.List[float]],
53+
extras: tp.Optional[tp.Dict[str, tp.List[tp.Any]]] = None,
54+
):
5055
self.sessions = sessions
5156
self.weights = weights
57+
self.extras = extras
5258

5359
def __len__(self) -> int:
5460
return len(self.sessions)
5561

56-
def __getitem__(self, index: int) -> tp.Tuple[tp.List[int], tp.List[float]]:
62+
def __getitem__(self, index: int) -> tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]:
5763
session = self.sessions[index] # [session_len]
5864
weights = self.weights[index] # [session_len]
59-
return session, weights
65+
extras = (
66+
{feature_name: features[index] for feature_name, features in self.extras.items()} if self.extras else {}
67+
)
68+
return session, weights, extras
6069

6170
@classmethod
6271
def from_interactions(
6372
cls,
6473
interactions: pd.DataFrame,
6574
sort_users: bool = False,
75+
extra_cols: tp.Optional[tp.List[str]] = None,
6676
) -> "SequenceDataset":
6777
"""
6878
Group interactions by user.
@@ -73,17 +83,18 @@ def from_interactions(
7383
interactions : pd.DataFrame
7484
User-item interactions.
7585
"""
86+
cols_to_agg = [col for col in interactions.columns if col != Columns.User]
7687
sessions = (
7788
interactions.sort_values(Columns.Datetime, kind="stable")
78-
.groupby(Columns.User, sort=sort_users)[[Columns.Item, Columns.Weight]]
89+
.groupby(Columns.User, sort=sort_users)[cols_to_agg]
7990
.agg(list)
8091
)
81-
sessions, weights = (
92+
sessions_items, weights = (
8293
sessions[Columns.Item].to_list(),
8394
sessions[Columns.Weight].to_list(),
8495
)
85-
86-
return cls(sessions=sessions, weights=weights)
96+
extras = {col: sessions[col].to_list() for col in extra_cols} if extra_cols else None
97+
return cls(sessions=sessions_items, weights=weights, extras=extras)
8798

8899

89100
class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attributes
@@ -133,6 +144,7 @@ def __init__(
133144
n_negatives: tp.Optional[int] = None,
134145
negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None,
135146
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
147+
extra_cols_kwargs: tp.Optional[InitKwargs] = None,
136148
**kwargs: tp.Any,
137149
) -> None:
138150
self.item_id_map: IdMap
@@ -148,6 +160,7 @@ def __init__(
148160
self.shuffle_train = shuffle_train
149161
self.get_val_mask_func = get_val_mask_func
150162
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs
163+
self.extra_cols_kwargs = extra_cols_kwargs
151164

152165
def get_known_items_sorted_internal_ids(self) -> np.ndarray:
153166
"""Return internal item ids from processed dataset in sorted order."""
@@ -231,7 +244,10 @@ def process_dataset_train(self, dataset: Dataset) -> None:
231244

232245
# Prepare train dataset
233246
# User features are dropped for now because model doesn't support them
234-
final_interactions = Interactions.from_raw(interactions, user_id_map, item_id_map, keep_extra_cols=True)
247+
keep_extra_cols = self.extra_cols_kwargs is not None
248+
final_interactions = Interactions.from_raw(
249+
interactions, user_id_map, item_id_map, keep_extra_cols=keep_extra_cols
250+
)
235251
self.train_dataset = Dataset(user_id_map, item_id_map, final_interactions, item_features=item_features)
236252
self.item_id_map = self.train_dataset.item_id_map
237253
self._init_extra_token_ids()
@@ -246,7 +262,9 @@ def process_dataset_train(self, dataset: Dataset) -> None:
246262
val_interactions = interactions[interactions[Columns.User].isin(val_targets[Columns.User].unique())].copy()
247263
val_interactions[Columns.Weight] = 0
248264
val_interactions = pd.concat([val_interactions, val_targets], axis=0)
249-
self.val_interactions = Interactions.from_raw(val_interactions, user_id_map, item_id_map).df
265+
self.val_interactions = Interactions.from_raw(
266+
val_interactions, user_id_map, item_id_map, keep_extra_cols=keep_extra_cols
267+
).df
250268

251269
def _init_extra_token_ids(self) -> None:
252270
extra_token_ids = self.item_id_map.convert_to_internal(self.item_extra_tokens)
@@ -261,7 +279,9 @@ def get_dataloader_train(self) -> DataLoader:
261279
DataLoader
262280
Train dataloader.
263281
"""
264-
sequence_dataset = SequenceDataset.from_interactions(self.train_dataset.interactions.df)
282+
sequence_dataset = SequenceDataset.from_interactions(
283+
self.train_dataset.interactions.df, **self._ensure_kwargs_dict(self.extra_cols_kwargs)
284+
)
265285
train_dataloader = DataLoader(
266286
sequence_dataset,
267287
collate_fn=self._collate_fn_train,
@@ -283,7 +303,9 @@ def get_dataloader_val(self) -> tp.Optional[DataLoader]:
283303
if self.val_interactions is None:
284304
return None
285305

286-
sequence_dataset = SequenceDataset.from_interactions(self.val_interactions)
306+
sequence_dataset = SequenceDataset.from_interactions(
307+
self.val_interactions, **self._ensure_kwargs_dict(self.extra_cols_kwargs)
308+
)
287309
val_dataloader = DataLoader(
288310
sequence_dataset,
289311
collate_fn=self._collate_fn_val,
@@ -306,7 +328,9 @@ def get_dataloader_recommend(self, dataset: Dataset, batch_size: int) -> DataLoa
306328
# User ids here are internal user ids in dataset.interactions.df that was prepared for recommendations.
307329
# Sorting sessions by user ids will ensure that these ids will also be correct indexes in user embeddings matrix
308330
# that will be returned by the net.
309-
sequence_dataset = SequenceDataset.from_interactions(interactions=dataset.interactions.df, sort_users=True)
331+
sequence_dataset = SequenceDataset.from_interactions(
332+
interactions=dataset.interactions.df, sort_users=True, **self._ensure_kwargs_dict(self.extra_cols_kwargs)
333+
)
310334
recommend_dataloader = DataLoader(
311335
sequence_dataset,
312336
batch_size=batch_size,
@@ -359,7 +383,10 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset
359383
if n_filtered > 0:
360384
explanation = f"""{n_filtered} target users were considered cold because of missing known items"""
361385
warnings.warn(explanation)
362-
filtered_interactions = Interactions.from_raw(interactions, rec_user_id_map, self.item_id_map)
386+
keep_extra_cols = self.extra_cols_kwargs is not None
387+
filtered_interactions = Interactions.from_raw(
388+
interactions, rec_user_id_map, self.item_id_map, keep_extra_cols=keep_extra_cols
389+
)
363390
filtered_dataset = Dataset(rec_user_id_map, self.item_id_map, filtered_interactions)
364391
return filtered_dataset
365392

@@ -383,24 +410,27 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset:
383410
"""
384411
interactions = dataset.get_raw_interactions()
385412
interactions = interactions[interactions[Columns.Item].isin(self.get_known_item_ids())]
386-
filtered_interactions = Interactions.from_raw(interactions, dataset.user_id_map, self.item_id_map)
413+
keep_extra_cols = self.extra_cols_kwargs is not None
414+
filtered_interactions = Interactions.from_raw(
415+
interactions, dataset.user_id_map, self.item_id_map, keep_extra_cols
416+
)
387417
filtered_dataset = Dataset(dataset.user_id_map, self.item_id_map, filtered_interactions)
388418
return filtered_dataset
389419

390420
def _collate_fn_train(
391421
self,
392-
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
422+
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
393423
) -> tp.Dict[str, torch.Tensor]:
394424
raise NotImplementedError()
395425

396426
def _collate_fn_val(
397427
self,
398-
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
428+
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
399429
) -> tp.Dict[str, torch.Tensor]:
400430
raise NotImplementedError()
401431

402432
def _collate_fn_recommend(
403433
self,
404-
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
434+
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
405435
) -> tp.Dict[str, torch.Tensor]:
406436
raise NotImplementedError()

rectools/models/nn/transformers/sasrec.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import typing as tp
16-
from typing import Dict, List, Tuple
16+
from typing import Dict
1717

1818
import numpy as np
1919
import torch
@@ -80,7 +80,7 @@ class SASRecDataPreparator(TransformerDataPreparatorBase):
8080

8181
def _collate_fn_train(
8282
self,
83-
batch: List[Tuple[List[int], List[float]]],
83+
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
8484
) -> Dict[str, torch.Tensor]:
8585
"""
8686
Truncate each session from right to keep `session_max_len` items.
@@ -91,7 +91,7 @@ def _collate_fn_train(
9191
x = np.zeros((batch_size, self.session_max_len))
9292
y = np.zeros((batch_size, self.session_max_len))
9393
yw = np.zeros((batch_size, self.session_max_len))
94-
for i, (ses, ses_weights) in enumerate(batch):
94+
for i, (ses, ses_weights, _) in enumerate(batch):
9595
x[i, -len(ses) + 1 :] = ses[:-1] # ses: [session_len] -> x[i]: [session_max_len]
9696
y[i, -len(ses) + 1 :] = ses[1:] # ses: [session_len] -> y[i]: [session_max_len]
9797
yw[i, -len(ses) + 1 :] = ses_weights[1:] # ses_weights: [session_len] -> yw[i]: [session_max_len]
@@ -103,12 +103,14 @@ def _collate_fn_train(
103103
)
104104
return batch_dict
105105

106-
def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
106+
def _collate_fn_val(
107+
self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]]
108+
) -> Dict[str, torch.Tensor]:
107109
batch_size = len(batch)
108110
x = np.zeros((batch_size, self.session_max_len))
109111
y = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses
110112
yw = np.zeros((batch_size, 1)) # Only leave-one-strategy is supported for losses
111-
for i, (ses, ses_weights) in enumerate(batch):
113+
for i, (ses, ses_weights, _) in enumerate(batch):
112114
input_session = [ses[idx] for idx, weight in enumerate(ses_weights) if weight == 0]
113115

114116
# take only first target for leave-one-strategy
@@ -126,10 +128,12 @@ def _collate_fn_val(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[st
126128
)
127129
return batch_dict
128130

129-
def _collate_fn_recommend(self, batch: List[Tuple[List[int], List[float]]]) -> Dict[str, torch.Tensor]:
131+
def _collate_fn_recommend(
132+
self, batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]]
133+
) -> Dict[str, torch.Tensor]:
130134
"""Right truncation, left padding to session_max_len"""
131135
x = np.zeros((len(batch), self.session_max_len))
132-
for i, (ses, _) in enumerate(batch):
136+
for i, (ses, _, _) in enumerate(batch):
133137
x[i, -len(ses) :] = ses[-self.session_max_len :]
134138
return {"x": torch.LongTensor(x)}
135139

rectools/models/nn/transformers/similarity.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,11 @@ class DistanceSimilarityModule(SimilarityModuleBase):
6262
dist_available: tp.List[str] = [Distance.DOT, Distance.COSINE]
6363
epsilon_cosine_dist: torch.Tensor = torch.tensor([1e-8])
6464

65-
def __init__(self, distance: str = "dot") -> None:
65+
def __init__(
66+
self,
67+
distance: str = "dot",
68+
**kwargs: tp.Any,
69+
) -> None:
6670
super().__init__()
6771
if distance not in self.dist_available:
6872
raise ValueError("`dist` can only be either `dot` or `cosine`.")

tests/models/nn/transformers/test_bert4rec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -640,13 +640,13 @@ def __init__(
640640

641641
def _collate_fn_train(
642642
self,
643-
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float]]],
643+
batch: tp.List[tp.Tuple[tp.List[int], tp.List[float], tp.Dict[str, tp.List[tp.Any]]]],
644644
) -> tp.Dict[str, torch.Tensor]:
645645
batch_size = len(batch)
646646
x = np.zeros((batch_size, self.session_max_len))
647647
y = np.zeros((batch_size, self.session_max_len))
648648
yw = np.zeros((batch_size, self.session_max_len))
649-
for i, (ses, ses_weights) in enumerate(batch):
649+
for i, (ses, ses_weights, _) in enumerate(batch):
650650
y[i, -self.n_last_targets] = ses[-self.n_last_targets]
651651
yw[i, -self.n_last_targets] = ses_weights[-self.n_last_targets]
652652
x[i, -len(ses) :] = ses

0 commit comments

Comments
 (0)