diff --git a/CHANGELOG.md b/CHANGELOG.md index 736244e5..e0cc54fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Python 3.13 support ([#227](https://github.com/MobileTeleSystems/RecTools/pull/227)) - `fit_partial` implementation for transformer-based models ([#273](https://github.com/MobileTeleSystems/RecTools/pull/273)) - `map_location` and `model_params_update` arguments for the function `load_from_checkpoint` for Transformer-based models. Use `map_location` to explicitly specify the computing new device and `model_params_update` to update original model parameters (e.g. remove training-specific parameters that are not needed anymore) ([#281](https://github.com/MobileTeleSystems/RecTools/pull/281)) +- `get_val_mask_func_kwargs` and `get_trainer_func_kwargs` arguments for Transformer-based models to allow keyword arguments in custom functions used for model training. ([#280](https://github.com/MobileTeleSystems/RecTools/pull/280)) ## [0.13.0] - 10.04.2025 diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index c2699978..14982cb9 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -38,7 +38,7 @@ ItemNetConstructorBase, SumOfEmbeddingsConstructor, ) -from .data_preparator import TransformerDataPreparatorBase +from .data_preparator import InitKwargs, TransformerDataPreparatorBase from .lightning import TransformerLightningModule, TransformerLightningModuleBase from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase from .net_blocks import ( @@ -50,8 +50,6 @@ from .similarity import DistanceSimilarityModule, SimilarityModuleBase from .torch_backbone import TransformerBackboneBase, TransformerTorchBackbone -InitKwargs = tp.Dict[str, tp.Any] - # #### -------------- Transformer Model Config -------------- #### # @@ -161,7 +159,7 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]: ] -ValMaskCallable = Callable[[], np.ndarray] +ValMaskCallable = Callable[..., np.ndarray] ValMaskCallableSerialized = tpe.Annotated[ ValMaskCallable, @@ -173,7 +171,7 @@ def _serialize_type_sequence(obj: tp.Sequence[tp.Type]) -> tp.Tuple[str, ...]: ), ] -TrainerCallable = Callable[[], Trainer] +TrainerCallable = Callable[..., Trainer] TrainerCallableSerialized = tpe.Annotated[ TrainerCallable, @@ -220,6 +218,8 @@ class TransformerModelConfig(ModelConfig): backbone_type: TransformerBackboneType = TransformerTorchBackbone get_val_mask_func: tp.Optional[ValMaskCallableSerialized] = None get_trainer_func: tp.Optional[TrainerCallableSerialized] = None + get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None + get_trainer_func_kwargs: tp.Optional[InitKwargs] = None data_preparator_kwargs: tp.Optional[InitKwargs] = None transformer_layers_kwargs: tp.Optional[InitKwargs] = None item_net_constructor_kwargs: tp.Optional[InitKwargs] = None @@ -280,6 +280,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, + get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None, + get_trainer_func_kwargs: tp.Optional[InitKwargs] = None, data_preparator_kwargs: tp.Optional[InitKwargs] = None, transformer_layers_kwargs: tp.Optional[InitKwargs] = None, item_net_constructor_kwargs: tp.Optional[InitKwargs] = None, @@ -321,6 +323,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals self.backbone_type = backbone_type self.get_val_mask_func = get_val_mask_func self.get_trainer_func = get_trainer_func + self.get_val_mask_func_kwargs = get_val_mask_func_kwargs + self.get_trainer_func_kwargs = get_trainer_func_kwargs self.data_preparator_kwargs = data_preparator_kwargs self.transformer_layers_kwargs = transformer_layers_kwargs self.item_net_constructor_kwargs = item_net_constructor_kwargs @@ -354,6 +358,7 @@ def _init_data_preparator(self) -> None: negative_sampler=self._init_negative_sampler() if requires_negatives else None, n_negatives=self.n_negatives if requires_negatives else None, get_val_mask_func=self.get_val_mask_func, + get_val_mask_func_kwargs=self.get_val_mask_func_kwargs, **self._get_kwargs(self.data_preparator_kwargs), ) @@ -370,7 +375,7 @@ def _init_trainer(self) -> None: devices=1, ) else: - self._trainer = self.get_trainer_func() + self._trainer = self.get_trainer_func(**self._get_kwargs(self.get_trainer_func_kwargs)) def _init_negative_sampler(self) -> TransformerNegativeSamplerBase: return self.negative_sampler_type( diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index 23ca8a9a..8e31d6ff 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -27,7 +27,6 @@ SumOfEmbeddingsConstructor, ) from .base import ( - InitKwargs, TrainerCallable, TransformerDataPreparatorType, TransformerLightningModule, @@ -37,7 +36,7 @@ ValMaskCallable, ) from .constants import MASKING_VALUE, PADDING_VALUE -from .data_preparator import TransformerDataPreparatorBase +from .data_preparator import InitKwargs, TransformerDataPreparatorBase from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase from .net_blocks import ( LearnableInversePositionalEncoding, @@ -72,6 +71,9 @@ class BERT4RecDataPreparator(TransformerDataPreparatorBase): Negative sampler. mask_prob : float, default 0.15 Probability of masking an item in interactions sequence. + get_val_mask_func_kwargs: optional(InitKwargs), default ``None`` + Additional arguments for the get_val_mask_func. + Make sure all dict values have JSON serializable types. """ train_session_max_len_addition: int = 0 @@ -88,6 +90,7 @@ def __init__( mask_prob: float = 0.15, get_val_mask_func: tp.Optional[ValMaskCallable] = None, shuffle_train: bool = True, + get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None, **kwargs: tp.Any, ) -> None: super().__init__( @@ -99,6 +102,7 @@ def __init__( train_min_user_interactions=train_min_user_interactions, shuffle_train=shuffle_train, get_val_mask_func=get_val_mask_func, + get_val_mask_func_kwargs=get_val_mask_func_kwargs, ) self.mask_prob = mask_prob @@ -301,6 +305,12 @@ class BERT4RecModel(TransformerModelBase[BERT4RecModelConfig]): When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise. If you want to change this parameter after model is initialized, you can manually assign new value to model `recommend_torch_device` attribute. + get_val_mask_func_kwargs: optional(InitKwargs), default ``None`` + Additional keyword arguments for the get_val_mask_func. + Make sure all dict values have JSON serializable types. + get_trainer_func_kwargs: optional(InitKwargs), default ``None`` + Additional keyword arguments for the get_trainer_func. + Make sure all dict values have JSON serializable types. data_preparator_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `data_preparator_type` initialization. Make sure all dict values have JSON serializable types. @@ -361,6 +371,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, + get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None, + get_trainer_func_kwargs: tp.Optional[InitKwargs] = None, recommend_batch_size: int = 256, recommend_torch_device: tp.Optional[str] = None, recommend_use_torch_ranking: bool = True, @@ -411,6 +423,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals backbone_type=backbone_type, get_val_mask_func=get_val_mask_func, get_trainer_func=get_trainer_func, + get_val_mask_func_kwargs=get_val_mask_func_kwargs, + get_trainer_func_kwargs=get_trainer_func_kwargs, data_preparator_kwargs=data_preparator_kwargs, transformer_layers_kwargs=transformer_layers_kwargs, item_net_block_kwargs=item_net_block_kwargs, @@ -433,5 +447,6 @@ def _init_data_preparator(self) -> None: train_min_user_interactions=self.train_min_user_interactions, mask_prob=self.mask_prob, get_val_mask_func=self.get_val_mask_func, + get_val_mask_func_kwargs=self.get_val_mask_func_kwargs, **self._get_kwargs(self.data_preparator_kwargs), ) diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index 8787e950..b13ec87d 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -31,6 +31,8 @@ from .constants import PADDING_VALUE from .negative_sampler import TransformerNegativeSamplerBase +InitKwargs = tp.Dict[str, tp.Any] + class SequenceDataset(TorchDataset): """ @@ -84,7 +86,7 @@ def from_interactions( return cls(sessions=sessions, weights=weights) -class TransformerDataPreparatorBase: +class TransformerDataPreparatorBase: # pylint: disable=too-many-instance-attributes """ Base class for data preparator. To change train/recommend dataset processing, train/recommend dataloaders inherit from this class and pass your custom data preparator to your model parameters. @@ -109,6 +111,9 @@ class TransformerDataPreparatorBase: Number of negatives for BCE, gBCE and sampled_softmax losses. negative_sampler: optional(TransformerNegativeSamplerBase), default ``None`` Negative sampler. + get_val_mask_func_kwargs: optional(InitKwargs), default ``None`` + Additional keyword arguments for the get_val_mask_func. + Make sure all dict values have JSON serializable types. """ # We sometimes need data preparators to add +1 to actual session_max_len @@ -127,6 +132,7 @@ def __init__( shuffle_train: bool = True, n_negatives: tp.Optional[int] = None, negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None, + get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None, **kwargs: tp.Any, ) -> None: self.item_id_map: IdMap @@ -141,6 +147,7 @@ def __init__( self.train_min_user_interactions = train_min_user_interactions self.shuffle_train = shuffle_train self.get_val_mask_func = get_val_mask_func + self.get_val_mask_func_kwargs = get_val_mask_func_kwargs def get_known_items_sorted_internal_ids(self) -> np.ndarray: """Return internal item ids from processed dataset in sorted order.""" @@ -150,6 +157,13 @@ def get_known_item_ids(self) -> np.ndarray: """Return external item ids from processed dataset in sorted order.""" return self.item_id_map.get_external_sorted_by_internal()[self.n_item_extra_tokens :] + @staticmethod + def _ensure_kwargs_dict(actual_kwargs: tp.Optional[InitKwargs]) -> InitKwargs: + kwargs = {} + if actual_kwargs is not None: + kwargs = actual_kwargs + return kwargs + @property def n_item_extra_tokens(self) -> int: """Return number of padding elements""" @@ -194,7 +208,9 @@ def process_dataset_train(self, dataset: Dataset) -> None: # Exclude val interaction targets from train if needed interactions = raw_interactions if self.get_val_mask_func is not None: - val_mask = self.get_val_mask_func(raw_interactions) + val_mask = self.get_val_mask_func( + raw_interactions, **self._ensure_kwargs_dict(self.get_val_mask_func_kwargs) + ) interactions = raw_interactions[~val_mask] interactions.reset_index(drop=True, inplace=True) diff --git a/rectools/models/nn/transformers/sasrec.py b/rectools/models/nn/transformers/sasrec.py index eebd9e73..b8350f72 100644 --- a/rectools/models/nn/transformers/sasrec.py +++ b/rectools/models/nn/transformers/sasrec.py @@ -27,7 +27,6 @@ SumOfEmbeddingsConstructor, ) from .base import ( - InitKwargs, TrainerCallable, TransformerDataPreparatorType, TransformerLayersType, @@ -37,7 +36,7 @@ TransformerModelConfig, ValMaskCallable, ) -from .data_preparator import TransformerDataPreparatorBase +from .data_preparator import InitKwargs, TransformerDataPreparatorBase from .negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase from .net_blocks import ( LearnableInversePositionalEncoding, @@ -72,6 +71,9 @@ class SASRecDataPreparator(TransformerDataPreparatorBase): Number of negatives for BCE, gBCE and sampled_softmax losses. negative_sampler: optional(TransformerNegativeSamplerBase), default ``None`` Negative sampler. + get_val_mask_func_kwargs: optional(InitKwargs), default ``None`` + Additional arguments for the get_val_mask_func. + Make sure all dict values have JSON serializable types. """ train_session_max_len_addition: int = 1 @@ -379,6 +381,12 @@ class SASRecModel(TransformerModelBase[SASRecModelConfig]): When set to ``None``, "cuda" will be used if it is available, "cpu" otherwise. If you want to change this parameter after model is initialized, you can manually assign new value to model `recommend_torch_device` attribute. + get_val_mask_func_kwargs: optional(InitKwargs), default ``None`` + Additional keyword arguments for the get_val_mask_func. + Make sure all dict values have JSON serializable types. + get_trainer_func_kwargs: optional(InitKwargs), default ``None`` + Additional keyword arguments for the get_trainer_func. + Make sure all dict values have JSON serializable types. data_preparator_kwargs: optional(dict), default ``None`` Additional keyword arguments to pass during `data_preparator_type` initialization. Make sure all dict values have JSON serializable types. @@ -438,6 +446,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals backbone_type: tp.Type[TransformerBackboneBase] = TransformerTorchBackbone, get_val_mask_func: tp.Optional[ValMaskCallable] = None, get_trainer_func: tp.Optional[TrainerCallable] = None, + get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None, + get_trainer_func_kwargs: tp.Optional[InitKwargs] = None, recommend_batch_size: int = 256, recommend_torch_device: tp.Optional[str] = None, recommend_use_torch_ranking: bool = True, @@ -485,6 +495,8 @@ def __init__( # pylint: disable=too-many-arguments, too-many-locals backbone_type=backbone_type, get_val_mask_func=get_val_mask_func, get_trainer_func=get_trainer_func, + get_val_mask_func_kwargs=get_val_mask_func_kwargs, + get_trainer_func_kwargs=get_trainer_func_kwargs, data_preparator_kwargs=data_preparator_kwargs, transformer_layers_kwargs=transformer_layers_kwargs, item_net_constructor_kwargs=item_net_constructor_kwargs, diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 6f0a0c34..140389aa 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# pylint: disable=too-many-lines import typing as tp from functools import partial @@ -33,6 +34,7 @@ TransformerLightningModule, ) from rectools.models.nn.transformers.bert4rec import MASKING_VALUE, BERT4RecDataPreparator, ValMaskCallable +from rectools.models.nn.transformers.data_preparator import InitKwargs from rectools.models.nn.transformers.negative_sampler import CatalogUniformSampler, TransformerNegativeSamplerBase from rectools.models.nn.transformers.similarity import DistanceSimilarityModule from rectools.models.nn.transformers.torch_backbone import TransformerTorchBackbone @@ -114,6 +116,34 @@ def get_trainer() -> Trainer: return get_trainer + @pytest.fixture + def get_custom_trainer_func(self) -> TrainerCallable: + def get_trainer_func(max_epochs: int, accelerator: str) -> Trainer: + return Trainer( + max_epochs=max_epochs, + min_epochs=2, + deterministic=True, + accelerator=accelerator, + enable_checkpointing=False, + devices=1, + ) + + return get_trainer_func + + @pytest.fixture + def get_custom_val_mask_func(self) -> ValMaskCallable: + def get_val_mask_func(interactions: pd.DataFrame, val_users: tp.List[int]) -> np.ndarray: + rank = ( + interactions.sort_values(Columns.Datetime, ascending=False, kind="stable") + .groupby(Columns.User, sort=False) + .cumcount() + + 1 + ) + val_mask = (interactions[Columns.User].isin(val_users)) & (rank <= 1) + return val_mask.values + + return get_val_mask_func + @pytest.mark.parametrize( "accelerator,n_devices,recommend_torch_device", [ @@ -549,7 +579,36 @@ def test_recommend_for_cold_user_with_hot_item( actual, ) - def test_customized_happy_path(self, dataset_devices: Dataset, get_trainer_func: TrainerCallable) -> None: + @pytest.mark.parametrize( + "get_custom_trainer_func_kwargs, get_custom_val_mask_func_kwargs", + ( + pytest.param( + { + "max_epochs": 2, + "accelerator": "cpu", + }, + {"val_users": [30, 40]}, + id="cpu_config", + ), + pytest.param( + { + "max_epochs": 3, + "accelerator": "gpu", + }, + {"val_users": [20, 30]}, + marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="GPU is not available"), + id="gpu_config", + ), + ), + ) + def test_customized_happy_path( + self, + dataset_devices: Dataset, + get_custom_trainer_func: TrainerCallable, + get_custom_val_mask_func: ValMaskCallable, + get_custom_trainer_func_kwargs: InitKwargs, + get_custom_val_mask_func_kwargs: InitKwargs, + ) -> None: class NextActionDataPreparator(BERT4RecDataPreparator): def __init__( self, @@ -562,6 +621,7 @@ def __init__( negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None, shuffle_train: bool = True, get_val_mask_func: tp.Optional[ValMaskCallable] = None, + get_val_mask_func_kwargs: tp.Optional[InitKwargs] = None, n_last_targets: int = 1, # custom kwarg ) -> None: super().__init__( @@ -572,7 +632,8 @@ def __init__( train_min_user_interactions=train_min_user_interactions, negative_sampler=negative_sampler, shuffle_train=shuffle_train, - get_val_mask_func=get_val_mask_func, + get_val_mask_func=get_custom_val_mask_func, + get_val_mask_func_kwargs=get_custom_val_mask_func_kwargs, mask_prob=mask_prob, ) self.n_last_targets = n_last_targets @@ -607,7 +668,8 @@ def _collate_fn_train( epochs=2, deterministic=True, item_net_block_types=(IdEmbeddingsItemNet,), - get_trainer_func=get_trainer_func, + get_trainer_func=get_custom_trainer_func, + get_trainer_func_kwargs=get_custom_trainer_func_kwargs, data_preparator_type=NextActionDataPreparator, data_preparator_kwargs={"n_last_targets": 1}, similarity_module_type=DistanceSimilarityModule, @@ -815,6 +877,63 @@ def test_get_dataloader_val( for key, value in actual.items(): assert torch.equal(value, val_batch[key]) + @pytest.mark.parametrize( + "val_batch, val_users", + ( + ( + { + "x": torch.tensor([[0, 2, 4, 1]]), + "y": torch.tensor([[3]]), + "yw": torch.tensor([[1.0]]), + "negatives": torch.tensor([[[5, 2]]]), + }, + [10, 30], + ), + ( + { + "x": torch.tensor([[0, 2, 4, 1]]), + "y": torch.tensor([[3]]), + "yw": torch.tensor([[1.0]]), + "negatives": torch.tensor([[[5, 2]]]), + }, + [30], + ), + ), + ) + def test_get_dataloader_val_with_kwargs( + self, + dataset: Dataset, + val_batch: tp.Dict[tp.Any, tp.Any], + val_users: tp.List, + ) -> None: + + def get_custom_val_mask_func(interactions: pd.DataFrame, val_users: tp.List[int]) -> np.ndarray: + rank = ( + interactions.sort_values(Columns.Datetime, ascending=False, kind="stable") + .groupby(Columns.User, sort=False) + .cumcount() + + 1 + ) + val_mask = (interactions[Columns.User].isin(val_users)) & (rank <= 1) + return val_mask.values + + get_custom_val_mask_func_kwargs = {"val_users": val_users} + data_preparator_val_mask = BERT4RecDataPreparator( + session_max_len=4, + n_negatives=2, + train_min_user_interactions=2, + mask_prob=0.5, + batch_size=4, + dataloader_num_workers=0, + get_val_mask_func=get_custom_val_mask_func, + get_val_mask_func_kwargs=get_custom_val_mask_func_kwargs, + ) + data_preparator_val_mask.process_dataset_train(dataset) + dataloader = data_preparator_val_mask.get_dataloader_val() + actual = next(iter(dataloader)) # type: ignore + for key, value in actual.items(): + assert torch.equal(value, val_batch[key]) + class TestBERT4RecModelConfiguration: def setup_method(self) -> None: @@ -859,6 +978,8 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "mask_prob": 0.15, "get_val_mask_func": leave_one_out_mask, "get_trainer_func": None, + "get_val_mask_func_kwargs": None, + "get_trainer_func_kwargs": None, "data_preparator_kwargs": None, "transformer_layers_kwargs": None, "item_net_constructor_kwargs": None, diff --git a/tests/models/nn/transformers/test_sasrec.py b/tests/models/nn/transformers/test_sasrec.py index 24438cc4..d1bd5911 100644 --- a/tests/models/nn/transformers/test_sasrec.py +++ b/tests/models/nn/transformers/test_sasrec.py @@ -968,6 +968,8 @@ def initial_config(self) -> tp.Dict[str, tp.Any]: "backbone_type": TransformerTorchBackbone, "get_val_mask_func": leave_one_out_mask, "get_trainer_func": None, + "get_val_mask_func_kwargs": None, + "get_trainer_func_kwargs": None, "data_preparator_kwargs": None, "transformer_layers_kwargs": None, "item_net_constructor_kwargs": None,