Skip to content

Commit 0d1df86

Browse files
authored
Model configs (#170)
- Added configs to models Сloses #108 #157
1 parent 0be5e15 commit 0d1df86

28 files changed

+1948
-163
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
66
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
77

88

9+
## Unreleased
10+
11+
### Added
12+
- `from_config`, `get_config` and `get_params` methods to all models except neural-net-based([#170](https://github.com/MobileTeleSystems/RecTools/pull/170))
13+
14+
915
## [0.8.0] - 28.08.2024
1016

1117
### Added

poetry.lock

Lines changed: 141 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,12 @@ tqdm = "^4.27.0"
6868
implicit = "^0.7.1"
6969
attrs = ">=19.1.0,<24.0.0"
7070
typeguard = "^4.1.0"
71-
71+
pydantic = "^2.8.2"
72+
pydantic-core = "^2.20.1"
73+
typing-extensions = "^4.12.2"
7274

7375
# The latest released version of lightfm is 1.17 and it's not compatible with PEP-517 installers (like latest poetry versions).
74-
rectools-lightfm = {version="1.17.1", python = "<3.12", optional = true}
76+
rectools-lightfm = {version="1.17.2", python = "<3.12", optional = true}
7577

7678
nmslib = {version = "^2.0.4", python = "<3.11", optional = true}
7779
# nmslib officialy doens't support Python 3.11 and 3.12. Use https://github.com/metabrainz/nmslib-metabrainz instead

rectools/dataset/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def get_user_item_matrix(
200200
include_weights: bool = True,
201201
include_warm_users: bool = False,
202202
include_warm_items: bool = False,
203+
dtype: tp.Type = np.float32,
203204
) -> sparse.csr_matrix:
204205
"""
205206
Construct user-item CSR matrix based on `interactions` attribute.
@@ -224,7 +225,7 @@ def get_user_item_matrix(
224225
csr_matrix
225226
Resized user-item CSR matrix
226227
"""
227-
matrix = self.interactions.get_user_item_matrix(include_weights)
228+
matrix = self.interactions.get_user_item_matrix(include_weights, dtype)
228229
n_rows = self.user_id_map.size if include_warm_users else matrix.shape[0]
229230
n_columns = self.item_id_map.size if include_warm_items else matrix.shape[1]
230231
matrix.resize(n_rows, n_columns)

rectools/dataset/interactions.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
"""Structure for saving user-item interactions."""
1616

17+
import typing as tp
18+
1719
import attr
1820
import numpy as np
1921
import pandas as pd
@@ -121,7 +123,7 @@ def from_raw(
121123

122124
return cls(df)
123125

124-
def get_user_item_matrix(self, include_weights: bool = True) -> sparse.csr_matrix:
126+
def get_user_item_matrix(self, include_weights: bool = True, dtype: tp.Type = np.float32) -> sparse.csr_matrix:
125127
"""
126128
Form a user-item CSR matrix based on interactions data.
127129
@@ -142,7 +144,7 @@ def get_user_item_matrix(self, include_weights: bool = True) -> sparse.csr_matri
142144

143145
csr = sparse.csr_matrix(
144146
(
145-
values.astype(np.float32),
147+
values.astype(dtype),
146148
(
147149
self.df[Columns.User].values,
148150
self.df[Columns.Item].values,

rectools/models/base.py

Lines changed: 139 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,17 @@
1919

2020
import numpy as np
2121
import pandas as pd
22+
import typing_extensions as tpe
23+
from pydantic import PlainSerializer
24+
from pydantic_core import PydanticSerializationError
2225

2326
from rectools import Columns, ExternalIds, InternalIds
2427
from rectools.dataset import Dataset
2528
from rectools.dataset.identifiers import IdMap
2629
from rectools.exceptions import NotFittedError
2730
from rectools.types import ExternalIdsArray, InternalIdsArray
31+
from rectools.utils.config import BaseConfig
32+
from rectools.utils.misc import make_dict_flat
2833

2934
T = tp.TypeVar("T", bound="ModelBase")
3035
ScoresArray = np.ndarray
@@ -38,7 +43,30 @@
3843
RecoTriplet_T = tp.TypeVar("RecoTriplet_T", InternalRecoTriplet, SemiInternalRecoTriplet, ExternalRecoTriplet)
3944

4045

41-
class ModelBase:
46+
def _serialize_random_state(rs: tp.Optional[tp.Union[None, int, np.random.RandomState]]) -> tp.Union[None, int]:
47+
if rs is None or isinstance(rs, int):
48+
return rs
49+
50+
# NOBUG: We can add serialization using get/set_state, but it's not human readable
51+
raise TypeError("`random_state` must be ``None`` or have ``int`` type to convert it to simple type")
52+
53+
54+
RandomState = tpe.Annotated[
55+
tp.Union[None, int, np.random.RandomState],
56+
PlainSerializer(func=_serialize_random_state, when_used="json"),
57+
]
58+
59+
60+
class ModelConfig(BaseConfig):
61+
"""Base model config."""
62+
63+
verbose: int = 0
64+
65+
66+
ModelConfig_T = tp.TypeVar("ModelConfig_T", bound=ModelConfig)
67+
68+
69+
class ModelBase(tp.Generic[ModelConfig_T]):
4270
"""
4371
Base model class.
4472
@@ -49,10 +77,120 @@ class ModelBase:
4977
recommends_for_warm: bool = False
5078
recommends_for_cold: bool = False
5179

80+
config_class: tp.Type[ModelConfig_T]
81+
5282
def __init__(self, *args: tp.Any, verbose: int = 0, **kwargs: tp.Any) -> None:
5383
self.is_fitted = False
5484
self.verbose = verbose
5585

86+
@tp.overload
87+
def get_config( # noqa: D102
88+
self, mode: tp.Literal["pydantic"], simple_types: bool = False
89+
) -> ModelConfig_T: # pragma: no cover
90+
...
91+
92+
@tp.overload
93+
def get_config( # noqa: D102
94+
self, mode: tp.Literal["dict"] = "dict", simple_types: bool = False
95+
) -> tp.Dict[str, tp.Any]: # pragma: no cover
96+
...
97+
98+
def get_config(
99+
self, mode: tp.Literal["pydantic", "dict"] = "dict", simple_types: bool = False
100+
) -> tp.Union[ModelConfig_T, tp.Dict[str, tp.Any]]:
101+
"""
102+
Return model config.
103+
104+
Parameters
105+
----------
106+
mode : {'pydantic', 'dict'}, default 'dict'
107+
Format of returning config.
108+
simple_types : bool, default False
109+
If True, return config with JSON serializable types.
110+
Only works for `mode='dict'`.
111+
112+
Returns
113+
-------
114+
Pydantic model or dict
115+
Model config.
116+
117+
Raises
118+
------
119+
ValueError
120+
If `mode` is not 'object' or 'dict', or if `simple_types` is ``True`` and format is not 'dict'.
121+
"""
122+
config = self._get_config()
123+
if mode == "pydantic":
124+
if simple_types:
125+
raise ValueError("`simple_types` is not compatible with `mode='pydantic'`")
126+
return config
127+
128+
pydantic_mode = "json" if simple_types else "python"
129+
try:
130+
config_dict = config.model_dump(mode=pydantic_mode)
131+
except PydanticSerializationError as e:
132+
if e.__cause__ is not None:
133+
raise e.__cause__
134+
raise e
135+
136+
if mode == "dict":
137+
return config_dict
138+
139+
raise ValueError(f"Unknown mode: {mode}")
140+
141+
def _get_config(self) -> ModelConfig_T:
142+
raise NotImplementedError(f"`get_config` method is not implemented for `{self.__class__.__name__}` model")
143+
144+
def get_params(self, simple_types: bool = False, sep: str = ".") -> tp.Dict[str, tp.Any]:
145+
"""
146+
Return model parameters.
147+
Same as `get_config` but returns flat dict.
148+
149+
Parameters
150+
----------
151+
simple_types : bool, default False
152+
If True, return config with JSON serializable types.
153+
sep : str, default "."
154+
Separator for nested keys.
155+
156+
Returns
157+
-------
158+
dict
159+
Model parameters.
160+
"""
161+
config_dict = self.get_config(mode="dict", simple_types=simple_types)
162+
config_flat = make_dict_flat(config_dict, sep=sep) # NOBUG: We're not handling lists for now
163+
return config_flat
164+
165+
@classmethod
166+
def from_config(cls, config: tp.Union[dict, ModelConfig_T]) -> tpe.Self:
167+
"""
168+
Create model from config.
169+
170+
Parameters
171+
----------
172+
config : dict or ModelConfig
173+
Model config.
174+
175+
Returns
176+
-------
177+
Model instance.
178+
"""
179+
try:
180+
config_cls = cls.config_class
181+
except AttributeError:
182+
raise NotImplementedError(f"`from_config` method is not implemented for `{cls.__name__}` model.") from None
183+
184+
if not isinstance(config, config_cls):
185+
config_obj = cls.config_class.model_validate(config)
186+
else:
187+
config_obj = config
188+
return cls._from_config(config_obj)
189+
190+
@classmethod
191+
def _from_config(cls, config: ModelConfig_T) -> tpe.Self:
192+
raise NotImplementedError()
193+
56194
def fit(self: T, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> T:
57195
"""
58196
Fit model.

rectools/models/ease.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,26 @@
1717
import typing as tp
1818

1919
import numpy as np
20+
import typing_extensions as tpe
2021
from scipy import sparse
2122

2223
from rectools import InternalIds
2324
from rectools.dataset import Dataset
25+
from rectools.models.base import ModelConfig
2426
from rectools.types import InternalIdsArray
2527

2628
from .base import ModelBase, Scores
2729
from .rank import Distance, ImplicitRanker
2830

2931

30-
class EASEModel(ModelBase):
32+
class EASEModelConfig(ModelConfig):
33+
"""Config for `EASE` model."""
34+
35+
regularization: float = 500.0
36+
num_threads: int = 1
37+
38+
39+
class EASEModel(ModelBase[EASEModelConfig]):
3140
"""
3241
Embarrassingly Shallow Autoencoders for Sparse Data model.
3342
@@ -51,17 +60,27 @@ class EASEModel(ModelBase):
5160
recommends_for_warm = False
5261
recommends_for_cold = False
5362

63+
config_class = EASEModelConfig
64+
5465
def __init__(
5566
self,
5667
regularization: float = 500.0,
5768
num_threads: int = 1,
5869
verbose: int = 0,
5970
):
71+
6072
super().__init__(verbose=verbose)
6173
self.weight: np.ndarray
6274
self.regularization = regularization
6375
self.num_threads = num_threads
6476

77+
def _get_config(self) -> EASEModelConfig:
78+
return EASEModelConfig(regularization=self.regularization, num_threads=self.num_threads, verbose=self.verbose)
79+
80+
@classmethod
81+
def _from_config(cls, config: EASEModelConfig) -> tpe.Self:
82+
return cls(regularization=config.regularization, num_threads=config.num_threads, verbose=config.verbose)
83+
6584
def _fit(self, dataset: Dataset) -> None: # type: ignore
6685
ui_csr = dataset.get_user_item_matrix(include_weights=True)
6786

0 commit comments

Comments
 (0)