@@ -100,17 +100,23 @@ def compute_arena_score(cost, accuracy, beta=0.1, c_max=200, c_min=0.0044):
100100 return S
101101
102102
103- def load_predictions_file (router_name : str ) -> List [Dict [str , Any ]]:
103+ def load_predictions_file (router_name : str , split : str | None = None ) -> List [Dict [str , Any ]]:
104104 """
105105 Load router predictions from JSON file.
106106
107107 Args:
108108 router_name: Name of the router
109+ split: Dataset split (optional). Used to determine prediction file name.
109110
110111 Returns:
111112 List of prediction dictionaries
112113 """
113- prediction_path = f"./router_inference/predictions/{ router_name } .json"
114+ # Construct prediction path based on split (same logic as llm_inference/run.py)
115+ if split and split in ["gpqa" , "robustness" ]:
116+ filename = f"{ router_name } -{ split } "
117+ else :
118+ filename = router_name
119+ prediction_path = f"./router_inference/predictions/{ filename } .json"
114120
115121 if not os .path .exists (prediction_path ):
116122 raise FileNotFoundError (
@@ -136,15 +142,21 @@ def load_predictions_from_path(path: str) -> List[Dict[str, Any]]:
136142 return json .load (f )
137143
138144
139- def save_predictions_file (predictions : List [Dict [str , Any ]], router_name : str ) -> None :
145+ def save_predictions_file (predictions : List [Dict [str , Any ]], router_name : str , split : str | None = None ) -> None :
140146 """
141147 Save predictions back to file.
142148
143149 Args:
144150 predictions: List of prediction dictionaries
145151 router_name: Name of the router
152+ split: Dataset split (optional). Used to determine prediction file name.
146153 """
147- prediction_path = f"./router_inference/predictions/{ router_name } .json"
154+ # Construct filename based on split (same logic as load_predictions_file)
155+ if split and split in ["gpqa" , "robustness" ]:
156+ filename = f"{ router_name } -{ split } "
157+ else :
158+ filename = router_name
159+ prediction_path = f"./router_inference/predictions/{ filename } .json"
148160
149161 # Create directory if it doesn't exist
150162 os .makedirs (os .path .dirname (prediction_path ), exist_ok = True )
@@ -170,7 +182,33 @@ def load_ground_truth_dataset(split: str) -> Dict[str, Dict[str, Any]]:
170182 """
171183 from datasets import load_from_disk
172184 import pandas as pd
173-
185+ ground_truth_map = {}
186+
187+ # Handle GPQA split
188+ if split == "gpqa" :
189+ gpqa_gt_path = "./dataset/gpqa_ground_truth.json"
190+ if not os .path .exists (gpqa_gt_path ):
191+ raise FileNotFoundError (
192+ f"GPQA ground truth not found at { gpqa_gt_path } . "
193+ f"Please create it using the preparation script."
194+ )
195+ logger .info (f"Loading GPQA ground truth from { gpqa_gt_path } ..." )
196+ with open (gpqa_gt_path , "r" , encoding = "utf-8" ) as f :
197+ gpqa_data = json .load (f )
198+
199+ for item in gpqa_data :
200+ global_index = item ["global_index" ]
201+ ground_truth_map [global_index ] = {
202+ "question" : item .get ("question" , "" ),
203+ "global_index" : global_index ,
204+ "context" : item .get ("context" , "" ),
205+ "answer" : item ["answer" ],
206+ "options" : item .get ("options" , []),
207+ "metadata" : item .get ("metadata" , {}),
208+ }
209+
210+ logger .info (f"Loaded { len (ground_truth_map )} GPQA ground truth samples" )
211+ return ground_truth_map
174212 if split not in ["sub_10" , "full" ]:
175213 raise ValueError (f"Invalid split: { split } . Must be 'sub_10' or 'full'" )
176214
@@ -354,9 +392,10 @@ def evaluate_single_prediction(
354392 )
355393
356394 # Calculate inference cost
395+ # Use original model name for cost lookup since cost config uses original names
357396 token_usage = generated_result .get ("token_usage" , {})
358397 inference_cost = evaluator .calculate_inference_cost (
359- universal_model_name , token_usage
398+ model_name , token_usage # Use original model_name instead of universal_model_name
360399 )
361400
362401 # Update the prediction with evaluation results
@@ -396,7 +435,7 @@ def process_router_predictions(
396435 logger .info (f"Using { num_workers } worker threads for parallel processing" )
397436
398437 # Load predictions
399- predictions = load_predictions_file (router_name )
438+ predictions = load_predictions_file (router_name , split = split )
400439
401440 # Separate regular and optimality entries
402441 regular_predictions = [p for p in predictions if not p .get ("for_optimality" , False )]
@@ -439,11 +478,13 @@ def process_router_predictions(
439478 # Note: This loop runs in the main thread before threading starts, so no lock needed
440479 tasks = []
441480 for i , prediction in enumerate (predictions ):
442- # Check if already evaluated (has accuracy and cost)
481+ # Check if already evaluated (has accuracy and cost > 0 )
443482 # Skip if already evaluated AND force is False
483+ # Note: cost > 0 check ensures costs were actually calculated (0.0 means not calculated)
444484 if not force and (
445485 prediction .get ("accuracy" ) is not None
446486 and prediction .get ("cost" ) is not None
487+ and prediction .get ("cost" , 0 ) > 0 # Cost must be > 0 to be considered evaluated
447488 ):
448489 already_evaluated_count += 1
449490 evaluated_count += 1
@@ -494,7 +535,7 @@ def evaluate_task(seq_idx: int, prediction: Dict[str, Any]) -> bool:
494535 with save_lock :
495536 # Save the entire predictions list
496537 # This is safe because each thread modifies a different index
497- save_predictions_file (predictions , router_name )
538+ save_predictions_file (predictions , router_name , split = split )
498539
499540 elapsed_time = (
500541 datetime .datetime .now () - start_time
@@ -542,7 +583,7 @@ def evaluate_task(seq_idx: int, prediction: Dict[str, Any]) -> bool:
542583
543584 # Final save
544585 with save_lock :
545- save_predictions_file (predictions , router_name )
586+ save_predictions_file (predictions , router_name , split = split )
546587
547588 # Final summary
548589 end_time = datetime .datetime .now ()
@@ -901,7 +942,7 @@ def run_robustness_only(router_name: str, robustness_path: Optional[str]) -> Non
901942 target_path ,
902943 )
903944
904- predictions = load_predictions_file (router_name )
945+ predictions = load_predictions_file (router_name , split = None ) # Load base file for robustness
905946
906947 try :
907948 robustness_predictions = load_predictions_from_path (target_path )
@@ -1096,10 +1137,10 @@ def main():
10961137 "split" ,
10971138 nargs = "?" ,
10981139 type = str ,
1099- choices = ["sub_10" , "full" , "robustness" ],
1140+ choices = ["sub_10" , "full" , "robustness" , "gpqa" ],
11001141 help = (
11011142 "Dataset split to use for evaluation ('sub_10' for testing with answers, "
1102- "'full' for submission, 'robustness' to compute robustness score only)."
1143+ "'full' for submission, 'robustness' to compute robustness score only, 'gpqa' for GPQA dataset )."
11031144 ),
11041145 )
11051146 parser .add_argument (
@@ -1161,7 +1202,7 @@ def main():
11611202 # Run evaluation
11621203 try :
11631204 # If save_interval is 0, only save at the end
1164- predictions = load_predictions_file (args .router_name )
1205+ predictions = load_predictions_file (args .router_name , split = args . split )
11651206 save_interval = (
11661207 args .save_interval if args .save_interval > 0 else len (predictions ) + 1
11671208 )
@@ -1177,8 +1218,8 @@ def main():
11771218 logger .info ("\n Interrupted by user. Saving partial results..." )
11781219 try :
11791220 # Try to save current state if possible
1180- predictions = load_predictions_file (args .router_name )
1181- save_predictions_file (predictions , args .router_name )
1221+ predictions = load_predictions_file (args .router_name , split = args . split )
1222+ save_predictions_file (predictions , args .router_name , split = args . split )
11821223 logger .info ("Partial results saved successfully." )
11831224 except Exception as e :
11841225 logger .warning (f"Could not save partial results: { e } " )
0 commit comments