Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Uncommited

### Added
- `map_location` and `config_update` params for the function `load_from_checkpoint`. Use `map_location` to explicitly specify the computing device and `config_update` in flattened form to update config


## [0.13.0] - 10.04.2025

### Added
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ ipywidgets = {version = ">=7.7,<8.2", optional = true}
plotly = {version="^5.22.0", optional = true}
nbformat = {version = ">=4.2.0", optional = true}
cupy-cuda12x = {version = "^13.3.0", python = "<3.13", optional = true}
flatten-dict = "^0.4.2"


[tool.poetry.extras]
Expand Down
23 changes: 19 additions & 4 deletions rectools/models/nn/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import numpy as np
import torch
import typing_extensions as tpe
from flatten_dict import flatten, unflatten
from pydantic import BeforeValidator, PlainSerializer
from pytorch_lightning import Trainer

Expand Down Expand Up @@ -601,20 +602,34 @@ def __setstate__(self, state: tp.Dict[str, tp.Any]) -> None:
self.__dict__.update(loaded.__dict__)

@classmethod
def load_from_checkpoint(cls, checkpoint_path: tp.Union[str, Path]) -> tpe.Self:
"""
Load model from Lightning checkpoint path.
def load_from_checkpoint(
cls,
checkpoint_path: tp.Union[str, Path],
map_location: tp.Union[str, torch.device, None] = None,
config_update: tp.Dict[str, tp.Any] = {},
) -> tpe.Self:
"""Load model from Lightning checkpoint path.

Parameters
----------
checkpoint_path: Union[str, Path]
Path to checkpoint location.
map_location: Union[str, torch.device, None], default None
Target device to load the checkpoint (e.g., 'cpu', 'cuda:0').
If None, will use the device the checkpoint was saved on.
config_update: tp.Dict[str, tp.Any], default '{}'
Contains custom values for checkpoint['hyper_parameters'].
Config_update has to be flattened with 'dot' reducer, before passed.

Returns
-------
Model instance.
"""
checkpoint = torch.load(checkpoint_path, weights_only=False)
checkpoint = torch.load(checkpoint_path, map_location=map_location, weights_only=False)
prev_config = checkpoint["hyper_parameters"]
prev_config_flatten = flatten(prev_config, reducer="dot")
prev_config_flatten.update(config_update)
checkpoint["hyper_parameters"] = unflatten(prev_config_flatten, splitter="dot")
loaded = cls._model_from_checkpoint(checkpoint)
return loaded

Expand Down
20 changes: 19 additions & 1 deletion tests/models/nn/transformers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pandas as pd
import pytest
import torch
from flatten_dict import flatten
from pytest import FixtureRequest
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CSVLogger
Expand Down Expand Up @@ -152,10 +153,24 @@ def test_save_load_for_fitted_model(

@pytest.mark.parametrize("test_dataset", ("dataset", "dataset_item_features"))
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
@pytest.mark.parametrize("map_location", ("cpu", torch.device("cuda:0"), None))
@pytest.mark.parametrize(
"config_update",
(
{
"model_config": {
"get_val_mask_func": "tests.models.nn.transformers.utils.leave_one_out_mask",
"get_trainer_func": "tests.models.nn.transformers.utils.custom_trainer",
}
},
),
)
def test_load_from_checkpoint(
self,
model_cls: tp.Type[TransformerModelBase],
test_dataset: str,
map_location: tp.Union[str, torch.device, None],
config_update: tp.Dict[str, tp.Any],
request: FixtureRequest,
) -> None:

Expand All @@ -173,7 +188,10 @@ def test_load_from_checkpoint(
raise ValueError("No log dir")
ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt")
assert os.path.isfile(ckpt_path)
recovered_model = model_cls.load_from_checkpoint(ckpt_path)
config_update_flatten = flatten(config_update, reducer="dot")
recovered_model = model_cls.load_from_checkpoint(
ckpt_path, map_location=map_location, config_update=config_update_flatten
)
assert isinstance(recovered_model, model_cls)

self._assert_same_reco(model, recovered_model, dataset)
Expand Down
Loading