diff --git a/.github/workflows/pr-evaluation.yml b/.github/workflows/pr-evaluation.yml index 38ad1761..3d177bca 100644 --- a/.github/workflows/pr-evaluation.yml +++ b/.github/workflows/pr-evaluation.yml @@ -13,32 +13,54 @@ jobs: contents: read pull-requests: write steps: - - name: Checkout repository + - name: Checkout PR branch for file detection uses: actions/checkout@v4 with: + ref: ${{ github.event.pull_request.head.sha }} fetch-depth: 0 - - name: Detect new prediction file + - name: Detect changed prediction file id: detect shell: bash run: | set -euo pipefail - git fetch origin main - NEW_FILES=$(git diff --name-status origin/main...HEAD -- router_inference/predictions/*.json | awk '$1 == "A" {print $2}') - if [[ -z "$NEW_FILES" ]]; then - echo "No newly added prediction file detected; skipping evaluation." + # Compare against the fork's base branch (the branch the PR was created from) + # This ensures each router submission is evaluated independently + BASE_REF="${{ github.event.pull_request.base.ref }}" + BASE_SHA="${{ github.event.pull_request.base.sha }}" + + if [[ -z "$BASE_SHA" ]]; then + echo "Error: Could not determine PR base SHA" >&2 + exit 1 + fi + + # Fetch the base branch to ensure it's available for comparison + git fetch origin "$BASE_REF" || true + + # Try to fetch the specific base SHA if it's not already available + if ! git cat-file -e "$BASE_SHA" 2>/dev/null; then + echo "Base SHA $BASE_SHA not found locally, attempting to fetch..." + git fetch origin "$BASE_SHA" || git fetch origin "$BASE_REF" || true + fi + + # For PRs from forks, we want to compare against the fork's base branch state + # Use three-dot diff to show changes from merge-base to HEAD (only PR changes) + # This isolates the evaluation to changes in this specific fork submission + CHANGED_FILES=$(git diff --name-status "$BASE_SHA"...HEAD -- router_inference/predictions/*.json 2>&1 | awk '$1 == "A" || $1 == "M" {print $2}') + if [[ -z "$CHANGED_FILES" ]]; then + echo "No changed prediction file detected; skipping evaluation." echo "router=" >> "$GITHUB_OUTPUT" exit 0 fi - if [[ $(echo "$NEW_FILES" | wc -l) -ne 1 ]]; then - echo "Expected exactly one new prediction file, found:" >&2 - echo "$NEW_FILES" >&2 + if [[ $(echo "$CHANGED_FILES" | wc -l) -ne 1 ]]; then + echo "Expected exactly one changed prediction file, found:" >&2 + echo "$CHANGED_FILES" >&2 exit 1 fi - ROUTER_NAME=$(basename "$NEW_FILES" .json) + ROUTER_NAME=$(basename "$CHANGED_FILES" .json) echo "router=$ROUTER_NAME" >> "$GITHUB_OUTPUT" - # Detect split based on prediction file size + # Detect split based on prediction file size (from PR branch) PREDICTION_FILE="router_inference/predictions/${ROUTER_NAME}.json" if [[ ! -f "$PREDICTION_FILE" ]]; then echo "Error: Prediction file not found at $PREDICTION_FILE" >&2 @@ -57,6 +79,24 @@ jobs: fi echo "split=$SPLIT" >> "$GITHUB_OUTPUT" + - name: Continue using PR branch for evaluation + if: ${{ steps.detect.outputs.router != '' }} + run: | + set -euo pipefail + # We stay on the PR branch to use the code from the PR + # This allows the PR to include both router submissions AND code improvements + # The prediction file is already available from the detection step + ROUTER_NAME="${{ steps.detect.outputs.router }}" + PREDICTION_FILE="router_inference/predictions/${ROUTER_NAME}.json" + + # Verify the file exists and has content + if [[ ! -f "$PREDICTION_FILE" ]]; then + echo "Error: Prediction file not found at $PREDICTION_FILE" >&2 + exit 1 + fi + echo "Using PR branch for evaluation (includes both router submission and code changes)" + echo "Prediction file ready: $PREDICTION_FILE" + - name: Show detected router if: ${{ steps.detect.outputs.router != '' }} run: | @@ -80,10 +120,13 @@ jobs: ROUTERARENA_DATASET_DIR: ${{ github.workspace }}/dataset run: | set -euo pipefail + # Use the PR's base branch SHA for comparison (fork's base, not upstream main) + BASE_SHA="${{ github.event.pull_request.base.sha }}" uv run python automation/process_pr_submission.py \ --pr "${{ github.event.pull_request.number }}" \ --router "${{ steps.detect.outputs.router }}" \ - --split "${{ steps.detect.outputs.split }}" > evaluation_output.txt 2>&1 + --split "${{ steps.detect.outputs.split }}" \ + --base-ref "$BASE_SHA" > evaluation_output.txt 2>&1 # Extract metrics from output if grep -q "Metrics:" evaluation_output.txt; then python3 automation/extract_metrics.py evaluation_output.txt diff --git a/automation/process_pr_submission.py b/automation/process_pr_submission.py index 5c869e76..8bf4a158 100644 --- a/automation/process_pr_submission.py +++ b/automation/process_pr_submission.py @@ -151,7 +151,7 @@ def cleanup_worktree(worktree_path: Path, branch_name: str, *, keep: bool) -> No def ensure_prediction_file_added( worktree_path: Path, base_ref: str, router_name: str ) -> None: - """Verify the PR adds a new prediction file for the specified router.""" + """Verify the PR adds or modifies a prediction file for the specified router.""" target_path = Path("router_inference") / "predictions" / f"{router_name}.json" @@ -174,14 +174,15 @@ def ensure_prediction_file_added( lines = [line.strip() for line in completed.stdout.splitlines() if line.strip()] for line in lines: - if line.startswith("A\t") or line.startswith("A "): + # Allow both added (A) and modified (M) files + if line[0] in ("A", "M"): return raise RuntimeError( textwrap.dedent( f""" - Expected pull request to add a new prediction file {target_path}. - Diff against {base_ref} did not show a newly added file. + Expected pull request to add or modify a prediction file {target_path}. + Diff against {base_ref} did not show a newly added or modified file. """ ).strip() ) @@ -396,6 +397,7 @@ def main(argv: Optional[list[str]] = None) -> int: "llm_evaluation/run.py", args.router, args.split, + "--force", ] evaluation_logs = "" diff --git a/llm_evaluation/run.py b/llm_evaluation/run.py index 2a7136ac..c4a184ae 100644 --- a/llm_evaluation/run.py +++ b/llm_evaluation/run.py @@ -359,7 +359,11 @@ def evaluate_single_prediction( def process_router_predictions( - router_name: str, split: str, save_interval: int = 50, num_workers: int = 4 + router_name: str, + split: str, + save_interval: int = 50, + num_workers: int = 4, + force: bool = False, ) -> None: """ Process router predictions by evaluating generated results with incremental saving. @@ -370,6 +374,7 @@ def process_router_predictions( split: Dataset split ("sub_10" or "full") save_interval: Number of entries to process before saving (default: 50) num_workers: Number of worker threads for parallel processing (default: 4) + force: If True, re-evaluate all entries even if already evaluated (default: False) """ logger.info(f"Starting LLM evaluation for router: {router_name} (split: {split})") logger.info(f"Using {num_workers} worker threads for parallel processing") @@ -401,12 +406,13 @@ def process_router_predictions( "The dataset contains entries from LiveCodeBench, and it is common to wait for ~10 minutes to evaluate the sub_10 split of the dataset." ) - # Prepare tasks: filter out already evaluated entries + # Prepare tasks: filter out already evaluated entries (unless force is True) # Note: This loop runs in the main thread before threading starts, so no lock needed tasks = [] for i, prediction in enumerate(predictions): # Check if already evaluated (has accuracy and cost) - if ( + # Skip if already evaluated AND force is False + if not force and ( prediction.get("accuracy") is not None and prediction.get("cost") is not None ): @@ -640,8 +646,14 @@ def main(): parser.add_argument( "--num-workers", type=int, - default=8, - help="Number of worker threads for parallel processing (default: 4). Set to 1 for sequential processing.", + default=16, + help="Number of worker threads for parallel processing (default: 16). Set to 1 for sequential processing.", + ) + parser.add_argument( + "--force", + action="store_true", + default=False, + help="Force re-evaluation of all entries, even if already evaluated (default: False)", ) args = parser.parse_args() @@ -666,7 +678,7 @@ def main(): args.save_interval if args.save_interval > 0 else len(predictions) + 1 ) process_router_predictions( - args.router_name, args.split, save_interval, args.num_workers + args.router_name, args.split, save_interval, args.num_workers, args.force ) except KeyboardInterrupt: logger.info("\nInterrupted by user. Saving partial results...") diff --git a/router_inference/predictions/glm-4-air-router.json b/router_inference/predictions/glm-4-air-router.json index b744a74f..75f8f164 100644 --- a/router_inference/predictions/glm-4-air-router.json +++ b/router_inference/predictions/glm-4-air-router.json @@ -4,7 +4,7 @@ "prompt": "Please read the following multiple-choice questions and provide the most likely correct answer based on the options given.\n\nContext: None\n\nQuestion: The average life expectancy of the most of firm is around ( ) years.\n\nOptions: \nA. 10\nB. 20\nC. 30\nD. 40\n\n\nProvide the correct letter choice in \\boxed{X}, where X is the correct letter choice. Keep the explanation or feedback within 3 sentences.", "prediction": "glm-4-air", "generated_result": { - "generated_answer": "\\boxed{A}\n\nThe average life expectancy for a firm is typically around 10 years. This is due to the high failure rate of businesses within the first few years of operation. Larger corporations may have longer life expectancies, but the general average is lower.", + "generated_answer": "\\boxed{A}\n\nThe average life expectancy for a firm was typically around 10 years. This is due to the high failure rate of businesses within the first few years of operation. Larger corporations may have longer life expectancies, but the general average is lower.", "success": true, "token_usage": { "input_tokens": 98,