Skip to content

Commit a2efbc3

Browse files
authored
Feature/models saving (#206)
- Added models saving and loading Closes #100
1 parent 11727d3 commit a2efbc3

File tree

14 files changed

+202
-3
lines changed

14 files changed

+202
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99
## Unreleased
1010

1111
### Added
12-
- `from_config`, `get_config` and `get_params` methods to all models except neural-net-based([#170](https://github.com/MobileTeleSystems/RecTools/pull/170))
12+
- `from_config`, `get_config` and `get_params` methods to all models except neural-net-based ([#170](https://github.com/MobileTeleSystems/RecTools/pull/170))
1313
- Optional `epochs` argument to `ImplicitALSWrapperModel.fit` method ([#203](https://github.com/MobileTeleSystems/RecTools/pull/203))
14+
- `save` and `load` methods to all of the models ([#206](https://github.com/MobileTeleSystems/RecTools/pull/206))
1415

1516

1617
## [0.8.0] - 28.08.2024

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ See [recommender baselines extended tutorial](https://github.com/MobileTeleSyste
150150
[Contributing guide](CONTRIBUTING.rst)
151151

152152
To install all requirements
153-
- you must have `python>=3.8` and `poetry>=1.5.0` installed
153+
- you must have `python3` and `poetry` installed
154154
- make sure you have no active virtual environments (deactivate conda `base` if applicable)
155155
- run
156156
```

rectools/models/base.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,10 @@
1414

1515
"""Base model."""
1616

17+
import pickle
1718
import typing as tp
1819
import warnings
20+
from pathlib import Path
1921

2022
import numpy as np
2123
import pandas as pd
@@ -42,6 +44,10 @@
4244

4345
RecoTriplet_T = tp.TypeVar("RecoTriplet_T", InternalRecoTriplet, SemiInternalRecoTriplet, ExternalRecoTriplet)
4446

47+
FileLike = tp.Union[str, Path, tp.IO[bytes]]
48+
49+
PICKLE_PROTOCOL = 5
50+
4551

4652
def _serialize_random_state(rs: tp.Optional[tp.Union[None, int, np.random.RandomState]]) -> tp.Union[None, int]:
4753
if rs is None or isinstance(rs, int):
@@ -191,6 +197,85 @@ def from_config(cls, config: tp.Union[dict, ModelConfig_T]) -> tpe.Self:
191197
def _from_config(cls, config: ModelConfig_T) -> tpe.Self:
192198
raise NotImplementedError()
193199

200+
def save(self, f: FileLike) -> int:
201+
"""
202+
Save model to file.
203+
204+
Parameters
205+
----------
206+
f : str or Path or file-like object
207+
Path to file or file-like object.
208+
209+
Returns
210+
-------
211+
int
212+
Number of bytes written.
213+
"""
214+
data = self.dumps()
215+
216+
if isinstance(f, (str, Path)):
217+
return Path(f).write_bytes(data)
218+
219+
return f.write(data)
220+
221+
def dumps(self) -> bytes:
222+
"""
223+
Serialize model to bytes.
224+
225+
Returns
226+
-------
227+
bytes
228+
Serialized model.
229+
"""
230+
return pickle.dumps(self, protocol=PICKLE_PROTOCOL)
231+
232+
@classmethod
233+
def load(cls, f: FileLike) -> tpe.Self:
234+
"""
235+
Load model from file.
236+
237+
Parameters
238+
----------
239+
f : str or Path or file-like object
240+
Path to file or file-like object.
241+
242+
Returns
243+
-------
244+
model
245+
Model instance.
246+
"""
247+
if isinstance(f, (str, Path)):
248+
data = Path(f).read_bytes()
249+
else:
250+
data = f.read()
251+
252+
return cls.loads(data)
253+
254+
@classmethod
255+
def loads(cls, data: bytes) -> tpe.Self:
256+
"""
257+
Load model from bytes.
258+
259+
Parameters
260+
----------
261+
data : bytes
262+
Serialized model.
263+
264+
Returns
265+
-------
266+
model
267+
Model instance.
268+
269+
Raises
270+
------
271+
TypeError
272+
If loaded object is not a direct instance of model class.
273+
"""
274+
loaded = pickle.loads(data)
275+
if loaded.__class__ is not cls:
276+
raise TypeError(f"Loaded object is not a direct instance of `{cls.__name__}`")
277+
return loaded
278+
194279
def fit(self: T, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> T:
195280
"""
196281
Fit model.

tests/models/test_base.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import typing as tp
1818
import warnings
1919
from datetime import timedelta
20+
from pathlib import Path
21+
from tempfile import NamedTemporaryFile, TemporaryFile
2022

2123
import numpy as np
2224
import pandas as pd
@@ -539,6 +541,42 @@ class MyModelWithoutConfig(ModelBase):
539541
MyModelWithoutConfig().get_config()
540542

541543

544+
class MyModel(ModelBase):
545+
def __init__(self, x: int = 10, verbose: int = 0):
546+
super().__init__(verbose=verbose)
547+
self.x = x
548+
549+
550+
class TestSavingAndLoading:
551+
552+
@pytest.fixture()
553+
def model(self) -> MyModel:
554+
return MyModel()
555+
556+
def test_save_and_load_to_file(self, model: MyModel) -> None:
557+
with TemporaryFile() as f:
558+
model.save(f)
559+
f.seek(0)
560+
loaded_model = MyModel.load(f)
561+
assert isinstance(loaded_model, MyModel)
562+
assert loaded_model.__dict__ == model.__dict__
563+
564+
@pytest.mark.parametrize("use_str", (False, True))
565+
def test_save_and_load_from_path(self, model: MyModel, use_str: bool) -> None:
566+
with NamedTemporaryFile() as f:
567+
path: tp.Union[Path, str] = Path(f.name) if not use_str else f.name
568+
model.save(path)
569+
loaded_model = MyModel.load(path)
570+
assert isinstance(loaded_model, MyModel)
571+
assert loaded_model.__dict__ == model.__dict__
572+
573+
def test_load_fails_on_incorrect_model_type(self, model: MyModel) -> None:
574+
with NamedTemporaryFile() as f:
575+
model.save(f.name)
576+
with pytest.raises(TypeError, match="Loaded object is not a direct instance of `ModelBase`"):
577+
ModelBase.load(f.name)
578+
579+
542580
class TestFixedColdRecoModelMixin:
543581
def test_cold_reco_works(self) -> None:
544582
class ColdRecoModel(FixedColdRecoModelMixin, ModelBase):

tests/models/test_dssm.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from rectools.models import DSSMModel
2626
from rectools.models.dssm import DSSM
2727
from rectools.models.vector import ImplicitRanker
28-
from tests.models.utils import assert_second_fit_refits_model
28+
from tests.models.utils import assert_dumps_loads_do_not_change_model, assert_second_fit_refits_model
2929

3030
from .data import INTERACTIONS
3131

@@ -338,3 +338,8 @@ def test_raises_when_no_features_in_dataset(self, dataset: Dataset, exclude_feat
338338
def test_second_fit_refits_model(self, dataset: Dataset) -> None:
339339
model = DSSMModel(deterministic=True)
340340
assert_second_fit_refits_model(model, dataset, pre_fit_callback=self._seed_everything)
341+
342+
def test_dumps_loads(self, dataset: Dataset) -> None:
343+
model = DSSMModel()
344+
model.fit(dataset)
345+
assert_dumps_loads_do_not_change_model(model, dataset, check_configs=False)

tests/models/test_ease.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .data import DATASET, INTERACTIONS
2626
from .utils import (
2727
assert_default_config_and_default_model_params_are_the_same,
28+
assert_dumps_loads_do_not_change_model,
2829
assert_get_config_and_from_config_compatibility,
2930
assert_second_fit_refits_model,
3031
)
@@ -225,6 +226,11 @@ def test_i2i_with_warm_and_cold_items(self, item_features: tp.Optional[pd.DataFr
225226
k=2,
226227
)
227228

229+
def test_dumps_loads(self, dataset: Dataset) -> None:
230+
model = EASEModel()
231+
model.fit(dataset)
232+
assert_dumps_loads_do_not_change_model(model, dataset)
233+
228234

229235
class TestEASEModelConfiguration:
230236
def test_from_config(self) -> None:

tests/models/test_implicit_als.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from .data import DATASET
3939
from .utils import (
4040
assert_default_config_and_default_model_params_are_the_same,
41+
assert_dumps_loads_do_not_change_model,
4142
assert_get_config_and_from_config_compatibility,
4243
assert_second_fit_refits_model,
4344
)
@@ -397,6 +398,11 @@ def test_per_epoch_fitting_consistent_with_regular_fitting(
397398
assert np.allclose(get_users_vectors(model_1.model), get_users_vectors(model_2.model))
398399
assert np.allclose(get_items_vectors(model_1.model), get_items_vectors(model_2.model))
399400

401+
def test_dumps_loads(self, use_gpu: bool, dataset: Dataset) -> None:
402+
model = ImplicitALSWrapperModel(model=AlternatingLeastSquares(use_gpu=use_gpu))
403+
model.fit(dataset)
404+
assert_dumps_loads_do_not_change_model(model, dataset)
405+
400406

401407
class CustomALS(CPUAlternatingLeastSquares):
402408
pass

tests/models/test_implicit_knn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .data import DATASET, INTERACTIONS
2727
from .utils import (
2828
assert_default_config_and_default_model_params_are_the_same,
29+
assert_dumps_loads_do_not_change_model,
2930
assert_get_config_and_from_config_compatibility,
3031
assert_second_fit_refits_model,
3132
)
@@ -251,6 +252,11 @@ def test_base_class(self, dataset: Dataset) -> None:
251252
).astype({Columns.Score: np.float32})
252253
pd.testing.assert_frame_equal(actual, expected, atol=0.001)
253254

255+
def test_dumps_loads(self, dataset: Dataset) -> None:
256+
model = ImplicitItemKNNWrapperModel(model=TFIDFRecommender())
257+
model.fit(dataset)
258+
assert_dumps_loads_do_not_change_model(model, dataset)
259+
254260

255261
class CustomKNN(ItemItemRecommender):
256262
pass

tests/models/test_lightfm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from tests.models.data import DATASET
3434
from tests.models.utils import (
3535
assert_default_config_and_default_model_params_are_the_same,
36+
assert_dumps_loads_do_not_change_model,
3637
assert_get_config_and_from_config_compatibility,
3738
assert_second_fit_refits_model,
3839
)
@@ -340,6 +341,11 @@ def _get_items_factors(self, dataset: Dataset) -> Factors:
340341
filter_viewed=False,
341342
)
342343

344+
def test_dumps_loads(self, dataset: Dataset) -> None:
345+
model = LightFMWrapperModel(LightFM())
346+
model.fit(dataset)
347+
assert_dumps_loads_do_not_change_model(model, dataset)
348+
343349

344350
class CustomLightFM(LightFM):
345351
pass

tests/models/test_popular.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from rectools.models.popular import Popularity
2626
from tests.models.utils import (
2727
assert_default_config_and_default_model_params_are_the_same,
28+
assert_dumps_loads_do_not_change_model,
2829
assert_get_config_and_from_config_compatibility,
2930
assert_second_fit_refits_model,
3031
)
@@ -220,6 +221,11 @@ def test_second_fit_refits_model(self, dataset: Dataset) -> None:
220221
model = PopularModel()
221222
assert_second_fit_refits_model(model, dataset)
222223

224+
def test_dumps_loads(self, dataset: Dataset) -> None:
225+
model = PopularModel()
226+
model.fit(dataset)
227+
assert_dumps_loads_do_not_change_model(model, dataset)
228+
223229

224230
class TestPopularModelConfiguration:
225231
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)