diff --git a/CHANGELOG.md b/CHANGELOG.md index 80742792..29e4d6c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased ### Added - `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287)) +- Support for resaving transformer models multiple times and loading trainer state ([#289](https://github.com/MobileTeleSystems/RecTools/pull/289)) ### Fixed - [Breaking] Now `LastNSplitter` guarantees taking the last ordered interaction in dataframe in case of identical timestamps ([#288](https://github.com/MobileTeleSystems/RecTools/pull/288)) diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 14982cb9..fcd97735 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -24,6 +24,7 @@ import typing_extensions as tpe from pydantic import BeforeValidator, PlainSerializer from pytorch_lightning import Trainer +from torch.utils.data import DataLoader, TensorDataset from rectools import ExternalIds from rectools.dataset.dataset import Dataset, DatasetSchema, DatasetSchemaDict, IdMap @@ -505,16 +506,25 @@ def _fit_partial( if not self.is_fitted: self._build_model_from_dataset(dataset) self.fit_trainer = deepcopy(self._trainer) - elif self.fit_trainer is None: + else: + # assumed that dataset is same as in `fit` or as in first call to `fit_partial` + # currently new datasets is not supported due to difficulties with + # handling id maps and item (user) features self.data_preparator.process_dataset_train(dataset) - self.fit_trainer = deepcopy(self._trainer) + if self.fit_trainer is None: + raise RuntimeError("expected to have fit_trainer set") train_dataloader = self.data_preparator.get_dataloader_train() val_dataloader = self.data_preparator.get_dataloader_val() self.lightning_model.train() - self.fit_trainer.fit_loop.max_epochs = self.fit_trainer.current_epoch + max_epochs - self.fit_trainer.fit_loop.min_epochs = self.fit_trainer.current_epoch + min_epochs + + # if checkpoint is from ModelCheckpoint callback (and saved at end of epoch) + # its epoch value equal to num of data epochs - 1 (as epoch is not ended in checkpoint time) + # so instead of `fit_trainer.current_epoch` we use `count of ready epochs` + current_epoch = self.fit_trainer.fit_loop.epoch_progress.current.ready + self.fit_trainer.fit_loop.max_epochs = current_epoch + max_epochs + self.fit_trainer.fit_loop.min_epochs = current_epoch + min_epochs self.fit_trainer.fit(self.lightning_model, train_dataloader, val_dataloader) def _recommend_u2i( @@ -574,8 +584,25 @@ def _get_config(self) -> TransformerModelConfig_T: return self.config_class(**params) @classmethod - def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self: - """Create model from loaded Lightning checkpoint.""" + def _model_from_checkpoint( + cls, checkpoint: tp.Dict[str, tp.Any], ckpt_path: tp.Optional[tp.Union[str, Path]] = None + ) -> tpe.Self: + """ + Create model from loaded Lightning checkpoint. + + Parameters + ---------- + checkpoint: Dict[str, tp.Any] + Checkpoint object (pl/torch like) + ckpt_path: Union[str, Path], optional + Path to checkpoint location. + If specified should be a path to `checkpoint` arg file. + `checkpoint` is saved to temp file if not specified. + + Returns + ------- + Model instance. + """ model_config = checkpoint["hyper_parameters"]["model_config"] loaded = cls.from_config(model_config) loaded.is_fitted = True @@ -596,20 +623,36 @@ def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self: item_external_ids=item_external_ids, model_config=model_config, ) + + try: + temp_file = None + actual_ckpt_path = ckpt_path + if actual_ckpt_path is None: + temp_file = NamedTemporaryFile() # pylint: disable=consider-using-with + actual_ckpt_path = temp_file.name + torch.save(checkpoint, actual_ckpt_path) + + loaded.fit_trainer = deepcopy(loaded._trainer) + # use stub dataset to load trainer state + loaded.fit_trainer.fit( + loaded.lightning_model, + ckpt_path=actual_ckpt_path, + train_dataloaders=DataLoader(TensorDataset(torch.Tensor())), + ) + + finally: + if temp_file is not None: + temp_file.close() + loaded.lightning_model.is_fitted = True - loaded.lightning_model.load_state_dict(checkpoint["state_dict"]) return loaded def __getstate__(self) -> object: if self.is_fitted: if self.fit_trainer is None: - explanation = """ - Model is fitted but has no `fit_trainer`. Most likely it was just loaded from the - checkpoint. Model that was loaded from checkpoint cannot be saved without being - fitted again. - """ - raise RuntimeError(explanation) + raise RuntimeError("Fitted model is expected to have `fit_trainer` set") + with NamedTemporaryFile() as f: self.fit_trainer.save_checkpoint(f.name) checkpoint = Path(f.name).read_bytes() @@ -658,7 +701,7 @@ def load_from_checkpoint( prev_config_flatten = make_dict_flat(prev_model_config) prev_config_flatten.update(model_params_update) checkpoint["hyper_parameters"]["model_config"] = unflatten_dict(prev_config_flatten) - loaded = cls._model_from_checkpoint(checkpoint) + loaded = cls._model_from_checkpoint(checkpoint, ckpt_path=checkpoint_path) return loaded def load_weights_from_checkpoint(self, checkpoint_path: tp.Union[str, Path]) -> None: diff --git a/tests/models/nn/transformers/test_base.py b/tests/models/nn/transformers/test_base.py index da3df2c7..eb68a9e2 100644 --- a/tests/models/nn/transformers/test_base.py +++ b/tests/models/nn/transformers/test_base.py @@ -19,10 +19,12 @@ import pandas as pd import pytest +import pytorch_lightning as pl import torch from pytest import FixtureRequest from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.loggers import CSVLogger +from torch import nn from rectools import Columns from rectools.dataset import Dataset @@ -35,6 +37,35 @@ from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask +def assert_torch_models_equal(model_a: nn.Module, model_b: nn.Module) -> None: + assert type(model_a) is type(model_b), "different types" + + with torch.no_grad(): + for (apn, apv), (bpn, bpv) in zip(model_a.named_parameters(), model_b.named_parameters()): + assert apn == bpn, "different parameter name" + assert torch.isclose(apv, bpv).all(), "different parameter value" + + +def assert_pl_models_equal(model_a: pl.LightningModule, model_b: pl.LightningModule) -> None: + """Assert pl modules are equal in terms of weights and trainer""" + assert_torch_models_equal(model_a, model_b) + + trainer_a = model_a.trainer + trainer_b = model_a.trainer + + assert_pl_trainers_equal(trainer_a, trainer_b) + + +def assert_pl_trainers_equal(trainer_a: Trainer, trainer_b: Trainer) -> None: + """Assert pl trainers are equal in terms of optimizers state""" + assert len(trainer_a.optimizers) == len(trainer_b.optimizers), "Different number of optimizers" + + for opt_a, opt_b in zip(trainer_b.optimizers, trainer_b.optimizers): + # Check optimizer class + assert type(opt_a) is type(opt_b), f"Optimizer types differ: {type(opt_a)} vs {type(opt_b)}" + assert opt_a.state_dict() == opt_b.state_dict(), "optimizers state dict differs" + + class TestTransformerModelBase: def setup_method(self) -> None: torch.use_deterministic_algorithms(True) @@ -209,28 +240,6 @@ def test_load_from_checkpoint( self._assert_same_reco(model, recovered_model, dataset) - @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) - def test_raises_when_save_model_loaded_from_checkpoint( - self, - model_cls: tp.Type[TransformerModelBase], - dataset: Dataset, - ) -> None: - model = model_cls.from_config( - { - "deterministic": True, - "get_trainer_func": custom_trainer_ckpt, - } - ) - model.fit(dataset) - assert model.fit_trainer is not None - if model.fit_trainer.log_dir is None: - raise ValueError("No log dir") - ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt") - recovered_model = model_cls.load_from_checkpoint(ckpt_path) - with pytest.raises(RuntimeError): - with NamedTemporaryFile() as f: - recovered_model.save(f.name) - @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) def test_load_weights_from_checkpoint( self, @@ -391,8 +400,6 @@ def test_fit_partial_from_checkpoint( recovered_fit_partial_model = model_cls.load_from_checkpoint(ckpt_path) seed_everything(32, workers=True) - fit_partial_model.fit_trainer = deepcopy(fit_partial_model._trainer) # pylint: disable=protected-access - fit_partial_model.lightning_model.optimizer = None fit_partial_model.fit_partial(dataset, min_epochs=1, max_epochs=1) seed_everything(32, workers=True) @@ -410,3 +417,108 @@ def test_raises_when_incorrect_similarity_dist( with pytest.raises(ValueError): model = model_cls.from_config(model_config) model.fit(dataset=dataset) + + @pytest.mark.parametrize("fit", (True, False)) + @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) + @pytest.mark.parametrize("default_trainer", (True, False)) + def test_resaving( + self, + model_cls: tp.Type[TransformerModelBase], + dataset: Dataset, + default_trainer: bool, + fit: bool, + ) -> None: + config: tp.Dict[str, tp.Any] = {"deterministic": True} + if not default_trainer: + config["get_trainer_func"] = custom_trainer + model = model_cls.from_config(config) + + seed_everything(32, workers=True) + if fit: + model.fit(dataset) + + with NamedTemporaryFile() as f: + model.save(f.name) + recovered_model = model_cls.load(f.name) + + with NamedTemporaryFile() as f: + recovered_model.save(f.name) + second_recovered_model = model_cls.load(f.name) + + assert isinstance(recovered_model, model_cls) + + original_model_config = model.get_config() + second_recovered_model_config = recovered_model.get_config() + assert second_recovered_model_config == original_model_config + + if fit: + assert_pl_models_equal(model.lightning_model, second_recovered_model.lightning_model) + + # check if trainer keep state on multiple call partial fit + @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) + def test_fit_partial_multiple_times( + self, + dataset: Dataset, + model_cls: tp.Type[TransformerModelBase], + ) -> None: + class FixSeedLightningModule(TransformerLightningModule): + def on_train_epoch_start(self) -> None: + seed_everything(32, workers=True) + + seed_everything(32, workers=True) + model = model_cls.from_config( + { + "epochs": 3, + "data_preparator_kwargs": {"shuffle_train": False}, + "get_trainer_func": custom_trainer, + "lightning_module_type": FixSeedLightningModule, + } + ) + model.fit_partial(dataset, min_epochs=1, max_epochs=1) + t1 = deepcopy(model.fit_trainer) + model.fit_partial( + Dataset.construct(pd.DataFrame(columns=Columns.Interactions)), + min_epochs=1, + max_epochs=1, + ) + t2 = deepcopy(model.fit_trainer) + + # Since for the second we are fitting on an empty dataset, + # the trainer state should be kept exactly the same as after the first fit + # to prove that fit_partial does not change trainer state before proceeding to training." + assert t1 is not None + assert t2 is not None + assert_pl_trainers_equal(t1, t2) + + @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) + def test_raises_when_fit_trainer_is_none_on_save_trained_model( + self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset + ) -> None: + config: tp.Dict[str, tp.Any] = {"deterministic": True} + model = model_cls.from_config(config) + + seed_everything(32, workers=True) + model.fit(dataset) + model.fit_trainer = None + + with NamedTemporaryFile() as f: + with pytest.raises(RuntimeError): + model.save(f.name) + + @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) + def test_raises_when_fit_trainer_is_none_on_fit_partial_trained_model( + self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset + ) -> None: + config: tp.Dict[str, tp.Any] = {"deterministic": True} + model = model_cls.from_config(config) + + seed_everything(32, workers=True) + model.fit(dataset) + model.fit_trainer = None + + with pytest.raises(RuntimeError): + model.fit_partial( + dataset, + min_epochs=1, + max_epochs=1, + )