Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- `map_location` and `model_params_update` arguments for the function `load_from_checkpoint` for Transformer-based models. Use `map_location` to explicitly specify the computing new device and `model_params_update` to update original model parameters (e.g. remove training-specific parameters that are not needed anymore)

## [0.13.0] - 10.04.2025

### Added
Expand All @@ -14,7 +20,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `TransformerBackboneBase`, `backbone_type` and `backbone_kwargs` parameters to transformer-based models ([#277](https://github.com/MobileTeleSystems/RecTools/pull/277))
- `sampled_softmax` loss option for transformer models ([#274](https://github.com/MobileTeleSystems/RecTools/pull/274))


## [0.12.0] - 24.02.2025

### Added
Expand Down
28 changes: 22 additions & 6 deletions rectools/models/nn/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from rectools.dataset.dataset import Dataset, DatasetSchema, DatasetSchemaDict, IdMap
from rectools.models.base import ErrorBehaviour, InternalRecoTriplet, ModelBase, ModelConfig
from rectools.types import InternalIdsArray
from rectools.utils.misc import get_class_or_function_full_path, import_object
from rectools.utils.misc import get_class_or_function_full_path, import_object, make_dict_flat, unflatten_dict

from ..item_net import (
CatFeaturesItemNet,
Expand Down Expand Up @@ -601,20 +601,36 @@ def __setstate__(self, state: tp.Dict[str, tp.Any]) -> None:
self.__dict__.update(loaded.__dict__)

@classmethod
def load_from_checkpoint(cls, checkpoint_path: tp.Union[str, Path]) -> tpe.Self:
"""
Load model from Lightning checkpoint path.
def load_from_checkpoint(
cls,
checkpoint_path: tp.Union[str, Path],
map_location: tp.Optional[tp.Union[str, torch.device]] = None,
model_params_update: tp.Optional[tp.Dict[str, tp.Any]] = None,
) -> tpe.Self:
"""Load model from Lightning checkpoint path.

Parameters
----------
checkpoint_path: Union[str, Path]
Path to checkpoint location.

map_location: Union[str, torch.device], optional
Target device to load the checkpoint (e.g., 'cpu', 'cuda:0').
If None, will use the device the checkpoint was saved on.
model_params_update: Dict[str, tp.Any], optional
Contains custom values for checkpoint['hyper_parameters']['model_config'].
Has to be flattened with 'dot' reducer, before passed.
You can use this argument to remove training-specific parameters that are not needed anymore.
e.g. 'get_trainer_func'
Returns
-------
Model instance.
"""
checkpoint = torch.load(checkpoint_path, weights_only=False)
checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
prev_model_config = checkpoint["hyper_parameters"]["model_config"]
if model_params_update:
prev_config_flatten = make_dict_flat(prev_model_config)
prev_config_flatten.update(model_params_update)
checkpoint["hyper_parameters"]["model_config"] = unflatten_dict(prev_config_flatten)
loaded = cls._model_from_checkpoint(checkpoint)
return loaded

Expand Down
31 changes: 30 additions & 1 deletion tests/models/nn/transformers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,37 @@ def test_save_load_for_fitted_model(

@pytest.mark.parametrize("test_dataset", ("dataset", "dataset_item_features"))
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
@pytest.mark.parametrize(
"map_location",
(
"cpu",
pytest.param(
"cuda:0",
marks=pytest.mark.skipif(torch.cuda.is_available() is False, reason="GPU is not available"),
),
None,
),
)
@pytest.mark.parametrize(
"model_params_update",
(
{
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
"get_trainer_func": "tests.models.nn.transformers.utils.custom_trainer",
},
{
"get_val_mask_func": None,
"get_trainer_func": None,
},
None,
),
)
def test_load_from_checkpoint(
self,
model_cls: tp.Type[TransformerModelBase],
test_dataset: str,
map_location: tp.Optional[tp.Union[str, torch.device]],
model_params_update: tp.Optional[tp.Dict[str, tp.Any]],
request: FixtureRequest,
) -> None:

Expand All @@ -173,7 +200,9 @@ def test_load_from_checkpoint(
raise ValueError("No log dir")
ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt")
assert os.path.isfile(ckpt_path)
recovered_model = model_cls.load_from_checkpoint(ckpt_path)
recovered_model = model_cls.load_from_checkpoint(
ckpt_path, map_location=map_location, model_params_update=model_params_update
)
assert isinstance(recovered_model, model_cls)

self._assert_same_reco(model, recovered_model, dataset)
Expand Down
1 change: 1 addition & 0 deletions tests/models/nn/transformers/test_bert4rec.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ def test_i2i(
whitelist: tp.Optional[np.ndarray],
expected: pd.DataFrame,
) -> None:

model = BERT4RecModel(
n_factors=32,
n_blocks=2,
Expand Down