Skip to content

Commit 8a5870e

Browse files
committed
Merge origin/main into research-pipeline - resolved conflicts and aligned with README.md (using sub_10)
2 parents 0d36c5d + 68e4bbd commit 8a5870e

7 files changed

Lines changed: 116 additions & 56 deletions

File tree

llm_evaluation/eval_reasoning.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import os
55
import json
6+
from typing import Any, Dict
67
from metrics import (
78
mcq_exact_match,
89
mcq_accuracy,
@@ -201,8 +202,8 @@ def eval(pred_dir, eval_params, pipeline_config, all_data):
201202
scores: Dictionary of evaluation scores
202203
raw_results: Detailed results for each prediction
203204
"""
204-
scores = dict()
205-
all_raw_results = dict()
205+
scores: Dict[str, float] = {}
206+
all_raw_results: Dict[str, Dict[str, Any]] = {}
206207

207208
# Get the appropriate scorers for this dataset and metrics
208209
dataset_name = eval_params["dataset"]

llm_evaluation/evaluate_models.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def load_all_data(self):
9898
print("Loading ground truth data...")
9999
try:
100100
# Load data directly without LiveCodeBench dependency
101-
from datasets import load_from_disk # type: ignore[import-untyped]
102-
import pandas as pd # type: ignore[import-untyped]
101+
from datasets import load_from_disk
102+
import pandas as pd
103103

104104
# Load the router eval benchmark dataset
105105
router_eval_bench = load_from_disk("./dataset/routerarena")
@@ -368,14 +368,21 @@ def evaluate_model(self, model_name: str, rerun=False) -> Dict[str, Any]:
368368

369369
# Evaluate each entry in this dataset
370370
for entry in dataset_entries:
371-
global_index = entry.get("global_index")
371+
global_index_val = entry.get("global_index")
372372
generated_answer = entry.get("generated_answer", "")
373373

374374
try:
375375
# Get ground truth for this entry
376-
ground_truth = self._get_ground_truth(global_index, dataset_name)
376+
if not isinstance(global_index_val, str):
377+
print(
378+
f"Warning: Invalid global_index {global_index_val} for dataset {dataset_name}"
379+
)
380+
continue
381+
ground_truth = self._get_ground_truth(
382+
global_index_val, dataset_name
383+
)
377384
if ground_truth is None:
378-
print(f"Warning: No ground truth found for {global_index}")
385+
print(f"Warning: No ground truth found for {global_index_val}")
379386
continue
380387

381388
# Evaluate using the appropriate scorer
@@ -415,7 +422,7 @@ def evaluate_model(self, model_name: str, rerun=False) -> Dict[str, Any]:
415422
"metric": "error",
416423
"inference_cost": 0.0,
417424
}
418-
print(f"Error evaluating {global_index}: {e}")
425+
print(f"Error evaluating {global_index_val}: {e}")
419426
continue
420427

421428
dataset_scores[dataset_name] = len(dataset_entries)
@@ -432,7 +439,7 @@ def evaluate_model(self, model_name: str, rerun=False) -> Dict[str, Any]:
432439
print(f"Evaluation completed. Evaluated {evaluated_count} new entries.")
433440
return self._compile_final_results(universal_model_name, cached_results)
434441

435-
def _get_ground_truth(self, global_index: str, dataset_name: str) -> Optional[str]:
442+
def _get_ground_truth(self, global_index: str, dataset_name: str) -> Optional[Any]:
436443
"""Get ground truth for a specific global_index from the dataset."""
437444
# Load dataset if not already loaded
438445
if self.all_data is None:
@@ -475,7 +482,7 @@ def _get_ground_truth(self, global_index: str, dataset_name: str) -> Optional[st
475482
return None
476483

477484
def _evaluate_single_entry(
478-
self, generated_answer: str, ground_truth: str, scorer, dataset_name: str
485+
self, generated_answer: str, ground_truth: Any, scorer, dataset_name: str
479486
) -> tuple:
480487
"""Evaluate a single entry using the appropriate scorer."""
481488
try:
@@ -593,7 +600,10 @@ def main():
593600

594601
args = parser.parse_args()
595602

596-
universal_name = model_name_manager.get_universal_name(args.model_name)
603+
if model_name_manager is not None:
604+
universal_name = model_name_manager.get_universal_name(args.model_name)
605+
else:
606+
universal_name = args.model_name
597607
print(f"Input model name: {args.model_name}")
598608
print(f"Universal model name: {universal_name}")
599609

llm_evaluation/livecodebench_util.py

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020
import time
2121
import zlib
2222
from io import StringIO
23-
from typing import Optional
23+
from typing import Optional, Any, Dict
24+
25+
# Global storage for original references used by reliability_guard
26+
originals: Dict[str, Any] = {}
2427

2528

2629
def has_code(response):
@@ -126,8 +129,7 @@ def post_process_tests_inputs(raw_text, is_stdin):
126129

127130
# If no matches are found, fall back to line-by-line parsing
128131
cleaned_lines = cleaned_string.split("\n")
129-
if test_cases is None:
130-
test_cases = []
132+
test_cases = []
131133
for line in cleaned_lines:
132134
try:
133135
test_case = json.loads(line)
@@ -229,10 +231,9 @@ def prepare_test_input_output_std(test_case):
229231

230232
def run_test_func(completion, is_extracted, test_input, test_output):
231233
# print(f"inside: {completion}")
234+
# Define the namespace in which to execute the completion code
235+
namespace: Dict[str, Any] = {}
232236
if not is_extracted:
233-
# Define the namespace in which to execute the completion code
234-
namespace = {}
235-
236237
# Execute the generated code in the namespace
237238

238239
exec(completion, namespace)
@@ -273,8 +274,6 @@ def run_test_func(completion, is_extracted, test_input, test_output):
273274

274275
return True, result_output
275276
else:
276-
namespace = {}
277-
278277
# Execute the generated code in the namespace
279278

280279
exec(completion, namespace)
@@ -313,7 +312,7 @@ def run_test_std(completion, test_input, test_output):
313312
# Simulate that the code is being run as the main script
314313
completion = '__name__ = "__main__"\n' + completion
315314

316-
namespace = {}
315+
namespace: Dict[str, Any] = {}
317316
exec(completion, namespace)
318317

319318
output_value = output.getvalue().strip()
@@ -409,7 +408,7 @@ def swallow_io(redirect_input=True):
409408

410409
with contextlib.redirect_stdout(stream), contextlib.redirect_stderr(stream):
411410
if redirect_input:
412-
with contextlib.redirect_stdin(StringIO()): # Redirect stdin if enabled
411+
with redirect_stdin(StringIO()): # Redirect stdin if enabled
413412
yield stream
414413
else:
415414
yield stream # Do not redirect stdin
@@ -443,8 +442,18 @@ def readable(self, *args, **kwargs):
443442
return False
444443

445444

446-
class redirect_stdin(contextlib._RedirectStream): # type: ignore
447-
_stream = "stdin"
445+
class redirect_stdin:
446+
def __init__(self, new_target: Any):
447+
self._new_target = new_target
448+
self._old_target: Any = None
449+
450+
def __enter__(self) -> Any:
451+
self._old_target = sys.stdin
452+
sys.stdin = self._new_target
453+
return self._new_target
454+
455+
def __exit__(self, exc_type, exc, tb) -> None:
456+
sys.stdin = self._old_target
448457

449458

450459
@contextlib.contextmanager
@@ -484,7 +493,9 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
484493
builtins.exit = cast(Any, None) # type: ignore[assignment]
485494
builtins.quit = cast(Any, None) # type: ignore[assignment]
486495

487-
import os
496+
# Prepare Any-typed aliases to avoid mypy assignment errors
497+
os_mod: Any = os
498+
subprocess_mod: Any = subprocess
488499

489500
os.environ["OMP_NUM_THREADS"] = "1"
490501

@@ -516,20 +527,51 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
516527
os.getcwd = cast(Any, None) # type: ignore[assignment]
517528
os.chdir = cast(Any, None) # type: ignore[assignment]
518529

519-
import shutil
530+
# Disable destructive os functions (guard where platform-specific)
531+
for name in [
532+
"kill",
533+
"system",
534+
"putenv",
535+
"remove",
536+
"removedirs",
537+
"rmdir",
538+
"fchdir",
539+
"setuid",
540+
"fork",
541+
"forkpty",
542+
"killpg",
543+
"rename",
544+
"renames",
545+
"truncate",
546+
"replace",
547+
"unlink",
548+
"fchmod",
549+
"fchown",
550+
"chmod",
551+
"chown",
552+
"chroot",
553+
"getcwd",
554+
"chdir",
555+
"lchflags",
556+
"lchmod",
557+
"lchown",
558+
]:
559+
try:
560+
setattr(os_mod, name, None)
561+
except Exception:
562+
pass
520563

521564
shutil.rmtree = cast(Any, None) # type: ignore[assignment]
522565
shutil.move = cast(Any, None) # type: ignore[assignment]
523566
shutil.chown = cast(Any, None) # type: ignore[assignment]
524567

525-
import subprocess
568+
# Disable subprocess.Popen
569+
setattr(subprocess_mod, "Popen", None)
526570

527571
setattr(subprocess, "Popen", cast(Any, None)) # type: ignore[misc]
528572

529573
# __builtins__["help"] = None # this line is commented out as it results into error
530574

531-
import sys
532-
533575
sys.modules["ipdb"] = None # type: ignore[assignment]
534576
sys.modules["joblib"] = None # type: ignore[assignment]
535577
sys.modules["resource"] = None # type: ignore[assignment]
@@ -600,7 +642,7 @@ def restore_original_references():
600642
setattr(shutil, func_name, original_func)
601643

602644
# Restore 'subprocess' functions
603-
subprocess.Popen = originals["subprocess"]["Popen"]
645+
setattr(subprocess, "Popen", originals["subprocess"]["Popen"])
604646

605647
# Restore sys modules
606648
for module_name, original_module in originals["sys_modules"].items():

llm_evaluation/metrics.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def classification_score(prediction, ground_truth, **kwargs):
130130
score = 0.0
131131
else:
132132
best_match = None
133-
highest_similarity = 0
133+
highest_similarity = 0.0
134134
for string in all_classes:
135135
similarity = difflib.SequenceMatcher(None, string, prediction).ratio()
136136
if similarity > highest_similarity:
@@ -528,24 +528,25 @@ def math_equal(
528528

529529
try: # 1. numerical equal
530530
if is_digit(prediction) and is_digit(reference):
531-
prediction = parse_digits(prediction)
532-
reference = parse_digits(reference)
533-
# number questions
534-
if include_percentage:
535-
gt_result = [reference / 100, reference, reference * 100]
536-
else:
537-
gt_result = [reference]
538-
for item in gt_result:
539-
try:
540-
if is_close:
541-
if numeric_equal(prediction, item):
542-
return True
543-
else:
544-
if item == prediction:
545-
return True
546-
except Exception:
547-
continue
548-
return False
531+
pred_val = parse_digits(prediction)
532+
ref_val = parse_digits(reference)
533+
if pred_val is not None and ref_val is not None:
534+
# number questions
535+
if include_percentage:
536+
gt_result: list[float] = [ref_val / 100, ref_val, ref_val * 100]
537+
else:
538+
gt_result = [ref_val]
539+
for item in gt_result:
540+
try:
541+
if is_close:
542+
if numeric_equal(pred_val, item):
543+
return True
544+
else:
545+
if item == pred_val:
546+
return True
547+
except Exception:
548+
continue
549+
return False
549550
except Exception:
550551
pass
551552

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,9 @@ packages = ["."]
100100

101101
[tool.uv]
102102
dev-dependencies = []
103+
104+
[tool.mypy]
105+
plugins = ['pydantic.mypy']
106+
ignore_missing_imports = true
107+
check_untyped_defs = true
108+
follow_imports = "silent"

router_inference/check_config_prediction_files.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
Usage:
1616
python router_inference/check_config_prediction_files.py <router_name> <split>
1717
18-
split: either "10" for 10% split or "full" for full dataset
18+
split: either "sub_10" for 10% split or "full" for full dataset
1919
"""
2020

2121
import argparse
@@ -89,15 +89,15 @@ def load_dataset(split: str) -> List[Dict[str, Any]]:
8989
Load dataset file.
9090
9191
Args:
92-
split: Either "10" or "full"
92+
split: Either "sub_10" or "full"
9393
9494
Returns:
9595
List of dataset entries
9696
"""
9797
dataset_path = DATASET_PATHS.get(split)
9898

9999
if not dataset_path:
100-
raise ValueError(f"Invalid split: {split}. Must be '10' or 'full'")
100+
raise ValueError(f"Invalid split: {split}. Must be 'sub_10' or 'full'")
101101

102102
if not os.path.exists(dataset_path):
103103
raise FileNotFoundError(f"Dataset file not found: {dataset_path}")
@@ -143,15 +143,15 @@ def check_prediction_size(
143143
144144
Args:
145145
predictions: List of prediction dictionaries
146-
split: Either "10" or "full"
146+
split: Either "sub_10" or "full"
147147
148148
Returns:
149149
Tuple of (is_valid, error_message)
150150
"""
151151
expected_size = EXPECTED_SIZES.get(split)
152152

153153
if expected_size is None:
154-
return False, f"Invalid split: {split}. Must be '10' or 'full'"
154+
return False, f"Invalid split: {split}. Must be 'sub_10' or 'full'"
155155

156156
actual_size = len(predictions)
157157

router_inference/generate_prediction_file.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
Usage:
1212
python router_inference/generate_prediction_file.py <router_name> <split>
1313
14-
split: either "10" for 10% split (809 entries) or "full" (8400 entries)
14+
split: either "sub_10" for 10% split (809 entries) or "full" (8400 entries)
1515
"""
1616

1717
import argparse
@@ -56,15 +56,15 @@ def load_dataset(split: str) -> List[Dict[str, Any]]:
5656
Load dataset file.
5757
5858
Args:
59-
split: Either "10" or "full"
59+
split: Either "sub_10" or "full"
6060
6161
Returns:
6262
List of dataset entries
6363
"""
6464
dataset_path = DATASET_PATHS.get(split)
6565

6666
if not dataset_path:
67-
raise ValueError(f"Invalid split: {split}. Must be '10' or 'full'")
67+
raise ValueError(f"Invalid split: {split}. Must be 'sub_10' or 'full'")
6868

6969
if not os.path.exists(dataset_path):
7070
raise FileNotFoundError(f"Dataset file not found: {dataset_path}")

0 commit comments

Comments
 (0)