|
| 1 | +"""Model prediction confidence threshold with an 'uncertain' classification band. |
| 2 | +
|
| 3 | +Wraps a fitted sklearn pipeline and maps raw probabilities to three labels: |
| 4 | + - 'scalper' — probability >= high_threshold |
| 5 | + - 'uncertain' — probability is between low_threshold and high_threshold |
| 6 | + - 'not_scalper' — probability < low_threshold |
| 7 | +""" |
| 8 | +from typing import Any, Dict |
| 9 | + |
| 10 | +# Default thresholds — tune these based on desired precision/recall trade-off |
| 11 | +DEFAULT_HIGH_THRESHOLD = 0.70 # above this -> scalper |
| 12 | +DEFAULT_LOW_THRESHOLD = 0.40 # below this -> not_scalper; between -> uncertain |
| 13 | + |
| 14 | + |
| 15 | +def classify_with_confidence( |
| 16 | + pipeline: Any, |
| 17 | + features: Any, |
| 18 | + high_threshold: float = DEFAULT_HIGH_THRESHOLD, |
| 19 | + low_threshold: float = DEFAULT_LOW_THRESHOLD, |
| 20 | +) -> Dict[str, Any]: |
| 21 | + """Run the pipeline on *features* and return a labelled prediction dict. |
| 22 | +
|
| 23 | + Parameters |
| 24 | + ---------- |
| 25 | + pipeline: |
| 26 | + A fitted sklearn Pipeline exposing ``predict_proba``. |
| 27 | + features: |
| 28 | + A 2-D array-like of shape (1, n_features). |
| 29 | + high_threshold: |
| 30 | + Probability at or above which the prediction is 'scalper'. |
| 31 | + low_threshold: |
| 32 | + Probability below which the prediction is 'not_scalper'. |
| 33 | +
|
| 34 | + Returns |
| 35 | + ------- |
| 36 | + dict with keys: |
| 37 | + - ``label`` : 'scalper' | 'uncertain' | 'not_scalper' |
| 38 | + - ``probability`` : float, the model's scalper-class probability |
| 39 | + """ |
| 40 | + probability = float(pipeline.predict_proba(features)[0, 1]) |
| 41 | + |
| 42 | + if probability >= high_threshold: |
| 43 | + label = "scalper" |
| 44 | + elif probability < low_threshold: |
| 45 | + label = "not_scalper" |
| 46 | + else: |
| 47 | + label = "uncertain" |
| 48 | + |
| 49 | + return {"label": label, "probability": probability} |
0 commit comments