Skip to content

Commit abbe3e5

Browse files
committed
feat: Integrate GPQA dataset with proper answer extraction and option shuffling
Add support for GPQA (Graduate-Level Google-Proof Q&A) dataset integration into the RouterArena evaluation pipeline. Key changes: - Add GPQA evaluation config (config/eval_config/zero-shot/GPQA.json) - Create prepare_gpqa_data.py script to: * Load GPQA dataset from HuggingFace (Idavidrein/gpqa) * Extract correct answers from \"Correct Answer\" field (fixes empty answer bug) * Shuffle MCQ options deterministically to distribute answers across A-D * Generate formatted prompts and ground truth files - Update evaluation pipeline (llm_evaluation/run.py, evaluate_models.py): * Add GPQA split handling in load_ground_truth_dataset() * Support GPQA dataset name detection - Update inference pipeline (llm_inference/run.py, model_inference.py): * Add GPQA split support for router predictions * Handle GPQA-specific data loading - Add OpenRouter router configuration - Update model cost configurations for GPQA models - Add universal model name mappings
1 parent a94b93f commit abbe3e5

10 files changed

Lines changed: 426 additions & 49 deletions

File tree

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
{
2+
"eval_params": {
3+
"dataset": "GPQA",
4+
"eval_metrics": [
5+
"mcq_accuracy"
6+
],
7+
"setting": "zero-shot",
8+
"prompt": "Please read the following multiple-choice questions and provide the most likely correct answer based on the options given.\n\nContext: {Context}\n\nQuestion: {Question}\n\nOptions: \n{Options}\n\nProvide the correct letter choice in \\boxed{{X}}, where X is the correct letter choice. Keep the explanation or feedback within 3 sentences."
9+
},
10+
"management": {
11+
"sub_dir": {
12+
"input_config": "input_config/",
13+
"raw_results": "raw_results.json",
14+
"result_vis": "result_vis.png",
15+
"output_config": "output_config.json"
16+
}
17+
}
18+
}

llm_evaluation/evaluate_models.py

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,15 @@ def load_dataset_configs(self):
144144
def load_cost_config(self):
145145
"""Load cost configuration from model_cost/cost.json"""
146146
# Try multiple possible paths for cost file
147+
# Get the directory of this file and construct paths relative to project root
148+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
149+
project_root = os.path.dirname(current_file_dir) # Go up from llm_evaluation/ to project root
150+
147151
possible_paths = [
148-
"./model_cost/cost.json",
149-
"../model_cost/cost.json",
150-
"model_cost/cost.json",
152+
os.path.join(project_root, "model_cost", "cost.json"), # From project root
153+
"./model_cost/cost.json", # Current working directory
154+
"../model_cost/cost.json", # Parent directory
155+
"model_cost/cost.json", # Relative to current dir
151156
]
152157

153158
cost_file = None
@@ -160,6 +165,7 @@ def load_cost_config(self):
160165
print(
161166
f"Warning: Could not find cost configuration file. Tried: {possible_paths}"
162167
)
168+
print(f"Current working directory: {os.getcwd()}")
163169
self.cost_config = {}
164170
return
165171

@@ -177,35 +183,45 @@ def calculate_inference_cost(
177183
self, model_name: str, token_usage: Dict[str, int]
178184
) -> float:
179185
"""Calculate inference cost based on token usage and model pricing."""
180-
if not token_usage or not self.cost_config:
186+
if not token_usage:
187+
return 0.0
188+
189+
if not self.cost_config:
190+
print("Warning: Cost config is empty!")
181191
return 0.0
182192

183193
# Remove _batch suffix if present for cost lookup
184194
cost_lookup_name = model_name
185195
if model_name.endswith("_batch"):
186196
cost_lookup_name = model_name[:-6] # Remove '_batch' suffix
187197

