Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Uncommited

### Added
- `get_val_mask_func_kwargs` and `get_trainer_func_kwargs` kwargs for corresponding functions


## [0.13.0] - 10.04.2025

### Added
Expand Down
17 changes: 11 additions & 6 deletions rectools/models/nn/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
ItemNetConstructorBase,
SumOfEmbeddingsConstructor,
)
from .data_preparator import TransformerDataPreparatorBase
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
from .lightning import TransformerLightningModule, TransformerLightningModuleBase
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
from .net_blocks import (
Expand All @@ -50,8 +50,6 @@
from .similarity import DistanceSimilarityModule, SimilarityModuleBase
from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone

InitKwargs = tp.Dict[str, tp.Any]

# #### -------------- Transformer Model Config -------------- #### #


Expand Down Expand Up @@ -161,7 +159,7 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
]


ValMaskCallable = Callable[[], np.ndarray]
ValMaskCallable = Callable[..., np.ndarray]

ValMaskCallableSerialized = tpe.Annotated[
ValMaskCallable,
Expand All @@ -173,7 +171,7 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]:
),
]

TrainerCallable = Callable[[], Trainer]
TrainerCallable = Callable[..., Trainer]

TrainerCallableSerialized = tpe.Annotated[
TrainerCallable,
Expand Down Expand Up @@ -220,6 +218,8 @@ class TransformerModelConfig(ModelConfig):
backbone_type: TransformerBackboneType = TransformerTorchBackbone
get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None
get_trainer_func: tp.Optional[TrainerCallableSerialized] = None
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None
data_preparator_kwargs: tp.Optional[InitKwargs] = None
transformer_layers_kwargs: tp.Optional[InitKwargs] = None
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None
Expand Down Expand Up @@ -280,6 +280,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None,
data_preparator_kwargs: tp.Optional[InitKwargs] = None,
transformer_layers_kwargs: tp.Optional[InitKwargs] = None,
item_net_constructor_kwargs: tp.Optional[InitKwargs] = None,
Expand Down Expand Up @@ -321,6 +323,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
self.backbone_type = backbone_type
self.get_val_mask_func = get_val_mask_func
self.get_trainer_func = get_trainer_func
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs
self.get_trainer_func_kwargs = get_trainer_func_kwargs
self.data_preparator_kwargs = data_preparator_kwargs
self.transformer_layers_kwargs = transformer_layers_kwargs
self.item_net_constructor_kwargs = item_net_constructor_kwargs
Expand Down Expand Up @@ -354,6 +358,7 @@ def _init_data_preparator(self) -> None:
negative_sampler=self._init_negative_sampler() if requires_negatives else None,
n_negatives=self.n_negatives if requires_negatives else None,
get_val_mask_func=self.get_val_mask_func,
get_val_mask_func_kwargs=self.get_val_mask_func_kwargs,
shuffle_train=True,
**self._get_kwargs(self.data_preparator_kwargs),
)
Expand All @@ -371,7 +376,7 @@ def _init_trainer(self) -> None:
devices=1,
)
else:
self._trainer = self.get_trainer_func()
self._trainer = self.get_trainer_func(**self._get_kwargs(self.get_trainer_func_kwargs))

