diff --git a/CHANGELOG.md b/CHANGELOG.md index e82e59a0..1c1e525b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added - - Python 3.13 support ([#227](https://github.com/MobileTeleSystems/RecTools/pull/227)) +- `fit_partial` implementation for transformer-based models ([#273](https://github.com/MobileTeleSystems/RecTools/pull/273)) ## [0.13.0] - 10.04.2025 diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 75029699..735bf4b6 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -354,7 +354,6 @@ def _init_data_preparator(self) -> None: negative_sampler=self._init_negative_sampler() if requires_negatives else None, n_negatives=self.n_negatives if requires_negatives else None, get_val_mask_func=self.get_val_mask_func, - shuffle_train=True, **self._get_kwargs(self.data_preparator_kwargs), ) @@ -457,14 +456,8 @@ def _init_lightning_model( **self._get_kwargs(self.lightning_module_kwargs), ) - def _fit( - self, - dataset: Dataset, - ) -> None: + def _build_model_from_dataset(self, dataset: Dataset) -> None: self.data_preparator.process_dataset_train(dataset) - train_dataloader = self.data_preparator.get_dataloader_train() - val_dataloader = self.data_preparator.get_dataloader_val() - item_model = self._construct_item_net(self.data_preparator.train_dataset) torch_model = self._init_torch_model(item_model) @@ -478,6 +471,13 @@ def _fit( model_config=model_config, ) + def _fit( + self, + dataset: Dataset, + ) -> None: + self._build_model_from_dataset(dataset) + train_dataloader = self.data_preparator.get_dataloader_train() + val_dataloader = self.data_preparator.get_dataloader_val() self.fit_trainer = deepcopy(self._trainer) self.fit_trainer.fit(self.lightning_model, train_dataloader, val_dataloader) @@ -491,6 +491,27 @@ def _custom_transform_dataset_i2i( ) -> Dataset: return self.data_preparator.transform_dataset_i2i(dataset) + def _fit_partial( + self, + dataset: Dataset, + min_epochs: int, + max_epochs: int, + ) -> None: + if not self.is_fitted: + self._build_model_from_dataset(dataset) + self.fit_trainer = deepcopy(self._trainer) + elif self.fit_trainer is None: + self.data_preparator.process_dataset_train(dataset) + self.fit_trainer = deepcopy(self._trainer) + + train_dataloader = self.data_preparator.get_dataloader_train() + val_dataloader = self.data_preparator.get_dataloader_val() + + self.lightning_model.train() + self.fit_trainer.fit_loop.max_epochs = self.fit_trainer.current_epoch + max_epochs + self.fit_trainer.fit_loop.min_epochs = self.fit_trainer.current_epoch + min_epochs + self.fit_trainer.fit(self.lightning_model, train_dataloader, val_dataloader) + def _recommend_u2i( self, user_ids: InternalIdsArray, @@ -570,6 +591,7 @@ def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self: item_external_ids=item_external_ids, model_config=model_config, ) + loaded.lightning_model.is_fitted = True loaded.lightning_model.load_state_dict(checkpoint["state_dict"]) return loaded diff --git a/rectools/models/nn/transformers/bert4rec.py b/rectools/models/nn/transformers/bert4rec.py index 7aef9099..23ca8a9a 100644 --- a/rectools/models/nn/transformers/bert4rec.py +++ b/rectools/models/nn/transformers/bert4rec.py @@ -86,8 +86,8 @@ def __init__( train_min_user_interactions: int, negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None, mask_prob: float = 0.15, - shuffle_train: bool = True, get_val_mask_func: tp.Optional[ValMaskCallable] = None, + shuffle_train: bool = True, **kwargs: tp.Any, ) -> None: super().__init__( @@ -433,6 +433,5 @@ def _init_data_preparator(self) -> None: train_min_user_interactions=self.train_min_user_interactions, mask_prob=self.mask_prob, get_val_mask_func=self.get_val_mask_func, - shuffle_train=True, **self._get_kwargs(self.data_preparator_kwargs), ) diff --git a/rectools/models/nn/transformers/data_preparator.py b/rectools/models/nn/transformers/data_preparator.py index 62ebb23f..8787e950 100644 --- a/rectools/models/nn/transformers/data_preparator.py +++ b/rectools/models/nn/transformers/data_preparator.py @@ -122,9 +122,9 @@ def __init__( session_max_len: int, batch_size: int, dataloader_num_workers: int, - shuffle_train: bool = True, train_min_user_interactions: int = 2, get_val_mask_func: tp.Optional[tp.Callable] = None, + shuffle_train: bool = True, n_negatives: tp.Optional[int] = None, negative_sampler: tp.Optional[TransformerNegativeSamplerBase] = None, **kwargs: tp.Any, diff --git a/rectools/models/nn/transformers/lightning.py b/rectools/models/nn/transformers/lightning.py index 8fbcd43c..df97f882 100644 --- a/rectools/models/nn/transformers/lightning.py +++ b/rectools/models/nn/transformers/lightning.py @@ -102,6 +102,8 @@ def __init__( self.verbose = verbose self.train_loss_name = train_loss_name self.val_loss_name = val_loss_name + self.is_fitted = False + self.optimizer: tp.Optional[torch.optim.Adam] = None self.item_embs: torch.Tensor self.save_hyperparameters(ignore=["torch_model", "data_preparator"]) @@ -207,8 +209,9 @@ def _calc_sampled_softmax_loss(self, logits: torch.Tensor, y: torch.Tensor, w: t def configure_optimizers(self) -> torch.optim.Adam: """Choose what optimizers and learning-rate schedulers to use in optimization""" - optimizer = torch.optim.Adam(self.torch_model.parameters(), lr=self.lr, betas=self.adam_betas) - return optimizer + if self.optimizer is None: + self.optimizer = torch.optim.Adam(self.torch_model.parameters(), lr=self.lr, betas=self.adam_betas) + return self.optimizer def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: """Training step.""" @@ -286,7 +289,8 @@ class TransformerLightningModule(TransformerLightningModuleBase): def on_train_start(self) -> None: """Initialize parameters with values from Xavier normal distribution.""" - self._xavier_normal_init() + if not self.is_fitted: + self._xavier_normal_init() def get_batch_logits(self, batch: tp.Dict[str, torch.Tensor]) -> torch.Tensor: """Get bacth logits.""" @@ -310,6 +314,10 @@ def training_step(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> to self.log(self.train_loss_name, loss, on_step=False, on_epoch=True, prog_bar=self.verbose > 0) return loss + def on_train_end(self) -> None: + """Save fitted state.""" + self.is_fitted = True + def _calc_custom_loss(self, batch: tp.Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: raise ValueError(f"loss {self.loss} is not supported") diff --git a/tests/models/nn/transformers/test_base.py b/tests/models/nn/transformers/test_base.py index 3b3b5110..c2ceba08 100644 --- a/tests/models/nn/transformers/test_base.py +++ b/tests/models/nn/transformers/test_base.py @@ -14,6 +14,7 @@ import os import typing as tp +from copy import deepcopy from tempfile import NamedTemporaryFile import pandas as pd @@ -27,6 +28,7 @@ from rectools.dataset import Dataset from rectools.models import BERT4RecModel, SASRecModel, load_model from rectools.models.nn.transformers.base import TransformerModelBase +from rectools.models.nn.transformers.lightning import TransformerLightningModule from tests.models.data import INTERACTIONS from tests.models.utils import assert_save_load_do_not_change_model @@ -306,6 +308,69 @@ def test_log_metrics( actual_columns = list(pd.read_csv(metrics_path).columns) assert actual_columns == expected_columns + @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) + def test_fit_partial( + self, + dataset: Dataset, + model_cls: tp.Type[TransformerModelBase], + ) -> None: + + class FixSeedLightningModule(TransformerLightningModule): + def on_train_epoch_start(self) -> None: + seed_everything(32, workers=True) + + seed_everything(32, workers=True) + model_1 = model_cls.from_config( + { + "epochs": 3, + "data_preparator_kwargs": {"shuffle_train": False}, + "get_trainer_func": custom_trainer, + "lightning_module_type": FixSeedLightningModule, + } + ) + model_1.fit(dataset) + + seed_everything(32, workers=True) + model_2 = model_cls.from_config( + { + "data_preparator_kwargs": {"shuffle_train": False}, + "get_trainer_func": custom_trainer, + "lightning_module_type": FixSeedLightningModule, + } + ) + model_2.fit_partial(dataset, min_epochs=2, max_epochs=2) + model_2.fit_partial(dataset, min_epochs=1, max_epochs=1) + + self._assert_same_reco(model_1, model_2, dataset) + + @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) + def test_fit_partial_from_checkpoint( + self, + dataset: Dataset, + model_cls: tp.Type[TransformerModelBase], + ) -> None: + fit_partial_model = model_cls.from_config( + {"data_preparator_kwargs": {"shuffle_train": False}, "get_trainer_func": custom_trainer_ckpt} + ) + fit_partial_model.fit_partial(dataset, min_epochs=1, max_epochs=1) + + assert fit_partial_model.fit_trainer is not None + if fit_partial_model.fit_trainer.log_dir is None: + raise ValueError("No log dir") + ckpt_path = os.path.join(fit_partial_model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt") + assert os.path.isfile(ckpt_path) + recovered_fit_partial_model = model_cls.load_from_checkpoint(ckpt_path) + + seed_everything(32, workers=True) + fit_partial_model.fit_trainer = deepcopy(fit_partial_model._trainer) # pylint: disable=protected-access + fit_partial_model.lightning_model.optimizer = None + fit_partial_model.fit_partial(dataset, min_epochs=1, max_epochs=1) + + seed_everything(32, workers=True) + recovered_fit_partial_model.fit_partial(dataset, min_epochs=1, max_epochs=1) + + self._assert_same_reco(fit_partial_model, recovered_fit_partial_model, dataset) + @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) def test_raises_when_incorrect_similarity_dist( self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index e1c9fb4b..68068475 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -700,7 +700,6 @@ def data_preparator(self) -> BERT4RecDataPreparator: batch_size=4, dataloader_num_workers=0, train_min_user_interactions=2, - shuffle_train=True, mask_prob=0.5, ) @@ -771,7 +770,6 @@ def test_get_dataloader_train_for_masked_session_with_random_replacement( batch_size=14, dataloader_num_workers=0, train_min_user_interactions=2, - shuffle_train=True, mask_prob=0.5, ) data_preparator.process_dataset_train(dataset_one_session)