Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
### Added
- `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287))
- Support for resaving transformer models multiple times and loading trainer state ([#289](https://github.com/MobileTeleSystems/RecTools/pull/289))

### Fixed
- [Breaking] Now `LastNSplitter` guarantees taking the last ordered interaction in dataframe in case of identical timestamps ([#288](https://github.com/MobileTeleSystems/RecTools/pull/288))
Expand Down
71 changes: 57 additions & 14 deletions rectools/models/nn/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import typing_extensions as tpe
from pydantic import BeforeValidator, PlainSerializer
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader, TensorDataset

from rectools import ExternalIds
from rectools.dataset.dataset import Dataset, DatasetSchema, DatasetSchemaDict, IdMap
Expand Down Expand Up @@ -505,16 +506,25 @@ def _fit_partial(
if not self.is_fitted:
self._build_model_from_dataset(dataset)
self.fit_trainer = deepcopy(self._trainer)
elif self.fit_trainer is None:
else:
# assumed that dataset is same as in `fit` or as in first call to `fit_partial`
# currently new datasets is not supported due to difficulties with
# handling id maps and item (user) features
self.data_preparator.process_dataset_train(dataset)
self.fit_trainer = deepcopy(self._trainer)
if self.fit_trainer is None:
raise RuntimeError("expected to have fit_trainer set")

train_dataloader = self.data_preparator.get_dataloader_train()
val_dataloader = self.data_preparator.get_dataloader_val()

self.lightning_model.train()
self.fit_trainer.fit_loop.max_epochs = self.fit_trainer.current_epoch + max_epochs
self.fit_trainer.fit_loop.min_epochs = self.fit_trainer.current_epoch + min_epochs

# if checkpoint is from ModelCheckpoint callback (and saved at end of epoch)
# its epoch value equal to num of data epochs - 1 (as epoch is not ended in checkpoint time)
# so instead of `fit_trainer.current_epoch` we use `count of ready epochs`
current_epoch = self.fit_trainer.fit_loop.epoch_progress.current.ready
self.fit_trainer.fit_loop.max_epochs = current_epoch + max_epochs
self.fit_trainer.fit_loop.min_epochs = current_epoch + min_epochs
self.fit_trainer.fit(self.lightning_model, train_dataloader, val_dataloader)

def _recommend_u2i(
Expand Down Expand Up @@ -574,8 +584,25 @@ def _get_config(self) -> TransformerModelConfig_T:
return self.config_class(**params)

@classmethod
def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self:
"""Create model from loaded Lightning checkpoint."""
def _model_from_checkpoint(
cls, checkpoint: tp.Dict[str, tp.Any], ckpt_path: tp.Optional[tp.Union[str, Path]] = None
) -> tpe.Self:
"""
Create model from loaded Lightning checkpoint.

Parameters
----------
checkpoint: Dict[str, tp.Any]
Checkpoint object (pl/torch like)
ckpt_path: Union[str, Path], optional
Path to checkpoint location.
If specified should be a path to `checkpoint` arg file.
`checkpoint` is saved to temp file if not specified.

Returns
-------
Model instance.
"""
model_config = checkpoint["hyper_parameters"]["model_config"]
loaded = cls.from_config(model_config)
loaded.is_fitted = True
Expand All @@ -596,20 +623,36 @@ def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self:
item_external_ids=item_external_ids,
model_config=model_config,
)

try:
temp_file = None
actual_ckpt_path = ckpt_path
if actual_ckpt_path is None:
temp_file = NamedTemporaryFile() # pylint: disable=consider-using-with
actual_ckpt_path = temp_file.name
torch.save(checkpoint, actual_ckpt_path)

loaded.fit_trainer = deepcopy(loaded._trainer)
# use stub dataset to load trainer state
loaded.fit_trainer.fit(
loaded.lightning_model,
ckpt_path=actual_ckpt_path,
train_dataloaders=DataLoader(TensorDataset(torch.Tensor())),
)

finally:
if temp_file is not None:
temp_file.close()

loaded.lightning_model.is_fitted = True
loaded.lightning_model.load_state_dict(checkpoint["state_dict"])

return loaded

def __getstate__(self) -> object:
if self.is_fitted:
if self.fit_trainer is None:
explanation = """
Model is fitted but has no `fit_trainer`. Most likely it was just loaded from the
checkpoint. Model that was loaded from checkpoint cannot be saved without being
fitted again.
"""
raise RuntimeError(explanation)
raise RuntimeError("Fitted model is expected to have `fit_trainer` set")

with NamedTemporaryFile() as f:
self.fit_trainer.save_checkpoint(f.name)
checkpoint = Path(f.name).read_bytes()
Expand Down Expand Up @@ -658,7 +701,7 @@ def load_from_checkpoint(
prev_config_flatten = make_dict_flat(prev_model_config)
prev_config_flatten.update(model_params_update)
checkpoint["hyper_parameters"]["model_config"] = unflatten_dict(prev_config_flatten)
loaded = cls._model_from_checkpoint(checkpoint)
loaded = cls._model_from_checkpoint(checkpoint, ckpt_path=checkpoint_path)
return loaded

def load_weights_from_checkpoint(self, checkpoint_path: tp.Union[str, Path]) -> None:
Expand Down
160 changes: 136 additions & 24 deletions tests/models/nn/transformers/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

import pandas as pd
import pytest
import pytorch_lightning as pl
import torch
from pytest import FixtureRequest
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import CSVLogger
from torch import nn

from rectools import Columns
from rectools.dataset import Dataset
Expand All @@ -35,6 +37,35 @@
from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask


def assert_torch_models_equal(model_a: nn.Module, model_b: nn.Module) -> None:
assert type(model_a) is type(model_b), "different types"

with torch.no_grad():
for (apn, apv), (bpn, bpv) in zip(model_a.named_parameters(), model_b.named_parameters()):
assert apn == bpn, "different parameter name"
assert torch.isclose(apv, bpv).all(), "different parameter value"


def assert_pl_models_equal(model_a: pl.LightningModule, model_b: pl.LightningModule) -> None:
"""Assert pl modules are equal in terms of weights and trainer"""
assert_torch_models_equal(model_a, model_b)

trainer_a = model_a.trainer
trainer_b = model_a.trainer

assert_pl_trainers_equal(trainer_a, trainer_b)


def assert_pl_trainers_equal(trainer_a: Trainer, trainer_b: Trainer) -> None:
"""Assert pl trainers are equal in terms of optimizers state"""
assert len(trainer_a.optimizers) == len(trainer_b.optimizers), "Different number of optimizers"

for opt_a, opt_b in zip(trainer_b.optimizers, trainer_b.optimizers):
# Check optimizer class
assert type(opt_a) is type(opt_b), f"Optimizer types differ: {type(opt_a)} vs {type(opt_b)}"
assert opt_a.state_dict() == opt_b.state_dict(), "optimizers state dict differs"


class TestTransformerModelBase:
def setup_method(self) -> None:
torch.use_deterministic_algorithms(True)
Expand Down Expand Up @@ -209,28 +240,6 @@ def test_load_from_checkpoint(

self._assert_same_reco(model, recovered_model, dataset)

@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
def test_raises_when_save_model_loaded_from_checkpoint(
self,
model_cls: tp.Type[TransformerModelBase],
dataset: Dataset,
) -> None:
model = model_cls.from_config(
{
"deterministic": True,
"get_trainer_func": custom_trainer_ckpt,
}
)
model.fit(dataset)
assert model.fit_trainer is not None
if model.fit_trainer.log_dir is None:
raise ValueError("No log dir")
ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt")
recovered_model = model_cls.load_from_checkpoint(ckpt_path)
with pytest.raises(RuntimeError):
with NamedTemporaryFile() as f:
recovered_model.save(f.name)

@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
def test_load_weights_from_checkpoint(
self,
Expand Down Expand Up @@ -391,8 +400,6 @@ def test_fit_partial_from_checkpoint(
recovered_fit_partial_model = model_cls.load_from_checkpoint(ckpt_path)

seed_everything(32, workers=True)
fit_partial_model.fit_trainer = deepcopy(fit_partial_model._trainer) # pylint: disable=protected-access
fit_partial_model.lightning_model.optimizer = None
fit_partial_model.fit_partial(dataset, min_epochs=1, max_epochs=1)

seed_everything(32, workers=True)
Expand All @@ -410,3 +417,108 @@ def test_raises_when_incorrect_similarity_dist(
with pytest.raises(ValueError):
model = model_cls.from_config(model_config)
model.fit(dataset=dataset)

@pytest.mark.parametrize("fit", (True, False))
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
@pytest.mark.parametrize("default_trainer", (True, False))
def test_resaving(
self,
model_cls: tp.Type[TransformerModelBase],
dataset: Dataset,
default_trainer: bool,
fit: bool,
) -> None:
config: tp.Dict[str, tp.Any] = {"deterministic": True}
if not default_trainer:
config["get_trainer_func"] = custom_trainer
model = model_cls.from_config(config)

seed_everything(32, workers=True)
if fit:
model.fit(dataset)

with NamedTemporaryFile() as f:
model.save(f.name)
recovered_model = model_cls.load(f.name)

with NamedTemporaryFile() as f:
recovered_model.save(f.name)
second_recovered_model = model_cls.load(f.name)

assert isinstance(recovered_model, model_cls)

original_model_config = model.get_config()
second_recovered_model_config = recovered_model.get_config()
assert second_recovered_model_config == original_model_config

if fit:
assert_pl_models_equal(model.lightning_model, second_recovered_model.lightning_model)

# check if trainer keep state on multiple call partial fit
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
def test_fit_partial_multiple_times(
self,
dataset: Dataset,
model_cls: tp.Type[TransformerModelBase],
) -> None:
class FixSeedLightningModule(TransformerLightningModule):
def on_train_epoch_start(self) -> None:
seed_everything(32, workers=True)

seed_everything(32, workers=True)
model = model_cls.from_config(
{
"epochs": 3,
"data_preparator_kwargs": {"shuffle_train": False},
"get_trainer_func": custom_trainer,
"lightning_module_type": FixSeedLightningModule,
}
)
model.fit_partial(dataset, min_epochs=1, max_epochs=1)
t1 = deepcopy(model.fit_trainer)
model.fit_partial(
Dataset.construct(pd.DataFrame(columns=Columns.Interactions)),
min_epochs=1,
max_epochs=1,
)
t2 = deepcopy(model.fit_trainer)

# Since for the second we are fitting on an empty dataset,
# the trainer state should be kept exactly the same as after the first fit
# to prove that fit_partial does not change trainer state before proceeding to training."
assert t1 is not None
assert t2 is not None
assert_pl_trainers_equal(t1, t2)

@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
def test_raises_when_fit_trainer_is_none_on_save_trained_model(
self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset
) -> None:
config: tp.Dict[str, tp.Any] = {"deterministic": True}
model = model_cls.from_config(config)

seed_everything(32, workers=True)
model.fit(dataset)
model.fit_trainer = None

with NamedTemporaryFile() as f:
with pytest.raises(RuntimeError):
model.save(f.name)

@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
def test_raises_when_fit_trainer_is_none_on_fit_partial_trained_model(
self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset
) -> None:
config: tp.Dict[str, tp.Any] = {"deterministic": True}
model = model_cls.from_config(config)

seed_everything(32, workers=True)
model.fit(dataset)
model.fit_trainer = None

with pytest.raises(RuntimeError):
model.fit_partial(
dataset,
min_epochs=1,
max_epochs=1,
)