Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions rectools/models/nn/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
]


ValMaskCallable = Callable[[], np.ndarray]
ValMaskCallable = Callable[..., np.ndarray]

ValMaskCallableSerialized = tpe.Annotated[
ValMaskCallable,
Expand All @@ -173,7 +173,7 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
),
]

TrainerCallable = Callable[[], Trainer]
TrainerCallable = Callable[..., Trainer]

TrainerCallableSerialized = tpe.Annotated[
TrainerCallable,
Expand Down Expand Up @@ -220,6 +220,8 @@ class TransformerModelConfig(ModelConfig):
backbone_type: TransformerBackboneType = TransformerTorchBackbone
get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None
get_trainer_func: tp.Optional[TrainerCallableSerialized] = None
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None
data_preparator_kwargs: tp.Optional[InitKwargs] = None
transformer_layers_kwargs: tp.Optional[InitKwargs] = None
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None
Expand Down Expand Up @@ -280,6 +282,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None,
data_preparator_kwargs: tp.Optional[InitKwargs] = None,
transformer_layers_kwargs: tp.Optional[InitKwargs] = None,
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None,
Expand Down Expand Up @@ -321,6 +325,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
self.backbone_type = backbone_type
self.get_val_mask_func = get_val_mask_func
self.get_trainer_func = get_trainer_func
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs
self.get_trainer_func_kwargs = get_trainer_func_kwargs
self.data_preparator_kwargs = data_preparator_kwargs
self.transformer_layers_kwargs = transformer_layers_kwargs
self.item_net_constructor_kwargs = item_net_constructor_kwargs
Expand Down Expand Up @@ -354,6 +360,7 @@ def _init_data_preparator(self) -> None:
negative_sampler=self._init_negative_sampler() if requires_negatives else None,
n_negatives=self.n_negatives if requires_negatives else None,
get_val_mask_func=self.get_val_mask_func,
get_val_mask_func_kwargs=self.get_val_mask_func_kwargs,
shuffle_train=True,
**self._get_kwargs(self.data_preparator_kwargs),
)
Expand All @@ -371,7 +378,7 @@ def _init_trainer(self) -> None:
devices=1,
)
else:
self._trainer = self.get_trainer_func()
self._trainer = self.get_trainer_func(**self._get_kwargs(self.get_trainer_func_kwargs))

def _init_negative_sampler(self) -> TransformerNegativeSamplerBase:
return self.negative_sampler_type(
Expand Down
7 changes: 7 additions & 0 deletions rectools/models/nn/transformers/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
mask_prob: float = 0.15,
shuffle_train: bool = True,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
**kwargs: tp.Any,
) -> None:
super().__init__(
Expand All @@ -99,6 +100,7 @@ def __init__(
train_min_user_interactions=train_min_user_interactions,
shuffle_train=shuffle_train,
get_val_mask_func=get_val_mask_func,
get_val_mask_func_kwargs=get_val_mask_func_kwargs
)
self.mask_prob = mask_prob

Expand Down Expand Up @@ -361,6 +363,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None,
recommend_batch_size: int = 256,
recommend_torch_device: tp.Optional[str] = None,
recommend_use_torch_ranking: bool = True,
Expand Down Expand Up @@ -411,6 +415,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type=backbone_type,
get_val_mask_func=get_val_mask_func,
get_trainer_func=get_trainer_func,
get_val_mask_func_kwargs = get_val_mask_func_kwargs,
get_trainer_func_kwargs = get_trainer_func_kwargs,
data_preparator_kwargs=data_preparator_kwargs,
transformer_layers_kwargs=transformer_layers_kwargs,
item_net_block_kwargs=item_net_block_kwargs,
Expand All @@ -433,6 +439,7 @@ def _init_data_preparator(self) -> None:
train_min_user_interactions=self.train_min_user_interactions,
mask_prob=self.mask_prob,
get_val_mask_func=self.get_val_mask_func,
get_val_mask_func_kwargs= self.get_val_mask_func_kwargs,
shuffle_train=True,
**self._get_kwargs(self.data_preparator_kwargs),
)
14 changes: 11 additions & 3 deletions rectools/models/nn/transformers/data_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
from rectools.dataset import Dataset, Interactions
from rectools.dataset.features import DenseFeatures, Features, SparseFeatures
from rectools.dataset.identifiers import IdMap

