@@ -119,13 +119,24 @@ def load_predictions(
119119 return json .load (handle )
120120
121121
122- def load_difficulty_map () -> dict [str , str ]:
123- """Map global index -> difficulty (easy/medium/hard) from the dataset."""
122+ def load_dataset_maps () -> tuple [dict [str , str ], dict [str , str ], dict [str , str ]]:
123+ """Return (difficulty, domain, category) maps keyed by global index.
124+
125+ Domain/category keys are stripped of the leading ``"<n> "`` index to match
126+ the website's ``categories``/``subcategories`` keys
127+ (e.g. ``"5 Science"`` -> ``"Science"``).
128+ """
124129 from datasets import load_from_disk
125130
126131 dataset_dir = os .environ .get ("ROUTERARENA_DATASET_DIR" , str (REPO_ROOT / "dataset" ))
127132 ds = load_from_disk (str (Path (dataset_dir ) / "routerarena" ))
128- return {row ["Global Index" ]: str (row ["Difficulty" ]).lower () for row in ds }
133+ difficulty , domain , category = {}, {}, {}
134+ for row in ds :
135+ gi = row ["Global Index" ]
136+ difficulty [gi ] = str (row ["Difficulty" ]).lower ()
137+ domain [gi ] = re .sub (r"^\d+\s+" , "" , str (row ["Domain" ])).strip ()
138+ category [gi ] = re .sub (r"^\d+\s+" , "" , str (row ["Category" ])).strip ()
139+ return difficulty , domain , category
129140
130141
131142def _regular (predictions : list [dict [str , Any ]]) -> list [dict [str , Any ]]:
@@ -161,19 +172,18 @@ def compute_headline_metrics(
161172 }
162173
163174
164- def compute_category_scores (
165- predictions : list [dict [str , Any ]],
166- flip_labels : list [ dict [str , Any ] ],
175+ def _difficulty_metrics (
176+ entries : list [dict [str , Any ]],
177+ flip_map : dict [str , int ],
167178 difficulty : dict [str , str ],
168179) -> dict [str , dict [str , Any ]]:
169- """Per-difficulty (easy/medium/hard/all) accuracy, cost, robustness."""
170- flip_map = {item ["global index" ]: item ["flip" ] for item in flip_labels }
180+ """{easy/medium/hard/all: {accuracy, cost, robustness}} over the given entries."""
171181 buckets : dict [str , dict [str , list [float ]]] = {
172182 b : {"acc" : [], "cost" : []} for b in (* DIFFICULTY_BUCKETS , "all" )
173183 }
174184 flip_buckets : dict [str , list [int ]] = {b : [] for b in (* DIFFICULTY_BUCKETS , "all" )}
175185
176- for p in _regular ( predictions ) :
186+ for p in entries :
177187 gi = p .get ("global index" )
178188 acc = _numeric (p .get ("accuracy" )) or 0.0
179189 cost = _numeric (p .get ("cost" )) or 0.0
@@ -182,25 +192,68 @@ def compute_category_scores(
182192 for t in targets :
183193 buckets [t ]["acc" ].append (acc )
184194 buckets [t ]["cost" ].append (cost )
185- if gi in flip_map :
195+ if isinstance ( gi , str ) and gi in flip_map :
186196 for t in targets :
187197 flip_buckets [t ].append (flip_map [gi ])
188198
189199 out : dict [str , dict [str , Any ]] = {}
190200 for b in (* DIFFICULTY_BUCKETS , "all" ):
191- accs = buckets [b ]["acc" ]
192- costs = buckets [b ]["cost" ]
193- flips = flip_buckets [b ]
201+ accs , costs , flips = buckets [b ]["acc" ], buckets [b ]["cost" ], flip_buckets [b ]
202+ # Empty buckets are reported as 0 (matches the website's convention).
194203 out [b ] = {
195- "accuracy" : round (sum (accs ) / len (accs ) * 100 , 1 ) if accs else None ,
196- "cost" : round (sum (costs ) / len (costs ), 4 ) if costs else None ,
197- "robustness" : round ((1 - sum (flips ) / len (flips )) * 100 , 1 )
198- if flips
199- else None ,
204+ "accuracy" : round (sum (accs ) / len (accs ) * 100 , 1 ) if accs else 0 ,
205+ "cost" : round (sum (costs ) / len (costs ), 4 ) if costs else 0 ,
206+ "robustness" : round ((1 - sum (flips ) / len (flips )) * 100 , 1 ) if flips else 0 ,
200207 }
201208 return out
202209
203210
211+ def compute_category_scores (
212+ predictions : list [dict [str , Any ]],
213+ flip_labels : list [dict [str , Any ]],
214+ difficulty : dict [str , str ],
215+ domain : dict [str , str ],
216+ category : dict [str , str ],
217+ ) -> dict [str , Any ]:
218+ """Website per-router entry: overall + per-domain + per-subcategory metrics.
219+
220+ Mirrors the website schema:
221+ {"metrics": {<difficulty>},
222+ "categories": {<domain>: {"metrics": {<difficulty>},
223+ "subcategories": {<category>: {"metrics": {<difficulty>}}}}}}
224+ """
225+ flip_map = {item ["global index" ]: item ["flip" ] for item in flip_labels }
226+ regular = _regular (predictions )
227+
228+ by_domain : dict [str , list [dict [str , Any ]]] = {}
229+ for p in regular :
230+ gi = p .get ("global index" )
231+ dom = domain .get (gi ) if isinstance (gi , str ) else None
232+ if dom :
233+ by_domain .setdefault (dom , []).append (p )
234+
235+ categories : dict [str , Any ] = {}
236+ for dom , dom_entries in sorted (by_domain .items ()):
237+ by_cat : dict [str , list [dict [str , Any ]]] = {}
238+ for p in dom_entries :
239+ gi = p .get ("global index" )
240+ cat = category .get (gi ) if isinstance (gi , str ) else None
241+ if cat :
242+ by_cat .setdefault (cat , []).append (p )
243+ categories [dom ] = {
244+ "metrics" : _difficulty_metrics (dom_entries , flip_map , difficulty ),
245+ "subcategories" : {
246+ cat : {"metrics" : _difficulty_metrics (entries , flip_map , difficulty )}
247+ for cat , entries in sorted (by_cat .items ())
248+ },
249+ }
250+
251+ return {
252+ "metrics" : _difficulty_metrics (regular , flip_map , difficulty ),
253+ "categories" : categories ,
254+ }
255+
256+
204257def main (argv : Optional [list [str ]] = None ) -> int :
205258 parser = argparse .ArgumentParser (description = __doc__ )
206259 parser .add_argument (
@@ -233,7 +286,9 @@ def main(argv: Optional[list[str]] = None) -> int:
233286 json .loads (cs_path .read_text (encoding = "utf-8" )) if cs_path .exists () else {}
234287 )
235288
236- difficulty = {} if args .skip_category else load_difficulty_map ()
289+ difficulty , domain , category = (
290+ ({}, {}, {}) if args .skip_category else load_dataset_maps ()
291+ )
237292 updated , regenerated , missing = [], [], []
238293
239294 for meta in routers :
@@ -286,9 +341,9 @@ def main(argv: Optional[list[str]] = None) -> int:
286341 and cat_key
287342 and any (_numeric (p .get ("accuracy" )) is not None for p in _regular (preds ))
288343 ):
289- category_scores [cat_key ] = {
290- "metrics" : compute_category_scores ( preds , flips , difficulty )
291- }
344+ category_scores [cat_key ] = compute_category_scores (
345+ preds , flips , difficulty , domain , category
346+ )
292347 regenerated .append (prediction )
293348
294349 leaderboard .sort (
0 commit comments