Skip to content

Commit 53dea46

Browse files
author
Louie Lu
committed
Website data: emit full per-domain + subcategory category_scores
The website's category_scores has three levels (overall metrics, per-domain categories, per-domain subcategories) and reports empty buckets as 0. The first build only emitted the overall difficulty metrics, which would drop the per-domain/subcategory breakdown for regenerated routers. Now reproduces the live structure exactly: validated identical to the website for unchanged routers (sqwish/r2/auto/vllm); azure updates to its re-evaluated numbers; baselines preserved; new routers get the full structure.
1 parent cf7d3ef commit 53dea46

1 file changed

Lines changed: 77 additions & 22 deletions

File tree

scripts/website/build_site_data.py

Lines changed: 77 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,24 @@ def load_predictions(
119119
return json.load(handle)
120120

121121

122-
def load_difficulty_map() -> dict[str, str]:
123-
"""Map global index -> difficulty (easy/medium/hard) from the dataset."""
122+
def load_dataset_maps() -> tuple[dict[str, str], dict[str, str], dict[str, str]]:
123+
"""Return (difficulty, domain, category) maps keyed by global index.
124+
125+
Domain/category keys are stripped of the leading ``"<n> "`` index to match
126+
the website's ``categories``/``subcategories`` keys
127+
(e.g. ``"5 Science"`` -> ``"Science"``).
128+
"""
124129
from datasets import load_from_disk
125130

126131
dataset_dir = os.environ.get("ROUTERARENA_DATASET_DIR", str(REPO_ROOT / "dataset"))
127132
ds = load_from_disk(str(Path(dataset_dir) / "routerarena"))
128-
return {row["Global Index"]: str(row["Difficulty"]).lower() for row in ds}
133+
difficulty, domain, category = {}, {}, {}
134+
for row in ds:
135+
gi = row["Global Index"]
136+
difficulty[gi] = str(row["Difficulty"]).lower()
137+
domain[gi] = re.sub(r"^\d+\s+", "", str(row["Domain"])).strip()
138+
category[gi] = re.sub(r"^\d+\s+", "", str(row["Category"])).strip()
139+
return difficulty, domain, category
129140

130141

131142
def _regular(predictions: list[dict[str, Any]]) -> list[dict[str, Any]]:
@@ -161,19 +172,18 @@ def compute_headline_metrics(
161172
}
162173

163174

164-
def compute_category_scores(
165-
predictions: list[dict[str, Any]],
166-
flip_labels: list[dict[str, Any]],
175+
def _difficulty_metrics(
176+
entries: list[dict[str, Any]],
177+
flip_map: dict[str, int],
167178
difficulty: dict[str, str],
168179
) -> dict[str, dict[str, Any]]:
169-
"""Per-difficulty (easy/medium/hard/all) accuracy, cost, robustness."""
170-
flip_map = {item["global index"]: item["flip"] for item in flip_labels}
180+
"""{easy/medium/hard/all: {accuracy, cost, robustness}} over the given entries."""
171181
buckets: dict[str, dict[str, list[float]]] = {
172182
b: {"acc": [], "cost": []} for b in (*DIFFICULTY_BUCKETS, "all")
173183
}
174184
flip_buckets: dict[str, list[int]] = {b: [] for b in (*DIFFICULTY_BUCKETS, "all")}
175185

176-
for p in _regular(predictions):
186+
for p in entries:
177187
gi = p.get("global index")
178188
acc = _numeric(p.get("accuracy")) or 0.0
179189
cost = _numeric(p.get("cost")) or 0.0
@@ -182,25 +192,68 @@ def compute_category_scores(
182192
for t in targets:
183193
buckets[t]["acc"].append(acc)
184194
buckets[t]["cost"].append(cost)
185-
if gi in flip_map:
195+
if isinstance(gi, str) and gi in flip_map:
186196
for t in targets:
187197
flip_buckets[t].append(flip_map[gi])
188198

189199
out: dict[str, dict[str, Any]] = {}
190200
for b in (*DIFFICULTY_BUCKETS, "all"):
191-
accs = buckets[b]["acc"]
192-
costs = buckets[b]["cost"]
193-
flips = flip_buckets[b]
201+
accs, costs, flips = buckets[b]["acc"], buckets[b]["cost"], flip_buckets[b]
202+
# Empty buckets are reported as 0 (matches the website's convention).
194203
out[b] = {
195-
"accuracy": round(sum(accs) / len(accs) * 100, 1) if accs else None,
196-
"cost": round(sum(costs) / len(costs), 4) if costs else None,
197-
"robustness": round((1 - sum(flips) / len(flips)) * 100, 1)
198-
if flips
199-
else None,
204+
"accuracy": round(sum(accs) / len(accs) * 100, 1) if accs else 0,
205+
"cost": round(sum(costs) / len(costs), 4) if costs else 0,
206+
"robustness": round((1 - sum(flips) / len(flips)) * 100, 1) if flips else 0,
200207
}
201208
return out
202209

203210

211+
def compute_category_scores(
212+
predictions: list[dict[str, Any]],
213+
flip_labels: list[dict[str, Any]],
214+
difficulty: dict[str, str],
215+
domain: dict[str, str],
216+
category: dict[str, str],
217+
) -> dict[str, Any]:
218+
"""Website per-router entry: overall + per-domain + per-subcategory metrics.
219+
220+
Mirrors the website schema:
221+
{"metrics": {<difficulty>},
222+
"categories": {<domain>: {"metrics": {<difficulty>},
223+
"subcategories": {<category>: {"metrics": {<difficulty>}}}}}}
224+
"""
225+
flip_map = {item["global index"]: item["flip"] for item in flip_labels}
226+
regular = _regular(predictions)
227+
228+
by_domain: dict[str, list[dict[str, Any]]] = {}
229+
for p in regular:
230+
gi = p.get("global index")
231+
dom = domain.get(gi) if isinstance(gi, str) else None
232+
if dom:
233+
by_domain.setdefault(dom, []).append(p)
234+
235+
categories: dict[str, Any] = {}
236+
for dom, dom_entries in sorted(by_domain.items()):
237+
by_cat: dict[str, list[dict[str, Any]]] = {}
238+
for p in dom_entries:
239+
gi = p.get("global index")
240+
cat = category.get(gi) if isinstance(gi, str) else None
241+
if cat:
242+
by_cat.setdefault(cat, []).append(p)
243+
categories[dom] = {
244+
"metrics": _difficulty_metrics(dom_entries, flip_map, difficulty),
245+
"subcategories": {
246+
cat: {"metrics": _difficulty_metrics(entries, flip_map, difficulty)}
247+
for cat, entries in sorted(by_cat.items())
248+
},
249+
}
250+
251+
return {
252+
"metrics": _difficulty_metrics(regular, flip_map, difficulty),
253+
"categories": categories,
254+
}
255+
256+
204257
def main(argv: Optional[list[str]] = None) -> int:
205258
parser = argparse.ArgumentParser(description=__doc__)
206259
parser.add_argument(
@@ -233,7 +286,9 @@ def main(argv: Optional[list[str]] = None) -> int:
233286
json.loads(cs_path.read_text(encoding="utf-8")) if cs_path.exists() else {}
234287
)
235288

236-
difficulty = {} if args.skip_category else load_difficulty_map()
289+
difficulty, domain, category = (
290+
({}, {}, {}) if args.skip_category else load_dataset_maps()
291+
)
237292
updated, regenerated, missing = [], [], []
238293

239294
for meta in routers:
@@ -286,9 +341,9 @@ def main(argv: Optional[list[str]] = None) -> int:
286341
and cat_key
287342
and any(_numeric(p.get("accuracy")) is not None for p in _regular(preds))
288343
):
289-
category_scores[cat_key] = {
290-
"metrics": compute_category_scores(preds, flips, difficulty)
291-
}
344+
category_scores[cat_key] = compute_category_scores(
345+
preds, flips, difficulty, domain, category
346+
)
292347
regenerated.append(prediction)
293348

294349
leaderboard.sort(

0 commit comments

Comments
 (0)