Skip to content

Commit 7d53178

Browse files
authored
Merge pull request #187 from Leothosine/fix/issue-144-prediction-confidence-threshold
feat: add model prediction confidence threshold and uncertain classification band
2 parents e9f6be9 + 109bd4f commit 7d53178

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

src/prediction_confidence.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
"""Model prediction confidence threshold with an 'uncertain' classification band.
2+
3+
Wraps a fitted sklearn pipeline and maps raw probabilities to three labels:
4+
- 'scalper' — probability >= high_threshold
5+
- 'uncertain' — probability is between low_threshold and high_threshold
6+
- 'not_scalper' — probability < low_threshold
7+
"""
8+
from typing import Any, Dict
9+
10+
# Default thresholds — tune these based on desired precision/recall trade-off
11+
DEFAULT_HIGH_THRESHOLD = 0.70 # above this -> scalper
12+
DEFAULT_LOW_THRESHOLD = 0.40 # below this -> not_scalper; between -> uncertain
13+
14+
15+
def classify_with_confidence(
16+
pipeline: Any,
17+
features: Any,
18+
high_threshold: float = DEFAULT_HIGH_THRESHOLD,
19+
low_threshold: float = DEFAULT_LOW_THRESHOLD,
20+
) -> Dict[str, Any]:
21+
"""Run the pipeline on *features* and return a labelled prediction dict.
22+
23+
Parameters
24+
----------
25+
pipeline:
26+
A fitted sklearn Pipeline exposing ``predict_proba``.
27+
features:
28+
A 2-D array-like of shape (1, n_features).
29+
high_threshold:
30+
Probability at or above which the prediction is 'scalper'.
31+
low_threshold:
32+
Probability below which the prediction is 'not_scalper'.
33+
34+
Returns
35+
-------
36+
dict with keys:
37+
- ``label`` : 'scalper' | 'uncertain' | 'not_scalper'
38+
- ``probability`` : float, the model's scalper-class probability
39+
"""
40+
probability = float(pipeline.predict_proba(features)[0, 1])
41+
42+
if probability >= high_threshold:
43+
label = "scalper"
44+
elif probability < low_threshold:
45+
label = "not_scalper"
46+
else:
47+
label = "uncertain"
48+
49+
return {"label": label, "probability": probability}

0 commit comments

Comments
 (0)