|
14 | 14 |
|
15 | 15 | """Base model.""" |
16 | 16 |
|
| 17 | +import pickle |
17 | 18 | import typing as tp |
18 | 19 | import warnings |
| 20 | +from pathlib import Path |
19 | 21 |
|
20 | 22 | import numpy as np |
21 | 23 | import pandas as pd |
|
42 | 44 |
|
43 | 45 | RecoTriplet_T = tp.TypeVar("RecoTriplet_T", InternalRecoTriplet, SemiInternalRecoTriplet, ExternalRecoTriplet) |
44 | 46 |
|
| 47 | +FileLike = tp.Union[str, Path, tp.IO[bytes]] |
| 48 | + |
| 49 | +PICKLE_PROTOCOL = 5 |
| 50 | + |
45 | 51 |
|
46 | 52 | def _serialize_random_state(rs: tp.Optional[tp.Union[None, int, np.random.RandomState]]) -> tp.Union[None, int]: |
47 | 53 | if rs is None or isinstance(rs, int): |
@@ -191,6 +197,85 @@ def from_config(cls, config: tp.Union[dict, ModelConfig_T]) -> tpe.Self: |
191 | 197 | def _from_config(cls, config: ModelConfig_T) -> tpe.Self: |
192 | 198 | raise NotImplementedError() |
193 | 199 |
|
| 200 | + def save(self, f: FileLike) -> int: |
| 201 | + """ |
| 202 | + Save model to file. |
| 203 | +
|
| 204 | + Parameters |
| 205 | + ---------- |
| 206 | + f : str or Path or file-like object |
| 207 | + Path to file or file-like object. |
| 208 | +
|
| 209 | + Returns |
| 210 | + ------- |
| 211 | + int |
| 212 | + Number of bytes written. |
| 213 | + """ |
| 214 | + data = self.dumps() |
| 215 | + |
| 216 | + if isinstance(f, (str, Path)): |
| 217 | + return Path(f).write_bytes(data) |
| 218 | + |
| 219 | + return f.write(data) |
| 220 | + |
| 221 | + def dumps(self) -> bytes: |
| 222 | + """ |
| 223 | + Serialize model to bytes. |
| 224 | +
|
| 225 | + Returns |
| 226 | + ------- |
| 227 | + bytes |
| 228 | + Serialized model. |
| 229 | + """ |
| 230 | + return pickle.dumps(self, protocol=PICKLE_PROTOCOL) |
| 231 | + |
| 232 | + @classmethod |
| 233 | + def load(cls, f: FileLike) -> tpe.Self: |
| 234 | + """ |
| 235 | + Load model from file. |
| 236 | +
|
| 237 | + Parameters |
| 238 | + ---------- |
| 239 | + f : str or Path or file-like object |
| 240 | + Path to file or file-like object. |
| 241 | +
|
| 242 | + Returns |
| 243 | + ------- |
| 244 | + model |
| 245 | + Model instance. |
| 246 | + """ |
| 247 | + if isinstance(f, (str, Path)): |
| 248 | + data = Path(f).read_bytes() |
| 249 | + else: |
| 250 | + data = f.read() |
| 251 | + |
| 252 | + return cls.loads(data) |
| 253 | + |
| 254 | + @classmethod |
| 255 | + def loads(cls, data: bytes) -> tpe.Self: |
| 256 | + """ |
| 257 | + Load model from bytes. |
| 258 | +
|
| 259 | + Parameters |
| 260 | + ---------- |
| 261 | + data : bytes |
| 262 | + Serialized model. |
| 263 | +
|
| 264 | + Returns |
| 265 | + ------- |
| 266 | + model |
| 267 | + Model instance. |
| 268 | +
|
| 269 | + Raises |
| 270 | + ------ |
| 271 | + TypeError |
| 272 | + If loaded object is not a direct instance of model class. |
| 273 | + """ |
| 274 | + loaded = pickle.loads(data) |
| 275 | + if loaded.__class__ is not cls: |
| 276 | + raise TypeError(f"Loaded object is not a direct instance of `{cls.__name__}`") |
| 277 | + return loaded |
| 278 | + |
194 | 279 | def fit(self: T, dataset: Dataset, *args: tp.Any, **kwargs: tp.Any) -> T: |
195 | 280 | """ |
196 | 281 | Fit model. |
|
0 commit comments