1919
2020import numpy as np
2121import pandas as pd
22+ import typing_extensions as tpe
23+ from pydantic import PlainSerializer
24+ from pydantic_core import PydanticSerializationError
2225
2326from rectools import Columns , ExternalIds , InternalIds
2427from rectools .dataset import Dataset
2528from rectools .dataset .identifiers import IdMap
2629from rectools .exceptions import NotFittedError
2730from rectools .types import ExternalIdsArray , InternalIdsArray
31+ from rectools .utils .config import BaseConfig
32+ from rectools .utils .misc import make_dict_flat
2833
2934T = tp .TypeVar ("T" , bound = "ModelBase" )
3035ScoresArray = np .ndarray
3843RecoTriplet_T = tp .TypeVar ("RecoTriplet_T" , InternalRecoTriplet , SemiInternalRecoTriplet , ExternalRecoTriplet )
3944
4045
41- class ModelBase :
46+ def _serialize_random_state (rs : tp .Optional [tp .Union [None , int , np .random .RandomState ]]) -> tp .Union [None , int ]:
47+ if rs is None or isinstance (rs , int ):
48+ return rs
49+
50+ # NOBUG: We can add serialization using get/set_state, but it's not human readable
51+ raise TypeError ("`random_state` must be ``None`` or have ``int`` type to convert it to simple type" )
52+
53+
54+ RandomState = tpe .Annotated [
55+ tp .Union [None , int , np .random .RandomState ],
56+ PlainSerializer (func = _serialize_random_state , when_used = "json" ),
57+ ]
58+
59+
60+ class ModelConfig (BaseConfig ):
61+ """Base model config."""
62+
63+ verbose : int = 0
64+
65+
66+ ModelConfig_T = tp .TypeVar ("ModelConfig_T" , bound = ModelConfig )
67+
68+
69+ class ModelBase (tp .Generic [ModelConfig_T ]):
4270 """
4371 Base model class.
4472
@@ -49,10 +77,120 @@ class ModelBase:
4977 recommends_for_warm : bool = False
5078 recommends_for_cold : bool = False
5179
80+ config_class : tp .Type [ModelConfig_T ]
81+
5282 def __init__ (self , * args : tp .Any , verbose : int = 0 , ** kwargs : tp .Any ) -> None :
5383 self .is_fitted = False
5484 self .verbose = verbose
5585
86+ @tp .overload
87+ def get_config ( # noqa: D102
88+ self , mode : tp .Literal ["pydantic" ], simple_types : bool = False
89+ ) -> ModelConfig_T : # pragma: no cover
90+ ...
91+
92+ @tp .overload
93+ def get_config ( # noqa: D102
94+ self , mode : tp .Literal ["dict" ] = "dict" , simple_types : bool = False
95+ ) -> tp .Dict [str , tp .Any ]: # pragma: no cover
96+ ...
97+
98+ def get_config (
99+ self , mode : tp .Literal ["pydantic" , "dict" ] = "dict" , simple_types : bool = False
100+ ) -> tp .Union [ModelConfig_T , tp .Dict [str , tp .Any ]]:
101+ """
102+ Return model config.
103+
104+ Parameters
105+ ----------
106+ mode : {'pydantic', 'dict'}, default 'dict'
107+ Format of returning config.
108+ simple_types : bool, default False
109+ If True, return config with JSON serializable types.
110+ Only works for `mode='dict'`.
111+
112+ Returns
113+ -------
114+ Pydantic model or dict
115+ Model config.
116+
117+ Raises
118+ ------
119+ ValueError
120+ If `mode` is not 'object' or 'dict', or if `simple_types` is ``True`` and format is not 'dict'.
121+ """
122+ config = self ._get_config ()
123+ if mode == "pydantic" :
124+ if simple_types :
125+ raise ValueError ("`simple_types` is not compatible with `mode='pydantic'`" )
126+ return config
127+
128+ pydantic_mode = "json" if simple_types else "python"
129+ try :
130+ config_dict = config .model_dump (mode = pydantic_mode )
131+ except PydanticSerializationError as e :
132+ if e .__cause__ is not None :
133+ raise e .__cause__
134+ raise e
135+
136+ if mode == "dict" :
137+ return config_dict
138+
139+ raise ValueError (f"Unknown mode: { mode } " )
140+
141+ def _get_config (self ) -> ModelConfig_T :
142+ raise NotImplementedError (f"`get_config` method is not implemented for `{ self .__class__ .__name__ } ` model" )
143+
144+ def get_params (self , simple_types : bool = False , sep : str = "." ) -> tp .Dict [str , tp .Any ]:
145+ """
146+ Return model parameters.
147+ Same as `get_config` but returns flat dict.
148+
149+ Parameters
150+ ----------
151+ simple_types : bool, default False
152+ If True, return config with JSON serializable types.
153+ sep : str, default "."
154+ Separator for nested keys.
155+
156+ Returns
157+ -------
158+ dict
159+ Model parameters.
160+ """
161+ config_dict = self .get_config (mode = "dict" , simple_types = simple_types )
162+ config_flat = make_dict_flat (config_dict , sep = sep ) # NOBUG: We're not handling lists for now
163+ return config_flat
164+
165+ @classmethod
166+ def from_config (cls , config : tp .Union [dict , ModelConfig_T ]) -> tpe .Self :
167+ """
168+ Create model from config.
169+
170+ Parameters
171+ ----------
172+ config : dict or ModelConfig
173+ Model config.
174+
175+ Returns
176+ -------
177+ Model instance.
178+ """
179+ try :
180+ config_cls = cls .config_class
181+ except AttributeError :
182+ raise NotImplementedError (f"`from_config` method is not implemented for `{ cls .__name__ } ` model." ) from None
183+
184+ if not isinstance (config , config_cls ):
185+ config_obj = cls .config_class .model_validate (config )
186+ else :
187+ config_obj = config
188+ return cls ._from_config (config_obj )
189+
190+ @classmethod
191+ def _from_config (cls , config : ModelConfig_T ) -> tpe .Self :
192+ raise NotImplementedError ()
193+
56194 def fit (self : T , dataset : Dataset , * args : tp .Any , ** kwargs : tp .Any ) -> T :
57195 """
58196 Fit model.
0 commit comments