diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c1e525b..736244e5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Python 3.13 support ([#227](https://github.com/MobileTeleSystems/RecTools/pull/227)) - `fit_partial` implementation for transformer-based models ([#273](https://github.com/MobileTeleSystems/RecTools/pull/273)) +- `map_location` and `model_params_update` arguments for the function `load_from_checkpoint` for Transformer-based models. Use `map_location` to explicitly specify the computing new device and `model_params_update` to update original model parameters (e.g. remove training-specific parameters that are not needed anymore) ([#281](https://github.com/MobileTeleSystems/RecTools/pull/281)) ## [0.13.0] - 10.04.2025 diff --git a/rectools/models/nn/transformers/base.py b/rectools/models/nn/transformers/base.py index 735bf4b6..c2699978 100644 --- a/rectools/models/nn/transformers/base.py +++ b/rectools/models/nn/transformers/base.py @@ -29,7 +29,7 @@ from rectools.dataset.dataset import Dataset, DatasetSchema, DatasetSchemaDict, IdMap from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase, ModelConfig from rectools.types import InternalIdsArray -from rectools.utils.misc import get_class_or_function_full_path, import_object +from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat, unflatten_dict from ..item_net import ( CatFeaturesItemNet, @@ -623,20 +623,36 @@ def __setstate__(self, state: tp.Dict[str, tp.Any]) -> None: self.__dict__.update(loaded.__dict__) @classmethod - def load_from_checkpoint(cls, checkpoint_path: tp.Union[str, Path]) -> tpe.Self: - """ - Load model from Lightning checkpoint path. + def load_from_checkpoint( + cls, + checkpoint_path: tp.Union[str, Path], + map_location: tp.Optional[tp.Union[str, torch.device]] = None, + model_params_update: tp.Optional[tp.Dict[str, tp.Any]] = None, + ) -> tpe.Self: + """Load model from Lightning checkpoint path. Parameters ---------- checkpoint_path: Union[str, Path] Path to checkpoint location. - + map_location: Union[str, torch.device], optional + Target device to load the checkpoint (e.g., 'cpu', 'cuda:0'). + If None, will use the device the checkpoint was saved on. + model_params_update: Dict[str, tp.Any], optional + Contains custom values for checkpoint['hyper_parameters']['model_config']. + Has to be flattened with 'dot' reducer, before passed. + You can use this argument to remove training-specific parameters that are not needed anymore. + e.g. 'get_trainer_func' Returns ------- Model instance. """ - checkpoint = torch.load(checkpoint_path, weights_only=False) + checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False) + if model_params_update: + prev_model_config = checkpoint["hyper_parameters"]["model_config"] + prev_config_flatten = make_dict_flat(prev_model_config) + prev_config_flatten.update(model_params_update) + checkpoint["hyper_parameters"]["model_config"] = unflatten_dict(prev_config_flatten) loaded = cls._model_from_checkpoint(checkpoint) return loaded diff --git a/tests/models/nn/transformers/test_base.py b/tests/models/nn/transformers/test_base.py index c2ceba08..da3df2c7 100644 --- a/tests/models/nn/transformers/test_base.py +++ b/tests/models/nn/transformers/test_base.py @@ -154,10 +154,37 @@ def test_save_load_for_fitted_model( @pytest.mark.parametrize("test_dataset", ("dataset", "dataset_item_features")) @pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel)) + @pytest.mark.parametrize( + "map_location", + ( + "cpu", + pytest.param( + "cuda:0", + marks=pytest.mark.skipif(torch.cuda.is_available() is False, reason="GPU is not available"), + ), + None, + ), + ) + @pytest.mark.parametrize( + "model_params_update", + ( + { + "get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask", + "get_trainer_func": "tests.models.nn.transformers.utils.custom_trainer", + }, + { + "get_val_mask_func": None, + "get_trainer_func": None, + }, + None, + ), + ) def test_load_from_checkpoint( self, model_cls: tp.Type[TransformerModelBase], test_dataset: str, + map_location: tp.Optional[tp.Union[str, torch.device]], + model_params_update: tp.Optional[tp.Dict[str, tp.Any]], request: FixtureRequest, ) -> None: @@ -175,7 +202,9 @@ def test_load_from_checkpoint( raise ValueError("No log dir") ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt") assert os.path.isfile(ckpt_path) - recovered_model = model_cls.load_from_checkpoint(ckpt_path) + recovered_model = model_cls.load_from_checkpoint( + ckpt_path, map_location=map_location, model_params_update=model_params_update + ) assert isinstance(recovered_model, model_cls) self._assert_same_reco(model, recovered_model, dataset) diff --git a/tests/models/nn/transformers/test_bert4rec.py b/tests/models/nn/transformers/test_bert4rec.py index 68068475..6f0a0c34 100644 --- a/tests/models/nn/transformers/test_bert4rec.py +++ b/tests/models/nn/transformers/test_bert4rec.py @@ -451,6 +451,7 @@ def test_i2i( whitelist: tp.Optional[np.ndarray], expected: pd.DataFrame, ) -> None: + model = BERT4RecModel( n_factors=32, n_blocks=2,