from .constants import PADDING_VALUE
from .negative_sampler import TransformerNegativeSamplerBase

InitKwargs = tp.Dict[str, tp.Any]

class SequenceDataset(TorchDataset):
"""
Expand Down Expand Up @@ -127,6 +127,7 @@ def __init__(
get_val_mask_func: tp.Optional[tp.Callable] = None,
n_negatives: tp.Optional[int] = None,
negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
**kwargs: tp.Any,
) -> None:
self.item_id_map: IdMap
Expand All @@ -141,7 +142,7 @@ def __init__(
self.train_min_user_interactions = train_min_user_interactions
self.shuffle_train = shuffle_train
self.get_val_mask_func = get_val_mask_func

self.get_val_mask_func_kwargs = get_val_mask_func_kwargs
def get_known_items_sorted_internal_ids(self) -> np.ndarray:
"""Return internal item ids from processed dataset in sorted order."""
return self.item_id_map.get_sorted_internal()[self.n_item_extra_tokens :]
Expand All @@ -150,6 +151,13 @@ def get_known_item_ids(self) -> np.ndarray:
"""Return external item ids from processed dataset in sorted order."""
return self.item_id_map.get_external_sorted_by_internal()[self.n_item_extra_tokens :]

@staticmethod
def _get_kwargs(actual_kwargs: tp.Optional[InitKwargs]) -> InitKwargs:
kwargs = {}
if actual_kwargs is not None:
kwargs = actual_kwargs
return kwargs

@property
def n_item_extra_tokens(self) -> int:
"""Return number of padding elements"""
Expand Down Expand Up @@ -194,7 +202,7 @@ def process_dataset_train(self, dataset: Dataset) -> None:
# Exclude val interaction targets from train if needed
interactions = raw_interactions
if self.get_val_mask_func is not None:
val_mask = self.get_val_mask_func(raw_interactions)
val_mask = self.get_val_mask_func(raw_interactions, **self._get_kwargs(self.get_val_mask_func_kwargs))
interactions = raw_interactions[~val_mask]
interactions.reset_index(drop=True, inplace=True)

Expand Down
4 changes: 4 additions & 0 deletions rectools/models/nn/transformers/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None,
recommend_batch_size: int = 256,
recommend_torch_device: tp.Optional[str] = None,
recommend_use_torch_ranking: bool = True,
Expand Down Expand Up @@ -485,6 +487,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type=backbone_type,
get_val_mask_func=get_val_mask_func,
get_trainer_func=get_trainer_func,
get_val_mask_func_kwargs = get_val_mask_func_kwargs,
get_trainer_func_kwargs = get_trainer_func_kwargs,
data_preparator_kwargs=data_preparator_kwargs,
transformer_layers_kwargs=transformer_layers_kwargs,
item_net_constructor_kwargs=item_net_constructor_kwargs,
Expand Down
2 changes: 2 additions & 0 deletions tests/models/nn/transformers/test_bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,8 @@ def initial_config(self) -> tp.Dict[str, tp.Any]:
"mask_prob": 0.15,
"get_val_mask_func": leave_one_out_mask,
"get_trainer_func": None,
"get_val_mask_func_kwargs": None,
"get_trainer_func_kwargs": None,
"data_preparator_kwargs": None,
"transformer_layers_kwargs": None,
"item_net_constructor_kwargs": None,
Expand Down
2 changes: 2 additions & 0 deletions tests/models/nn/transformers/test_sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -968,6 +968,8 @@ def initial_config(self) -> tp.Dict[str, tp.Any]:
"backbone_type": TransformerTorchBackbone,
"get_val_mask_func": leave_one_out_mask,
"get_trainer_func": None,
"get_val_mask_func_kwargs": None,
"get_trainer_func_kwargs": None,
"data_preparator_kwargs": None,
"transformer_layers_kwargs": None,
"item_net_constructor_kwargs": None,
Expand Down
Loading