def _init_negative_sampler(self) -> TransformerNegativeSamplerBase:
return self.negative_sampler_type(
Expand Down
19 changes: 17 additions & 2 deletions rectools/models/nn/transformers/bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
SumOfEmbeddingsConstructor,
)
from .base import (
InitKwargs,
TrainerCallable,
TransformerDataPreparatorType,
TransformerLightningModule,
Expand All @@ -37,7 +36,7 @@
ValMaskCallable,
)
from .constants import MASKING_VALUE, PADDING_VALUE
from .data_preparator import TransformerDataPreparatorBase
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
from .net_blocks import (
LearnableInversePositionalEncoding,
Expand Down Expand Up @@ -72,6 +71,9 @@ class BERT4RecDataPreparator(TransformerDataPreparatorBase):
Negative sampler.
mask_prob : float, default 0.15
Probability of masking an item in interactions sequence.
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
Additional arguments for the get_val_mask_func.
Make sure all dict values have JSON serializable types.
"""

train_session_max_len_addition: int = 0
Expand All @@ -88,6 +90,7 @@ def __init__(
mask_prob: float = 0.15,
shuffle_train: bool = True,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
**kwargs: tp.Any,
) -> None:
super().__init__(
Expand All @@ -99,6 +102,7 @@ def __init__(
train_min_user_interactions=train_min_user_interactions,
shuffle_train=shuffle_train,
get_val_mask_func=get_val_mask_func,
get_val_mask_func_kwargs=get_val_mask_func_kwargs,
)
self.mask_prob = mask_prob

Expand Down Expand Up @@ -301,6 +305,12 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]):
When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_torch_device` attribute.
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
Additional keyword arguments for the get_val_mask_func.
Make sure all dict values have JSON serializable types.
get_trainer_func_kwargs: optional(InitKwargs), default ``None``
Additional keyword arguments for the get_trainer_func.
Make sure all dict values have JSON serializable types.
data_preparator_kwargs: optional(dict), default ``None``
Additional keyword arguments to pass during `data_preparator_type` initialization.
Make sure all dict values have JSON serializable types.
Expand Down Expand Up @@ -361,6 +371,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None,
recommend_batch_size: int = 256,
recommend_torch_device: tp.Optional[str] = None,
recommend_use_torch_ranking: bool = True,
Expand Down Expand Up @@ -411,6 +423,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type=backbone_type,
get_val_mask_func=get_val_mask_func,
get_trainer_func=get_trainer_func,
get_val_mask_func_kwargs=get_val_mask_func_kwargs,
get_trainer_func_kwargs=get_trainer_func_kwargs,
data_preparator_kwargs=data_preparator_kwargs,
transformer_layers_kwargs=transformer_layers_kwargs,
item_net_block_kwargs=item_net_block_kwargs,
Expand All @@ -433,6 +447,7 @@ def _init_data_preparator(self) -> None:
train_min_user_interactions=self.train_min_user_interactions,
mask_prob=self.mask_prob,
get_val_mask_func=self.get_val_mask_func,
get_val_mask_func_kwargs=self.get_val_mask_func_kwargs,
shuffle_train=True,
**self._get_kwargs(self.data_preparator_kwargs),
)
19 changes: 18 additions & 1 deletion rectools/models/nn/transformers/data_preparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from .constants import PADDING_VALUE
from .negative_sampler import TransformerNegativeSamplerBase

InitKwargs = tp.Dict[str, tp.Any]


class SequenceDataset(TorchDataset):
"""
Expand Down Expand Up @@ -84,6 +86,9 @@ def from_interactions(
return cls(sessions=sessions, weights=weights)


# pylint: disable=too-many-instance-attributes


class TransformerDataPreparatorBase:
"""
Base class for data preparator. To change train/recommend dataset processing, train/recommend dataloaders inherit
Expand All @@ -109,6 +114,9 @@ class TransformerDataPreparatorBase:
Number of negatives for BCE, gBCE and sampled_softmax losses.
negative_sampler: optional(TransformerNegativeSamplerBase), default ``None``
Negative sampler.
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
Additional keyword arguments for the get_val_mask_func.
Make sure all dict values have JSON serializable types.
"""

# We sometimes need data preparators to add +1 to actual session_max_len
Expand All @@ -127,6 +135,7 @@ def __init__(
get_val_mask_func: tp.Optional[tp.Callable] = None,
n_negatives: tp.Optional[int] = None,
negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
**kwargs: tp.Any,
) -> None:
self.item_id_map: IdMap
Expand All @@ -141,6 +150,7 @@ def __init__(
self.train_min_user_interactions = train_min_user_interactions
self.shuffle_train = shuffle_train
self.get_val_mask_func = get_val_mask_func
self.get_val_mask_func_kwargs = get_val_mask_func_kwargs

def get_known_items_sorted_internal_ids(self) -> np.ndarray:
"""Return internal item ids from processed dataset in sorted order."""
Expand All @@ -150,6 +160,13 @@ def get_known_item_ids(self) -> np.ndarray:
"""Return external item ids from processed dataset in sorted order."""
return self.item_id_map.get_external_sorted_by_internal()[self.n_item_extra_tokens :]

@staticmethod
def _get_kwargs(actual_kwargs: tp.Optional[InitKwargs]) -> InitKwargs:
kwargs = {}
if actual_kwargs is not None:
kwargs = actual_kwargs
return kwargs

@property
def n_item_extra_tokens(self) -> int:
"""Return number of padding elements"""
Expand Down Expand Up @@ -194,7 +211,7 @@ def process_dataset_train(self, dataset: Dataset) -> None:
# Exclude val interaction targets from train if needed
interactions = raw_interactions
if self.get_val_mask_func is not None:
val_mask = self.get_val_mask_func(raw_interactions)
val_mask = self.get_val_mask_func(raw_interactions, **self._get_kwargs(self.get_val_mask_func_kwargs))
interactions = raw_interactions[~val_mask]
interactions.reset_index(drop=True, inplace=True)

Expand Down
16 changes: 14 additions & 2 deletions rectools/models/nn/transformers/sasrec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
SumOfEmbeddingsConstructor,
)
from .base import (
InitKwargs,
TrainerCallable,
TransformerDataPreparatorType,
TransformerLayersType,
Expand All @@ -37,7 +36,7 @@
TransformerModelConfig,
ValMaskCallable,
)
from .data_preparator import TransformerDataPreparatorBase
from .data_preparator import InitKwargs, TransformerDataPreparatorBase
from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase
from .net_blocks import (
LearnableInversePositionalEncoding,
Expand Down Expand Up @@ -72,6 +71,9 @@ class SASRecDataPreparator(TransformerDataPreparatorBase):
Number of negatives for BCE, gBCE and sampled_softmax losses.
negative_sampler: optional(TransformerNegativeSamplerBase), default ``None``
Negative sampler.
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
Additional arguments for the get_val_mask_func.
Make sure all dict values have JSON serializable types.
"""

train_session_max_len_addition: int = 1
Expand Down Expand Up @@ -379,6 +381,12 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]):
When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise.
If you want to change this parameter after model is initialized,
you can manually assign new value to model `recommend_torch_device` attribute.
get_val_mask_func_kwargs: optional(InitKwargs), default ``None``
Additional keyword arguments for the get_val_mask_func.
Make sure all dict values have JSON serializable types.
get_trainer_func_kwargs: optional(InitKwargs), default ``None``
Additional keyword arguments for the get_trainer_func.
Make sure all dict values have JSON serializable types.
data_preparator_kwargs: optional(dict), default ``None``
Additional keyword arguments to pass during `data_preparator_type` initialization.
Make sure all dict values have JSON serializable types.
Expand Down Expand Up @@ -438,6 +446,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone,
get_val_mask_func: tp.Optional[ValMaskCallable] = None,
get_trainer_func: tp.Optional[TrainerCallable] = None,
get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None,
get_trainer_func_kwargs: tp.Optional[InitKwargs] = None,
recommend_batch_size: int = 256,
recommend_torch_device: tp.Optional[str] = None,
recommend_use_torch_ranking: bool = True,
Expand Down Expand Up @@ -485,6 +495,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals
backbone_type=backbone_type,
get_val_mask_func=get_val_mask_func,
get_trainer_func=get_trainer_func,
get_val_mask_func_kwargs=get_val_mask_func_kwargs,
get_trainer_func_kwargs=get_trainer_func_kwargs,
data_preparator_kwargs=data_preparator_kwargs,
transformer_layers_kwargs=transformer_layers_kwargs,
item_net_constructor_kwargs=item_net_constructor_kwargs,
Expand Down
Loading
Loading