188-
# Normalize model name to match cost config
189-
if model_name_manager:
190-
normalized_name = model_name_manager.get_universal_name(cost_lookup_name)
198+
# Try exact match first (cost config uses original model names)
199+
if cost_lookup_name in self.cost_config:
200+
cost_info = self.cost_config[cost_lookup_name]
191201
else:
192-
normalized_name = cost_lookup_name
202+
# Normalize model name to match cost config
203+
if model_name_manager:
204+
normalized_name = model_name_manager.get_universal_name(cost_lookup_name)
205+
else:
206+
normalized_name = cost_lookup_name
193207

194-
# Try to find exact match first
195-
if normalized_name in self.cost_config:
196-
cost_info = self.cost_config[normalized_name]
197-
else:
198-
# Try to find partial matches
199-
cost_info = None
200-
for config_name in self.cost_config.keys():
201-
if config_name in normalized_name or normalized_name in config_name:
202-
cost_info = self.cost_config[config_name]
203-
break
208+
# Try to find exact match with normalized name
209+
if normalized_name in self.cost_config:
210+
cost_info = self.cost_config[normalized_name]
211+
else:
212+
# Try to find partial matches
213+
cost_info = None
214+
for config_name in self.cost_config.keys():
215+
if config_name in normalized_name or normalized_name in config_name:
216+
cost_info = self.cost_config[config_name]
217+
break
204218

205219
if not cost_info:
206220
print(
207-
f"Warning: No cost configuration found for model {model_name} (lookup: {cost_lookup_name}, normalized: {normalized_name})"
221+
f"Warning: No cost configuration found for model {model_name} (lookup: {cost_lookup_name})"
208222
)
223+
if len(self.cost_config) > 0:
224+
print(f"Available cost config keys (first 10): {list(self.cost_config.keys())[:10]}")
209225
return 0.0
210226

