Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
- `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287))
- allow resaving transformer model multiple times. Load train state on model loading ([#289](https://github.com/MobileTeleSystems/RecTools/pull/289))

## [0.14.0] - 16.05.2025

Expand Down
40 changes: 29 additions & 11 deletions rectools/models/nn/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import typing_extensions as tpe
from pydantic import BeforeValidator, PlainSerializer
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, TensorDataset

from rectools import ExternalIds
from rectools.dataset.dataset import Dataset, DatasetSchema, DatasetSchemaDict, IdMap
Expand Down Expand Up @@ -505,16 +506,25 @@ def _fit_partial(
if not self.is_fitted:
self._build_model_from_dataset(dataset)
self.fit_trainer = deepcopy(self._trainer)
elif self.fit_trainer is None:
else:
# assumed that dataset is same as in `fit` or as in first call to `fit_partial`
# currently new datasets is not supported due to difficulties with
# handling id maps and item (user) features
self.data_preparator.process_dataset_train(dataset)
self.fit_trainer = deepcopy(self._trainer)
if self.fit_trainer is None:
raise RuntimeError("expected to have fit_trainer set")

train_dataloader = self.data_preparator.get_dataloader_train()
val_dataloader = self.data_preparator.get_dataloader_val()

self.lightning_model.train()
self.fit_trainer.fit_loop.max_epochs = self.fit_trainer.current_epoch + max_epochs
self.fit_trainer.fit_loop.min_epochs = self.fit_trainer.current_epoch + min_epochs

# if checkpoint is from ModelCheckpoint callback (and saved at end of epoch)
# its epoch value equal to num of data epochs - 1 (as epoch is not ended in checkpoint time)
# so instead of `fit_trainer.current_epoch` we use `count of ready epochs`
current_epoch = self.fit_trainer.fit_loop.epoch_progress.current.ready
self.fit_trainer.fit_loop.max_epochs = current_epoch + max_epochs
self.fit_trainer.fit_loop.min_epochs = current_epoch + min_epochs
self.fit_trainer.fit(self.lightning_model, train_dataloader, val_dataloader)

def _recommend_u2i(
Expand Down Expand Up @@ -596,20 +606,28 @@ def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self:
item_external_ids=item_external_ids,
model_config=model_config,
)

# save checkpoint to temp file to be able to use it in trainer
with NamedTemporaryFile() as f:
torch.save(checkpoint, f.name)
fit_trainer = deepcopy(loaded._trainer)
loaded.fit_trainer = fit_trainer
# use stub dataset to load trainer state
loaded.fit_trainer.fit(
loaded.lightning_model,
ckpt_path=f.name,
train_dataloaders=DataLoader(TensorDataset(torch.Tensor())),
)

loaded.lightning_model.is_fitted = True
loaded.lightning_model.load_state_dict(checkpoint["state_dict"])

return loaded

def __getstate__(self) -> object:
if self.is_fitted:
if self.fit_trainer is None:
explanation = """
Model is fitted but has no `fit_trainer`. Most likely it was just loaded from the
checkpoint. Model that was loaded from checkpoint cannot be saved without being
fitted again.
"""
raise RuntimeError(explanation)
raise RuntimeError("expected to have fit_trainer set")

with NamedTemporaryFile() as f:
self.fit_trainer.save_checkpoint(f.name)
checkpoint = Path(f.name).read_bytes()
Expand Down
157 changes: 133 additions & 24 deletions tests/models/nn/transformers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

import pandas as pd
import pytest
import pytorch_lightning as pl
import torch
from pytest import FixtureRequest
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CSVLogger
from torch import nn

from rectools import Columns
from rectools.dataset import Dataset
Expand All @@ -35,6 +37,35 @@
from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask


def assert_torch_models_equal(model_a: nn.Module, model_b: nn.Module) -> None:
assert type(model_a) is type(model_b), "different types"

with torch.no_grad():
for (apn, apv), (bpn, bpv) in zip(model_a.named_parameters(), model_b.named_parameters()):
assert apn == bpn, "different parameter name"
assert torch.isclose(apv, bpv).all(), "different parameter value"


def assert_pl_models_equal(model_a: pl.LightningModule, model_b: pl.LightningModule) -> None:
"""Assert pl modules are equal in terms of weights and trainer"""
assert_torch_models_equal(model_a, model_b)

trainer_a = model_a.trainer
trainer_b = model_a.trainer

assert_pl_trainers_equal(trainer_a, trainer_b)


def assert_pl_trainers_equal(trainer_a: Trainer, trainer_b: Trainer) -> None:
"""Assert pl trainers are equal in terms of optimizers state"""
assert len(trainer_a.optimizers) == len(trainer_b.optimizers), "Different number of optimizers"

for opt_a, opt_b in zip(trainer_b.optimizers, trainer_b.optimizers):
# Check optimizer class
assert type(opt_a) is type(opt_b), f"Optimizer types differ: {type(opt_a)} vs {type(opt_b)}"
assert opt_a.state_dict() == opt_b.state_dict(), "optimizers state dict differs"


class TestTransformerModelBase:
def setup_method(self) -> None:
torch.use_deterministic_algorithms(True)
Expand Down Expand Up @@ -209,28 +240,6 @@ def test_load_from_checkpoint(

self._assert_same_reco(model, recovered_model, dataset)

@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
def test_raises_when_save_model_loaded_from_checkpoint(
self,
model_cls: tp.Type[TransformerModelBase],
dataset: Dataset,
) -> None:
model = model_cls.from_config(
{
"deterministic": True,
"get_trainer_func": custom_trainer_ckpt,
}
)
model.fit(dataset)
assert model.fit_trainer is not None
if model.fit_trainer.log_dir is None:
raise ValueError("No log dir")
ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt")
recovered_model = model_cls.load_from_checkpoint(ckpt_path)
with pytest.raises(RuntimeError):
with NamedTemporaryFile() as f:
recovered_model.save(f.name)

@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
def test_load_weights_from_checkpoint(
self,
Expand Down Expand Up @@ -391,8 +400,6 @@ def test_fit_partial_from_checkpoint(
recovered_fit_partial_model = model_cls.load_from_checkpoint(ckpt_path)

seed_everything(32, workers=True)
fit_partial_model.fit_trainer = deepcopy(fit_partial_model._trainer) # pylint: disable=protected-access
fit_partial_model.lightning_model.optimizer = None
fit_partial_model.fit_partial(dataset, min_epochs=1, max_epochs=1)

seed_everything(32, workers=True)
Expand All @@ -410,3 +417,105 @@ def test_raises_when_incorrect_similarity_dist(
with pytest.raises(ValueError):
model = model_cls.from_config(model_config)
model.fit(dataset=dataset)

@pytest.mark.parametrize("fit", (True, False))
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
@pytest.mark.parametrize("default_trainer", (True, False))
def test_resaving(
self,
model_cls: tp.Type[TransformerModelBase],
dataset: Dataset,
default_trainer: bool,
fit: bool,
) -> None:
config: tp.Dict[str, tp.Any] = {"deterministic": True}
if not default_trainer:
config["get_trainer_func"] = custom_trainer
model = model_cls.from_config(config)

seed_everything(32, workers=True)
if fit:
model.fit(dataset)

with NamedTemporaryFile() as f:
model.save(f.name)
recovered_model = model_cls.load(f.name)

with NamedTemporaryFile() as f:
recovered_model.save(f.name)
second_recovered_model = model_cls.load(f.name)

assert isinstance(recovered_model, model_cls)

original_model_config = model.get_config()
second_recovered_model_config = recovered_model.get_config()
assert second_recovered_model_config == original_model_config

if fit:
assert_pl_models_equal(model.lightning_model, second_recovered_model.lightning_model)

# check if trainer keep state on multiple call partial fit
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
def test_fit_partial_multiple_times(
self,
dataset: Dataset,
model_cls: tp.Type[TransformerModelBase],
) -> None:
class FixSeedLightningModule(TransformerLightningModule):
def on_train_epoch_start(self) -> None:
seed_everything(32, workers=True)

seed_everything(32, workers=True)
model = model_cls.from_config(
{
"epochs": 3,
"data_preparator_kwargs": {"shuffle_train": False},
"get_trainer_func": custom_trainer,
"lightning_module_type": FixSeedLightningModule,
}
)
model.fit_partial(dataset, min_epochs=1, max_epochs=1)
t1 = deepcopy(model.fit_trainer)
model.fit_partial(
Dataset.construct(pd.DataFrame(columns=Columns.Interactions)),
min_epochs=1,
max_epochs=1,
)
t2 = deepcopy(model.fit_trainer)

assert t1 is not None
assert t2 is not None
assert_pl_trainers_equal(t1, t2)

@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
def test_raises_when_fit_trainer_is_none_on_save_trained_model(
self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset
) -> None:
config: tp.Dict[str, tp.Any] = {"deterministic": True}
model = model_cls.from_config(config)

seed_everything(32, workers=True)
model.fit(dataset)
model.fit_trainer = None

with NamedTemporaryFile() as f:
with pytest.raises(RuntimeError):
model.save(f.name)

@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
def test_raises_when_fit_trainer_is_none_on_fit_partial_trained_model(
self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset
) -> None:
config: tp.Dict[str, tp.Any] = {"deterministic": True}
model = model_cls.from_config(config)

seed_everything(32, workers=True)
model.fit(dataset)
model.fit_trainer = None

with pytest.raises(RuntimeError):
model.fit_partial(
dataset,
min_epochs=1,
max_epochs=1,
)