-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patheval_sample.py
More file actions
138 lines (116 loc) · 5.7 KB
/
eval_sample.py
File metadata and controls
138 lines (116 loc) · 5.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from __future__ import annotations
import argparse
import sys
from pathlib import Path
import pandas as pd
from csv_io import TicketCsvError, canonicalize_ticket_columns, read_tickets_csv, rename_prediction_columns
from eval_metrics import compact_overlap_ratio, normalize_text, token_set_f1
def _norm_status(x: object) -> str:
if x is None or (isinstance(x, float) and pd.isna(x)):
return ""
s = str(x).strip().lower()
if s in {"replied", "reply"}:
return "replied"
if s in {"escalated", "escalate"}:
return "escalated"
return s
def main() -> None:
ap = argparse.ArgumentParser(
description="Compare model predictions to labeled sample_support_tickets.csv (routing + response fuzzy metrics).",
)
ap.add_argument("--sample", type=str, default=str(Path("..") / "support_tickets" / "sample_support_tickets.csv"))
ap.add_argument("--pred", type=str, default=str(Path("..") / "support_tickets" / "output.csv"))
ap.add_argument("--report", type=str, default=str(Path("..") / "support_tickets" / "sample_eval_report.csv"))
ap.add_argument(
"--routing-detail",
action="store_true",
help="Print per-row gold vs predicted routing (status, request_type, product_area).",
)
args = ap.parse_args()
try:
sample = read_tickets_csv(args.sample, label="--sample")
pred = read_tickets_csv(args.pred, label="--pred")
sample = canonicalize_ticket_columns(sample)
pred = canonicalize_ticket_columns(pred)
except (FileNotFoundError, TicketCsvError) as e:
print(f"error: {e}", file=sys.stderr)
sys.exit(2)
pred = rename_prediction_columns(pred)
key_cols = ["Issue", "Subject", "Company"]
merged = sample.merge(pred, on=key_cols, how="inner").copy()
print(f"sample rows: {len(sample)}")
print(f"pred rows: {len(pred)}")
print(f"matched: {len(merged)} (exact match on Issue+Subject+Company)")
if len(merged) == 0:
print(
"error: no rows matched on Issue+Subject+Company between --sample and --pred. "
"Regenerate --pred from the same inputs or fix CSV keys.",
file=sys.stderr,
)
sys.exit(2)
merged.loc[:, "Status"] = merged["Status"].map(_norm_status)
merged.loc[:, "Pred_Status"] = merged["Pred_Status"].map(_norm_status)
def exact_acc(gold: str, pred_col: str) -> float:
g = merged[gold].fillna("").astype(str)
p = merged[pred_col].fillna("").astype(str)
return float((g == p).mean())
print("\nExact match accuracy (on matched rows):")
print(f"- status: {exact_acc('Status', 'Pred_Status'):.2%}")
print(f"- request_type: {exact_acc('Request Type', 'Pred_Request Type'):.2%}")
print(f"- product_area: {exact_acc('Product Area', 'Pred_Product Area'):.2%}")
print("\nAnswer columns (same rows; normalized exact + fuzzy):")
ge = merged["Response"].fillna("").map(normalize_text)
pe = merged["Pred_Response"].fillna("").map(normalize_text)
je = merged["Justification"].fillna("").map(normalize_text) if "Justification" in merged.columns else None
pje = merged["Pred_Justification"].fillna("").map(normalize_text) if "Pred_Justification" in merged.columns else None
print(f"- response (norm exact): {float((ge == pe).mean()):.2%}")
if je is not None and pje is not None:
print(f"- justification (norm exact): {float((je == pje).mean()):.2%}")
f1_r = [token_set_f1(str(a), str(b)) for a, b in zip(merged["Response"], merged["Pred_Response"])]
print(f"- response (token F1 mean): {sum(f1_r) / max(1, len(f1_r)):.3f}")
if "Justification" in merged.columns:
f1_j = [
token_set_f1(str(a), str(b)) for a, b in zip(merged["Justification"], merged["Pred_Justification"])
]
print(f"- justification (token F1 mean): {sum(f1_j) / max(1, len(f1_j)):.3f}")
ovl = [compact_overlap_ratio(str(a), str(b)) for a, b in zip(merged["Response"], merged["Pred_Response"])]
print(f"- response (compact char overlap mean): {sum(ovl) / max(1, len(ovl)):.3f}")
mism = merged[merged["Status"] != merged["Pred_Status"]][key_cols + ["Status", "Pred_Status"]]
print(f"\nStatus mismatches: {len(mism)}")
if args.routing_detail:
print("\n=== Per-row routing (gold vs pred) ===")
for i, r in merged.iterrows():
subj = str(r.get("Subject", ""))[:60]
g_st = str(r.get("Status", "")).strip()
p_st = str(r.get("Pred_Status", "")).strip()
g_rt = str(r.get("Request Type", "")).strip()
p_rt = str(r.get("Pred_Request Type", "")).strip()
g_pa = str(r.get("Product Area", "")).strip()
p_pa = str(r.get("Pred_Product Area", "")).strip()
ok = (
_norm_status(g_st) == _norm_status(p_st)
and g_rt.lower() == p_rt.lower()
and g_pa.lower() == p_pa.lower()
)
mark = "OK" if ok else "MISMATCH"
print(f"[{mark}] row={i} subject={subj!r}…")
print(f" status: gold={g_st!r} pred={p_st!r}")
print(f" request_type: gold={g_rt!r} pred={p_rt!r}")
print(f" product_area: gold={g_pa!r} pred={p_pa!r}")
report_cols = key_cols + [
"Status",
"Pred_Status",
"Request Type",
"Pred_Request Type",
"Product Area",
"Pred_Product Area",
"Response",
"Pred_Response",
]
if "Justification" in merged.columns:
report_cols += ["Justification", "Pred_Justification"]
Path(args.report).parent.mkdir(parents=True, exist_ok=True)
merged[report_cols].to_csv(args.report, index=False)
print(f"Wrote report: {args.report}")
if __name__ == "__main__":
main()