@@ -46,23 +46,33 @@ class SequenceDataset(TorchDataset):
4646 Weight of each interaction from the session.
4747 """
4848
49- def __init__ (self , sessions : tp .List [tp .List [int ]], weights : tp .List [tp .List [float ]]):
49+ def __init__ (
50+ self ,
51+ sessions : tp .List [tp .List [int ]],
52+ weights : tp .List [tp .List [float ]],
53+ extras : tp .Optional [tp .Dict [str , tp .List [tp .Any ]]] = None ,
54+ ):
5055 self .sessions = sessions
5156 self .weights = weights
57+ self .extras = extras
5258
5359 def __len__ (self ) -> int :
5460 return len (self .sessions )
5561
56- def __getitem__ (self , index : int ) -> tp .Tuple [tp .List [int ], tp .List [float ]]:
62+ def __getitem__ (self , index : int ) -> tp .Tuple [tp .List [int ], tp .List [float ], tp . Dict [ str , tp . List [ tp . Any ]] ]:
5763 session = self .sessions [index ] # [session_len]
5864 weights = self .weights [index ] # [session_len]
59- return session , weights
65+ extras = (
66+ {feature_name : features [index ] for feature_name , features in self .extras .items ()} if self .extras else {}
67+ )
68+ return session , weights , extras
6069
6170 @classmethod
6271 def from_interactions (
6372 cls ,
6473 interactions : pd .DataFrame ,
6574 sort_users : bool = False ,
75+ extra_cols : tp .Optional [tp .List [str ]] = None ,
6676 ) -> "SequenceDataset" :
6777 """
6878 Group interactions by user.
@@ -73,17 +83,18 @@ def from_interactions(
7383 interactions : pd.DataFrame
7484 User-item interactions.
7585 """
86+ cols_to_agg = [col for col in interactions .columns if col != Columns .User ]
7687 sessions = (
7788 interactions .sort_values (Columns .Datetime , kind = "stable" )
78- .groupby (Columns .User , sort = sort_users )[[ Columns . Item , Columns . Weight ] ]
89+ .groupby (Columns .User , sort = sort_users )[cols_to_agg ]
7990 .agg (list )
8091 )
81- sessions , weights = (
92+ sessions_items , weights = (
8293 sessions [Columns .Item ].to_list (),
8394 sessions [Columns .Weight ].to_list (),
8495 )
85-
86- return cls (sessions = sessions , weights = weights )
96+ extras = { col : sessions [ col ]. to_list () for col in extra_cols } if extra_cols else None
97+ return cls (sessions = sessions_items , weights = weights , extras = extras )
8798
8899
89100class TransformerDataPreparatorBase : # pylint: disable=too-many-instance-attributes
@@ -133,6 +144,7 @@ def __init__(
133144 n_negatives : tp .Optional [int ] = None ,
134145 negative_sampler : tp .Optional [TransformerNegativeSamplerBase ] = None ,
135146 get_val_mask_func_kwargs : tp .Optional [InitKwargs ] = None ,
147+ extra_cols_kwargs : tp .Optional [InitKwargs ] = None ,
136148 ** kwargs : tp .Any ,
137149 ) -> None :
138150 self .item_id_map : IdMap
@@ -148,6 +160,7 @@ def __init__(
148160 self .shuffle_train = shuffle_train
149161 self .get_val_mask_func = get_val_mask_func
150162 self .get_val_mask_func_kwargs = get_val_mask_func_kwargs
163+ self .extra_cols_kwargs = extra_cols_kwargs
151164
152165 def get_known_items_sorted_internal_ids (self ) -> np .ndarray :
153166 """Return internal item ids from processed dataset in sorted order."""
@@ -231,7 +244,10 @@ def process_dataset_train(self, dataset: Dataset) -> None:
231244
232245 # Prepare train dataset
233246 # User features are dropped for now because model doesn't support them
234- final_interactions = Interactions .from_raw (interactions , user_id_map , item_id_map , keep_extra_cols = True )
247+ keep_extra_cols = self .extra_cols_kwargs is not None
248+ final_interactions = Interactions .from_raw (
249+ interactions , user_id_map , item_id_map , keep_extra_cols = keep_extra_cols
250+ )
235251 self .train_dataset = Dataset (user_id_map , item_id_map , final_interactions , item_features = item_features )
236252 self .item_id_map = self .train_dataset .item_id_map
237253 self ._init_extra_token_ids ()
@@ -246,7 +262,9 @@ def process_dataset_train(self, dataset: Dataset) -> None:
246262 val_interactions = interactions [interactions [Columns .User ].isin (val_targets [Columns .User ].unique ())].copy ()
247263 val_interactions [Columns .Weight ] = 0
248264 val_interactions = pd .concat ([val_interactions , val_targets ], axis = 0 )
249- self .val_interactions = Interactions .from_raw (val_interactions , user_id_map , item_id_map ).df
265+ self .val_interactions = Interactions .from_raw (
266+ val_interactions , user_id_map , item_id_map , keep_extra_cols = keep_extra_cols
267+ ).df
250268
251269 def _init_extra_token_ids (self ) -> None :
252270 extra_token_ids = self .item_id_map .convert_to_internal (self .item_extra_tokens )
@@ -261,7 +279,9 @@ def get_dataloader_train(self) -> DataLoader:
261279 DataLoader
262280 Train dataloader.
263281 """
264- sequence_dataset = SequenceDataset .from_interactions (self .train_dataset .interactions .df )
282+ sequence_dataset = SequenceDataset .from_interactions (
283+ self .train_dataset .interactions .df , ** self ._ensure_kwargs_dict (self .extra_cols_kwargs )
284+ )
265285 train_dataloader = DataLoader (
266286 sequence_dataset ,
267287 collate_fn = self ._collate_fn_train ,
@@ -283,7 +303,9 @@ def get_dataloader_val(self) -> tp.Optional[DataLoader]:
283303 if self .val_interactions is None :
284304 return None
285305
286- sequence_dataset = SequenceDataset .from_interactions (self .val_interactions )
306+ sequence_dataset = SequenceDataset .from_interactions (
307+ self .val_interactions , ** self ._ensure_kwargs_dict (self .extra_cols_kwargs )
308+ )
287309 val_dataloader = DataLoader (
288310 sequence_dataset ,
289311 collate_fn = self ._collate_fn_val ,
@@ -306,7 +328,9 @@ def get_dataloader_recommend(self, dataset: Dataset, batch_size: int) -> DataLoa
306328 # User ids here are internal user ids in dataset.interactions.df that was prepared for recommendations.
307329 # Sorting sessions by user ids will ensure that these ids will also be correct indexes in user embeddings matrix
308330 # that will be returned by the net.
309- sequence_dataset = SequenceDataset .from_interactions (interactions = dataset .interactions .df , sort_users = True )
331+ sequence_dataset = SequenceDataset .from_interactions (
332+ interactions = dataset .interactions .df , sort_users = True , ** self ._ensure_kwargs_dict (self .extra_cols_kwargs )
333+ )
310334 recommend_dataloader = DataLoader (
311335 sequence_dataset ,
312336 batch_size = batch_size ,
@@ -359,7 +383,10 @@ def transform_dataset_u2i(self, dataset: Dataset, users: ExternalIds) -> Dataset
359383 if n_filtered > 0 :
360384 explanation = f"""{ n_filtered } target users were considered cold because of missing known items"""
361385 warnings .warn (explanation )
362- filtered_interactions = Interactions .from_raw (interactions , rec_user_id_map , self .item_id_map )
386+ keep_extra_cols = self .extra_cols_kwargs is not None
387+ filtered_interactions = Interactions .from_raw (
388+ interactions , rec_user_id_map , self .item_id_map , keep_extra_cols = keep_extra_cols
389+ )
363390 filtered_dataset = Dataset (rec_user_id_map , self .item_id_map , filtered_interactions )
364391 return filtered_dataset
365392
@@ -383,24 +410,27 @@ def transform_dataset_i2i(self, dataset: Dataset) -> Dataset:
383410 """
384411 interactions = dataset .get_raw_interactions ()
385412 interactions = interactions [interactions [Columns .Item ].isin (self .get_known_item_ids ())]
386- filtered_interactions = Interactions .from_raw (interactions , dataset .user_id_map , self .item_id_map )
413+ keep_extra_cols = self .extra_cols_kwargs is not None
414+ filtered_interactions = Interactions .from_raw (
415+ interactions , dataset .user_id_map , self .item_id_map , keep_extra_cols
416+ )
387417 filtered_dataset = Dataset (dataset .user_id_map , self .item_id_map , filtered_interactions )
388418 return filtered_dataset
389419
390420 def _collate_fn_train (
391421 self ,
392- batch : tp .List [tp .Tuple [tp .List [int ], tp .List [float ]]],
422+ batch : tp .List [tp .Tuple [tp .List [int ], tp .List [float ], tp . Dict [ str , tp . List [ tp . Any ]] ]],
393423 ) -> tp .Dict [str , torch .Tensor ]:
394424 raise NotImplementedError ()
395425
396426 def _collate_fn_val (
397427 self ,
398- batch : tp .List [tp .Tuple [tp .List [int ], tp .List [float ]]],
428+ batch : tp .List [tp .Tuple [tp .List [int ], tp .List [float ], tp . Dict [ str , tp . List [ tp . Any ]] ]],
399429 ) -> tp .Dict [str , torch .Tensor ]:
400430 raise NotImplementedError ()
401431
402432 def _collate_fn_recommend (
403433 self ,
404- batch : tp .List [tp .Tuple [tp .List [int ], tp .List [float ]]],
434+ batch : tp .List [tp .Tuple [tp .List [int ], tp .List [float ], tp . Dict [ str , tp . List [ tp . Any ]] ]],
405435 ) -> tp .Dict [str , torch .Tensor ]:
406436 raise NotImplementedError ()
0 commit comments