Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 55 additions & 12 deletions .github/workflows/pr-evaluation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,54 @@ jobs:
contents: read
pull-requests: write
steps:
- name: Checkout repository
- name: Checkout PR branch for file detection
uses: actions/checkout@v4
with:
ref: ${{ github.event.pull_request.head.sha }}
fetch-depth: 0

- name: Detect new prediction file
- name: Detect changed prediction file
id: detect
shell: bash
run: |
set -euo pipefail
git fetch origin main
NEW_FILES=$(git diff --name-status origin/main...HEAD -- router_inference/predictions/*.json | awk '$1 == "A" {print $2}')
if [[ -z "$NEW_FILES" ]]; then
echo "No newly added prediction file detected; skipping evaluation."
# Compare against the fork's base branch (the branch the PR was created from)
# This ensures each router submission is evaluated independently
BASE_REF="${{ github.event.pull_request.base.ref }}"
BASE_SHA="${{ github.event.pull_request.base.sha }}"

if [[ -z "$BASE_SHA" ]]; then
echo "Error: Could not determine PR base SHA" >&2
exit 1
fi

# Fetch the base branch to ensure it's available for comparison
git fetch origin "$BASE_REF" || true

# Try to fetch the specific base SHA if it's not already available
if ! git cat-file -e "$BASE_SHA" 2>/dev/null; then
echo "Base SHA $BASE_SHA not found locally, attempting to fetch..."
git fetch origin "$BASE_SHA" || git fetch origin "$BASE_REF" || true
fi

# For PRs from forks, we want to compare against the fork's base branch state
# Use three-dot diff to show changes from merge-base to HEAD (only PR changes)
# This isolates the evaluation to changes in this specific fork submission
CHANGED_FILES=$(git diff --name-status "$BASE_SHA"...HEAD -- router_inference/predictions/*.json 2>&1 | awk '$1 == "A" || $1 == "M" {print $2}')
if [[ -z "$CHANGED_FILES" ]]; then
echo "No changed prediction file detected; skipping evaluation."
echo "router=" >> "$GITHUB_OUTPUT"
exit 0
fi
if [[ $(echo "$NEW_FILES" | wc -l) -ne 1 ]]; then
echo "Expected exactly one new prediction file, found:" >&2
echo "$NEW_FILES" >&2
if [[ $(echo "$CHANGED_FILES" | wc -l) -ne 1 ]]; then
echo "Expected exactly one changed prediction file, found:" >&2
echo "$CHANGED_FILES" >&2
exit 1
fi
ROUTER_NAME=$(basename "$NEW_FILES" .json)
ROUTER_NAME=$(basename "$CHANGED_FILES" .json)
echo "router=$ROUTER_NAME" >> "$GITHUB_OUTPUT"

# Detect split based on prediction file size
# Detect split based on prediction file size (from PR branch)
PREDICTION_FILE="router_inference/predictions/${ROUTER_NAME}.json"
if [[ ! -f "$PREDICTION_FILE" ]]; then
echo "Error: Prediction file not found at $PREDICTION_FILE" >&2
Expand All @@ -57,6 +79,24 @@ jobs:
fi
echo "split=$SPLIT" >> "$GITHUB_OUTPUT"

- name: Continue using PR branch for evaluation
if: ${{ steps.detect.outputs.router != '' }}
run: |
set -euo pipefail
# We stay on the PR branch to use the code from the PR
# This allows the PR to include both router submissions AND code improvements
# The prediction file is already available from the detection step
ROUTER_NAME="${{ steps.detect.outputs.router }}"
PREDICTION_FILE="router_inference/predictions/${ROUTER_NAME}.json"

# Verify the file exists and has content
if [[ ! -f "$PREDICTION_FILE" ]]; then
echo "Error: Prediction file not found at $PREDICTION_FILE" >&2
exit 1
fi
echo "Using PR branch for evaluation (includes both router submission and code changes)"
echo "Prediction file ready: $PREDICTION_FILE"

- name: Show detected router
if: ${{ steps.detect.outputs.router != '' }}
run: |
Expand All @@ -80,10 +120,13 @@ jobs:
ROUTERARENA_DATASET_DIR: ${{ github.workspace }}/dataset
run: |
set -euo pipefail
# Use the PR's base branch SHA for comparison (fork's base, not upstream main)
BASE_SHA="${{ github.event.pull_request.base.sha }}"
uv run python automation/process_pr_submission.py \
--pr "${{ github.event.pull_request.number }}" \
--router "${{ steps.detect.outputs.router }}" \
--split "${{ steps.detect.outputs.split }}" > evaluation_output.txt 2>&1
--split "${{ steps.detect.outputs.split }}" \
--base-ref "$BASE_SHA" > evaluation_output.txt 2>&1
# Extract metrics from output
if grep -q "Metrics:" evaluation_output.txt; then
python3 automation/extract_metrics.py evaluation_output.txt
Expand Down
10 changes: 6 additions & 4 deletions automation/process_pr_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def cleanup_worktree(worktree_path: Path, branch_name: str, *, keep: bool) -> No
def ensure_prediction_file_added(
worktree_path: Path, base_ref: str, router_name: str
) -> None:
"""Verify the PR adds a new prediction file for the specified router."""
"""Verify the PR adds or modifies a prediction file for the specified router."""

target_path = Path("router_inference") / "predictions" / f"{router_name}.json"

Expand All @@ -174,14 +174,15 @@ def ensure_prediction_file_added(

lines = [line.strip() for line in completed.stdout.splitlines() if line.strip()]
for line in lines:
if line.startswith("A\t") or line.startswith("A "):
# Allow both added (A) and modified (M) files
if line[0] in ("A", "M"):
return

raise RuntimeError(
textwrap.dedent(
f"""
Expected pull request to add a new prediction file {target_path}.
Diff against {base_ref} did not show a newly added file.
Expected pull request to add or modify a prediction file {target_path}.
Diff against {base_ref} did not show a newly added or modified file.
"""
).strip()
)
Expand Down Expand Up @@ -396,6 +397,7 @@ def main(argv: Optional[list[str]] = None) -> int:
"llm_evaluation/run.py",
args.router,
args.split,
"--force",
]

evaluation_logs = ""
Expand Down
24 changes: 18 additions & 6 deletions llm_evaluation/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,11 @@ def evaluate_single_prediction(


def process_router_predictions(
router_name: str, split: str, save_interval: int = 50, num_workers: int = 4
router_name: str,
split: str,
save_interval: int = 50,
num_workers: int = 4,
force: bool = False,
) -> None:
"""
Process router predictions by evaluating generated results with incremental saving.
Expand All @@ -370,6 +374,7 @@ def process_router_predictions(
split: Dataset split ("sub_10" or "full")
save_interval: Number of entries to process before saving (default: 50)
num_workers: Number of worker threads for parallel processing (default: 4)
force: If True, re-evaluate all entries even if already evaluated (default: False)
"""
logger.info(f"Starting LLM evaluation for router: {router_name} (split: {split})")
logger.info(f"Using {num_workers} worker threads for parallel processing")
Expand Down Expand Up @@ -401,12 +406,13 @@ def process_router_predictions(
"The dataset contains entries from LiveCodeBench, and it is common to wait for ~10 minutes to evaluate the sub_10 split of the dataset."
)

# Prepare tasks: filter out already evaluated entries
# Prepare tasks: filter out already evaluated entries (unless force is True)
# Note: This loop runs in the main thread before threading starts, so no lock needed
tasks = []
for i, prediction in enumerate(predictions):
# Check if already evaluated (has accuracy and cost)
if (
# Skip if already evaluated AND force is False
if not force and (
prediction.get("accuracy") is not None
and prediction.get("cost") is not None
):
Expand Down Expand Up @@ -640,8 +646,14 @@ def main():
parser.add_argument(
"--num-workers",
type=int,
default=8,
help="Number of worker threads for parallel processing (default: 4). Set to 1 for sequential processing.",
default=16,
help="Number of worker threads for parallel processing (default: 16). Set to 1 for sequential processing.",
)
parser.add_argument(
"--force",
action="store_true",
default=False,
help="Force re-evaluation of all entries, even if already evaluated (default: False)",
)

args = parser.parse_args()
Expand All @@ -666,7 +678,7 @@ def main():
args.save_interval if args.save_interval > 0 else len(predictions) + 1
)
process_router_predictions(
args.router_name, args.split, save_interval, args.num_workers
args.router_name, args.split, save_interval, args.num_workers, args.force
)
except KeyboardInterrupt:
logger.info("\nInterrupted by user. Saving partial results...")
Expand Down
2 changes: 1 addition & 1 deletion router_inference/predictions/glm-4-air-router.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"prompt": "Please read the following multiple-choice questions and provide the most likely correct answer based on the options given.\n\nContext: None\n\nQuestion: The average life expectancy of the most of firm is around ( ) years.\n\nOptions: \nA. 10\nB. 20\nC. 30\nD. 40\n\n\nProvide the correct letter choice in \\boxed{X}, where X is the correct letter choice. Keep the explanation or feedback within 3 sentences.",
"prediction": "glm-4-air",
"generated_result": {
"generated_answer": "\\boxed{A}\n\nThe average life expectancy for a firm is typically around 10 years. This is due to the high failure rate of businesses within the first few years of operation. Larger corporations may have longer life expectancies, but the general average is lower.",
"generated_answer": "\\boxed{A}\n\nThe average life expectancy for a firm was typically around 10 years. This is due to the high failure rate of businesses within the first few years of operation. Larger corporations may have longer life expectancies, but the general average is lower.",
"success": true,
"token_usage": {
"input_tokens": 98,
Expand Down