211227
# Calculate cost
@@ -239,6 +255,7 @@ def determine_dataset_from_global_index(self, global_index: str) -> str:
239255
"FinQA": "FinQA",
240256
"GeoBench": "GeoBench",
241257
"GeoGraphyData": "GeoGraphyData_100k", # Fix the dataset name
258+
"GPQA": "GPQA",
242259
"GSM8K": "GSM8K",
243260
"LiveCodeBench": "LiveCodeBench",
244261
"MATH": "MATH",
@@ -468,7 +485,18 @@ def _get_ground_truth(self, global_index: str, dataset_name: str) -> Optional[An
468485
except Exception as e:
469486
print(f"Error loading LiveCodeBench dataset: {e}")
470487
return None
471-
488+
elif dataset_name == "GPQA":
489+
gpqa_gt_path = "./dataset/gpqa_ground_truth.json"
490+
if os.path.exists(gpqa_gt_path):
491+
try:
492+
with open(gpqa_gt_path, "r", encoding="utf-8") as f:
493+
gpqa_data = json.load(f)
494+
for item in gpqa_data:
495+
if item.get("global_index") == global_index:
496+
return item["answer"]
497+
except Exception as e:
498+
print(f"Error loading GPQA ground truth: {e}")
499+
return None
472500
# For other datasets, find the entry with matching global_index
473501
if self.all_data is None:
474502
return None

llm_evaluation/run.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,23 @@ def compute_arena_score(cost, accuracy, beta=0.1, c_max=200, c_min=0.0044):
100100
return S
101101

102102

103-
def load_predictions_file(router_name: str) -> List[Dict[str, Any]]:
103+
def load_predictions_file(router_name: str, split: str | None = None) -> List[Dict[str, Any]]:
104104
"""
105105
Load router predictions from JSON file.
106106
107107
Args:
108108
router_name: Name of the router
109+
split: Dataset split (optional). Used to determine prediction file name.
109110
110111
Returns:
111112
List of prediction dictionaries
112113
"""
113-
prediction_path = f"./router_inference/predictions/{router_name}.json"
114+
# Construct prediction path based on split (same logic as llm_inference/run.py)
115+
if split and split in ["gpqa", "robustness"]:
116+
filename = f"{router_name}-{split}"
117+
else:
118+
filename = router_name
119+
prediction_path = f"./router_inference/predictions/{filename}.json"
114120

115121
if not os.path.exists(prediction_path):
116122
raise FileNotFoundError(
@@ -136,15 +142,21 @@ def load_predictions_from_path(path: str) -> List[Dict[str, Any]]:
136142
return json.load(f)
137143

138144

139-
def save_predictions_file(predictions: List[Dict[str, Any]], router_name: str) -> None:
145+
def save_predictions_file(predictions: List[Dict[str, Any]], router_name: str, split: str | None = None) -> None:
140146
"""
141147
Save predictions back to file.
142148
143149
Args:
144150
predictions: List of prediction dictionaries
145151
router_name: Name of the router
152+
split: Dataset split (optional). Used to determine prediction file name.
146153
"""
147-
prediction_path = f"./router_inference/predictions/{router_name}.json"
154+
# Construct filename based on split (same logic as load_predictions_file)
155+
if split and split in ["gpqa", "robustness"]:
156+
filename = f"{router_name}-{split}"
157+
else:
158+
filename = router_name
159+
prediction_path = f"./router_inference/predictions/{filename}.json"
148160

149161
# Create directory if it doesn't exist
150162
os.makedirs(os.path.dirname(prediction_path), exist_ok=True)
@@ -170,7 +182,33 @@ def load_ground_truth_dataset(split: str) -> Dict[str, Dict[str, Any]]:
170182
"""
171183
from datasets import load_from_disk
172184
import pandas as pd
173-
185+
ground_truth_map = {}
186+
187+
# Handle GPQA split
188+
if split == "gpqa":
189+
gpqa_gt_path = "./dataset/gpqa_ground_truth.json"
190+
if not os.path.exists(gpqa_gt_path):
191+
raise FileNotFoundError(
192+
f"GPQA ground truth not found at {gpqa_gt_path}. "
193+
f"Please create it using the preparation script."
194+
)
195+
logger.info(f"Loading GPQA ground truth from {gpqa_gt_path}...")
196+
with open(gpqa_gt_path, "r", encoding="utf-8") as f:
197+
gpqa_data = json.load(f)
198+
199+
for item in gpqa_data:
200+
global_index = item["global_index"]
201+
ground_truth_map[global_index] = {
202+
"question": item.get("question", ""),
203+
"global_index": global_index,
204+
"context": item.get("context", ""),
205+
"answer": item["answer"],
206+
"options": item.get("options", []),
207+
"metadata": item.get("metadata", {}),
208+
}
209+
210+
logger.info(f"Loaded {len(ground_truth_map)} GPQA ground truth samples")
211+
return ground_truth_map
174212
if split not in ["sub_10", "full"]:
175213
raise ValueError(f"Invalid split: {split}. Must be 'sub_10' or 'full'")
176214

@@ -354,9 +392,10 @@ def evaluate_single_prediction(
354392
)
355393

356394
# Calculate inference cost
395+
# Use original model name for cost lookup since cost config uses original names
357396
token_usage = generated_result.get("token_usage", {})
358397
inference_cost = evaluator.calculate_inference_cost(
359-
universal_model_name, token_usage
398+
model_name, token_usage # Use original model_name instead of universal_model_name
360399
)
361400

362401
# Update the prediction with evaluation results
@@ -396,7 +435,7 @@ def process_router_predictions(
396435
logger.info(f"Using {num_workers} worker threads for parallel processing")
397436

398437
# Load predictions
399-
predictions = load_predictions_file(router_name)
438+
predictions = load_predictions_file(router_name, split=split)
400439

401440
# Separate regular and optimality entries
402441
regular_predictions = [p for p in predictions if not p.get("for_optimality", False)]
@@ -439,11 +478,13 @@ def process_router_predictions(
439478
# Note: This loop runs in the main thread before threading starts, so no lock needed
440479
tasks = []
441480
for i, prediction in enumerate(predictions):
442-
# Check if already evaluated (has accuracy and cost)
481+
# Check if already evaluated (has accuracy and cost > 0)
443482
# Skip if already evaluated AND force is False
483+
# Note: cost > 0 check ensures costs were actually calculated (0.0 means not calculated)
444484
if not force and (
445485
prediction.get("accuracy") is not None
446486
and prediction.get("cost") is not None
487+
and prediction.get("cost", 0) > 0 # Cost must be > 0 to be considered evaluated
447488
):
448489
already_evaluated_count += 1
449490
evaluated_count += 1
@@ -494,7 +535,7 @@ def evaluate_task(seq_idx: int, prediction: Dict[str, Any]) -> bool:
494535
with save_lock:
495536
# Save the entire predictions list
496537
# This is safe because each thread modifies a different index
497-
save_predictions_file(predictions, router_name)
538+
save_predictions_file(predictions, router_name, split=split)
498539

499540
elapsed_time = (
500541
datetime.datetime.now() - start_time
@@ -542,7 +583,7 @@ def evaluate_task(seq_idx: int, prediction: Dict[str, Any]) -> bool:
542583

543584
# Final save
544585
with save_lock:
545-
save_predictions_file(predictions, router_name)
586+
save_predictions_file(predictions, router_name, split=split)
546587

547588
# Final summary
548589
end_time = datetime.datetime.now()
@@ -901,7 +942,7 @@ def run_robustness_only(router_name: str, robustness_path: Optional[str]) -> Non
901942
target_path,
902943
)
903944

904-
predictions = load_predictions_file(router_name)
945+
predictions = load_predictions_file(router_name, split=None) # Load base file for robustness
905946

906947
try:
907948
robustness_predictions = load_predictions_from_path(target_path)
@@ -1096,10 +1137,10 @@ def main():
10961137
"split",
10971138
nargs="?",
10981139
type=str,
1099-
choices=["sub_10", "full", "robustness"],
1140+
choices=["sub_10", "full", "robustness", "gpqa"],
11001141
help=(
11011142
"Dataset split to use for evaluation ('sub_10' for testing with answers, "
1102-
"'full' for submission, 'robustness' to compute robustness score only)."
1143+
"'full' for submission, 'robustness' to compute robustness score only, 'gpqa' for GPQA dataset)."
11031144
),
11041145
)
11051146
parser.add_argument(
@@ -1161,7 +1202,7 @@ def main():
11611202
# Run evaluation
11621203
try:
11631204
# If save_interval is 0, only save at the end
1164-
predictions = load_predictions_file(args.router_name)
1205+
predictions = load_predictions_file(args.router_name, split=args.split)
11651206
save_interval = (
11661207
args.save_interval if args.save_interval > 0 else len(predictions) + 1
11671208
)
@@ -1177,8 +1218,8 @@ def main():
11771218
logger.info("\nInterrupted by user. Saving partial results...")
11781219
try:
11791220
# Try to save current state if possible
1180-
predictions = load_predictions_file(args.router_name)
1181-
save_predictions_file(predictions, args.router_name)
1221+
predictions = load_predictions_file(args.router_name, split=args.split)
1222+
save_predictions_file(predictions, args.router_name, split=args.split)
11821223
logger.info("Partial results saved successfully.")
11831224
except Exception as e:
11841225
logger.warning(f"Could not save partial results: {e}")

llm_inference/model_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def _get_provider(self, model_name: str) -> str:
172172
"qwen/qwen3-vl-235b-a22b-instruct": "openrouter",
173173
"qwen/qwen3-coder": "openrouter",
174174
"x-ai/grok-code-fast-1": "openrouter",
175+
"xiaomi/mimo-v2-flash": "openrouter",
175176
"xiaomi/mimo-v2-flash:free": "openrouter",
176177
"openai/gpt-oss-120b": "openrouter",
177178
"qwen/qwen3-235b-a22b-2507": "openrouter",

0 commit comments

Comments
 (0)