-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathresponse_quality_report.py
More file actions
126 lines (101 loc) · 4.18 KB
/
response_quality_report.py
File metadata and controls
126 lines (101 loc) · 4.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
"""Offline quality diagnostics for generated CSV outputs.
Computes cheap, interpretable metrics without needing hidden labels:
- response length (chars / words)
- escalation rate
- numeric-string leakage heuristic (digits sequences not present in retrieved evidence)
- lexical overlap between response tokens and retrieved chunk tokens (requires rebuilding retrieval hits)
This is meant for hackathon iteration: catch verbose outputs and grounding drift early.
"""
from __future__ import annotations
import argparse
import os
import re
import sys
from dataclasses import dataclass
from pathlib import Path
import pandas as pd
from config import CACHE_PATH, DATA_DIR, TOP_K
from csv_io import TicketCsvError, read_tickets_csv
from corpus import tokenize
from grounding import has_unsupported_numbers, lexical_overlap
from retrieve import BM25Index, rerank_hits
def _norm_company(val: object) -> str | None:
if val is None or (isinstance(val, float) and pd.isna(val)):
return None
s = str(val).strip()
if not s or s.lower() == "none":
return None
return s
def _brand_for_search(company: str | None, issue: str, subject: str, index: BM25Index) -> str:
if company:
m = company.strip().lower()
if m == "hackerrank":
return "hackerrank"
if m == "claude":
return "claude"
if m == "visa":
return "visa"
return index.infer_brand(f"{subject}\n{issue}")
@dataclass(frozen=True)
class RowMetrics:
overlap: float
numeric_leak: bool
def metrics_for_row(index: BM25Index, issue: str, subject: str, company_raw: object, response: str) -> RowMetrics:
company = _norm_company(company_raw)
brand = _brand_for_search(company, issue, subject, index)
hits, _raw_top = index.search(f"{subject}\n{issue}", brand, TOP_K)
hits = rerank_hits(f"{subject}\n{issue}", hits)
ov = lexical_overlap(response, hits) if hits else 0.0
leak = has_unsupported_numbers(response, hits) if hits else False
return RowMetrics(overlap=ov, numeric_leak=leak)
def main() -> None:
ap = argparse.ArgumentParser(
description="Diagnostics on prediction CSV (lengths, escalation rate, lexical overlap vs retrieval).",
)
ap.add_argument("--pred", type=str, default=str(Path("..") / "support_tickets" / "output.csv"))
ap.add_argument("--offline", action="store_true", help="Force ORCHESTRATE_DISABLE_LLM=1 (informational; retrieval ignores it)")
args = ap.parse_args()
if args.offline:
os.environ["ORCHESTRATE_DISABLE_LLM"] = "1"
try:
df = read_tickets_csv(args.pred, label="--pred")
except (FileNotFoundError, TicketCsvError) as e:
print(f"error: {e}", file=sys.stderr)
sys.exit(2)
norm_cols = {str(c).strip().lower(): c for c in df.columns}
required = {"issue", "subject", "company", "response", "status"}
missing = required - set(norm_cols.keys())
if missing:
raise SystemExit(f"CSV missing columns: {sorted(missing)}")
issue_c = norm_cols["issue"]
sub_c = norm_cols["subject"]
comp_c = norm_cols["company"]
resp_c = norm_cols["response"]
stat_c = norm_cols["status"]
index = BM25Index.load(CACHE_PATH, DATA_DIR)
overlaps: list[float] = []
leaks = 0
lengths = []
esc = 0
for _, row in df.iterrows():
issue = str(row.get(issue_c, "") or "")
subject = str(row.get(sub_c, "") or "")
resp = str(row.get(resp_c, "") or "")
st = str(row.get(stat_c, "") or "").lower()
if st == "escalated":
esc += 1
lengths.append(len(resp.split()))
m = metrics_for_row(index, issue, subject, row.get(comp_c), resp)
overlaps.append(m.overlap)
if m.numeric_leak:
leaks += 1
n = max(1, len(df))
print(f"rows: {len(df)}")
print(f"escalated_rate: {esc/n:.2%}")
print(f"avg_response_words: {sum(lengths)/n:.1f}")
print(f"p95_response_words: {sorted(lengths)[int(0.95*(len(lengths)-1))] if lengths else 0}")
print(f"avg_lexical_overlap: {sum(overlaps)/n:.3f}")
print(f"p05_lexical_overlap: {sorted(overlaps)[0] if overlaps else 0.0}")
print(f"numeric_leak_rows: {leaks} ({leaks/n:.2%})")
if __name__ == "__main__":
main()