Skip to content

Commit 6ff1fb2

Browse files
authored
Merge pull request #183 from devwums/fix/issue-140-model-persistence
feat: persist scalper-detection model to disk and reload on startup
2 parents cb58213 + e5a1a59 commit 6ff1fb2

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

src/model_persistence.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""Utilities to persist and reload the scalper-detection model from disk."""
2+
import logging
3+
import os
4+
from typing import Any, Optional
5+
6+
MODEL_PATH = os.environ.get("SCALPER_MODEL_PATH", "scalper_model.pkl")
7+
8+
logger = logging.getLogger("veritix.model_persistence")
9+
10+
11+
def save_model(pipeline: Any, path: str = MODEL_PATH) -> None:
12+
"""Serialize *pipeline* to *path* using joblib."""
13+
import joblib # type: ignore[import-untyped]
14+
15+
joblib.dump(pipeline, path)
16+
logger.info("Scalper model saved to %s", path)
17+
18+
19+
def load_model(path: str = MODEL_PATH) -> Optional[Any]:
20+
"""Load and return the model from *path*, or None if the file does not exist."""
21+
import joblib # type: ignore[import-untyped]
22+
23+
if not os.path.exists(path):
24+
logger.info("No saved model found at %s; will train from scratch", path)
25+
return None
26+
pipeline = joblib.load(path)
27+
logger.info("Scalper model loaded from %s", path)
28+
return pipeline
29+
30+
31+
def get_or_train_model(path: str = MODEL_PATH) -> Any:
32+
"""Return the persisted model if available, otherwise train, save, and return it."""
33+
from src.utils import train_logistic_regression_pipeline
34+
35+
pipeline = load_model(path)
36+
if pipeline is None:
37+
pipeline = train_logistic_regression_pipeline()
38+
save_model(pipeline, path)
39+
return pipeline

0 commit comments

Comments
 (0)