-
Notifications
You must be signed in to change notification settings - Fork 27
[Feat.] Refactor llm_inference/run.py to use ParallelInferenceManager with batch inference #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
7329832
Refactor run.py to use ParallelInferenceManager with batch inference
yl231 b1418ca
fix format
yl231 ebb7396
Refactor run.py to use ParallelInferenceManager with batch inference
yl231 c98f592
Format code: fix line length and whitespace
yl231 868951e
Integrate clear_failed_entries into workflow to allow retries
yl231 854dcea
fix format
yl231 f4440b2
renamed readme
yl231 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,140 @@ | ||
| # Parallel Inference Quick Start Guide | ||
|
|
||
| ## Overview | ||
|
|
||
| This directory now supports parallel inference for processing models efficiently. The system processes models sequentially but uses multiple workers to parallelize queries within each model. | ||
|
|
||
| ## Quick Start | ||
|
|
||
| ### 1. Process All Models (Recommended) | ||
|
|
||
| Process all 26 models from `model_cost/model_cost.json` with 16 workers per model: | ||
|
|
||
| ```bash | ||
| cd /home/yl231/routers/RouterArena | ||
| uv run python llm_inference/batch_inference.py --num-workers 16 | ||
| ``` | ||
|
|
||
| **What this does:** | ||
| - Loads all model names from `model_cost/model_cost.json` | ||
| - Processes each model sequentially | ||
| - Uses 16 workers per model to process queries in parallel | ||
| - Skips already processed queries automatically | ||
| - Saves results to `./cached_results/{model_name}.jsonl` | ||
|
|
||
| ### 2. Process Specific Models | ||
|
|
||
| Process only certain models: | ||
|
|
||
| ```bash | ||
| uv run python llm_inference/batch_inference.py \ | ||
| --models gemini-2.0-flash-001 gpt-5-mini claude-sonnet-4-5 \ | ||
| --num-workers 16 | ||
| ``` | ||
|
|
||
| ### 3. Single Model Inference | ||
|
|
||
| Process a single model with parallel workers: | ||
|
|
||
| ```bash | ||
| # With 16 workers (parallel) | ||
| uv run python llm_inference/main.py \ | ||
| --model_name gemini-2.0-flash-001 \ | ||
| --num-workers 16 | ||
|
|
||
| # Sequential (backward compatible) | ||
| uv run python llm_inference/main.py \ | ||
| --model_name gemini-2.0-flash-001 | ||
| ``` | ||
|
|
||
| ## Configuration Options | ||
|
|
||
| ### batch_inference.py | ||
|
|
||
| | Option | Default | Description | | ||
| |--------|---------|-------------| | ||
| | `--num-workers` | 16 | Number of parallel workers per model | | ||
| | `--models` | All | Specific models to process (space-separated) | | ||
| | `--cache-dir` | `./cached_results` | Cache directory path | | ||
| | `--model-cost-path` | `./model_cost/model_cost.json` | Path to model cost file | | ||
| | `--input-file` | `./llm_inference/datasets/router_data.json` | Input data file | | ||
|
|
||
| ### main.py | ||
|
|
||
| | Option | Default | Description | | ||
| |--------|---------|-------------| | ||
| | `--model_name` | Required | Model name to process | | ||
| | `--num-workers` | 1 | Number of parallel workers (1 = sequential) | | ||
| | `--run-full` | False | Process full dataset | | ||
|
|
||
| ## Architecture | ||
|
|
||
| ``` | ||
| ┌─────────────────────────────────────────────────────┐ | ||
| │ batch_inference.py │ | ||
| │ ┌───────────────────────────────────────────────┐ │ | ||
| │ │ Load models from model_cost.json (26 models)│ │ | ||
| │ │ Load dataset (8400 queries) │ │ | ||
| │ └───────────────────────────────────────────────┘ │ | ||
| │ │ | ||
| │ FOR EACH MODEL (Sequential): │ | ||
| │ ┌───────────────────────────────────────────────┐ │ | ||
| │ │ Model 1: gemini-3-pro-preview │ │ | ||
| │ │ ┌─────────────────────────────────────────┐ │ │ | ||
| │ │ │ Check cache → 8200 done, 200 remaining │ │ │ | ||
| │ │ │ Launch 16 workers (Parallel) │ │ │ | ||
| │ │ │ ┌──────┐ ┌──────┐ ┌──────┐ │ │ │ | ||
| │ │ │ │Worker│ │Worker│ ... │Worker│ │ │ │ | ||
| │ │ │ │ 1 │ │ 2 │ │ 16 │ │ │ │ | ||
| │ │ │ │~12 q │ │~12 q │ │~13 q │ │ │ │ | ||
| │ │ │ └──────┘ └──────┘ └──────┘ │ │ │ | ||
| │ │ │ Wait for completion │ │ │ | ||
| │ │ │ Save results │ │ │ | ||
| │ │ └─────────────────────────────────────────┘ │ │ | ||
| │ └───────────────────────────────────────────────┘ │ | ||
| │ ┌───────────────────────────────────────────────┐ │ | ||
| │ │ Model 2: gemini-3-flash-preview │ │ | ||
| │ │ (Same parallel processing...) │ │ | ||
| │ └───────────────────────────────────────────────┘ │ | ||
| │ ... │ | ||
| │ ┌───────────────────────────────────────────────┐ │ | ||
| │ │ Model 26: meta-llama_llama-3.1-405b-instruct│ │ | ||
| │ │ (Same parallel processing...) │ │ | ||
| │ └───────────────────────────────────────────────┘ │ | ||
| └─────────────────────────────────────────────────────┘ | ||
| ``` | ||
|
|
||
| ## Next Steps | ||
|
|
||
| After inference completes: | ||
|
|
||
| 1. **Run Evaluation**: | ||
|
|
||
| ```bash | ||
| uv run python llm_evaluation/batch_evaluate.py \ | ||
| --cached-results-dir ./cached_results \ | ||
| --max-workers 16 | ||
| ``` | ||
|
|
||
| 2. **Compute Scores**: | ||
|
|
||
| ```bash | ||
| uv run python router_evaluation/compute_scores.py <router_name> | ||
| ``` | ||
|
|
||
| ## Files Modified/Created | ||
|
|
||
| 1. ✅ `llm_inference/parallel_inference.py` - Parallel inference manager | ||
| 2. ✅ `llm_inference/pipeline.py` - Added parallel support | ||
| 3. ✅ `llm_inference/main.py` - Added --num-workers argument | ||
| 4. ✅ `llm_inference/batch_inference.py` - Batch processing script | ||
| 5. ✅ `docs/PARALLEL_INFERENCE_IMPLEMENTATION.md` - Implementation details | ||
| 6. ✅ `llm_inference/README_PARALLEL.md` - This guide | ||
|
|
||
| ## Support | ||
|
|
||
| For issues or questions: | ||
| 1. Check logs for error messages | ||
| 2. Review `docs/PARALLEL_INFERENCE_IMPLEMENTATION.md` for details | ||
| 3. Verify model names in `model_cost/model_cost.json` | ||
| 4. Ensure dataset is prepared (run prep_datasets.py) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,248 @@ | ||
| # SPDX-FileCopyrightText: Copyright contributors to the RouterArena project | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """ | ||
| Batch LLM Inference Script | ||
|
|
||
| This script processes multiple models from model_cost.json sequentially, | ||
| using parallel workers for query processing within each model. | ||
|
|
||
| Architecture: | ||
| - Processes models sequentially (one at a time) | ||
| - Within each model, uses k workers for parallel query processing | ||
| - Example: 8400 queries with 16 workers → each worker handles ~525 queries | ||
|
|
||
| Usage: | ||
| # Process all models from model_cost.json with 16 workers per model | ||
| uv run python llm_inference/batch_inference.py --num-workers 16 | ||
|
|
||
| # Process specific models only | ||
| uv run python llm_inference/batch_inference.py \ | ||
| --models gemini-2.0-flash-001 gpt-5-mini \ | ||
| --num-workers 16 | ||
| """ | ||
|
|
||
| import argparse | ||
| import json | ||
| import os | ||
| import sys | ||
| import logging | ||
| import datetime | ||
| from typing import List, Optional | ||
| from parallel_inference import ParallelInferenceManager | ||
|
|
||
| # Add parent directory to path for imports | ||
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) | ||
|
|
||
| # Change to project root BEFORE loading .env file | ||
| current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
| base_dir = os.path.abspath(os.path.join(current_dir, "../")) | ||
| os.chdir(base_dir) | ||
|
yl231 marked this conversation as resolved.
Outdated
|
||
|
|
||
| # Load environment variables from .env file (now in project root) | ||
| try: | ||
| from dotenv import load_dotenv | ||
|
|
||
| load_dotenv() | ||
| except ImportError: | ||
| # dotenv is optional | ||
| pass | ||
|
|
||
|
|
||
| # Set up logging | ||
| logging.basicConfig( | ||
| level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | ||
| ) | ||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def load_model_list_from_cost_file( | ||
| model_cost_path: str = "./model_cost/model_cost.json", | ||
| specified_models: Optional[List[str]] = None, | ||
| ) -> List[str]: | ||
| """ | ||
| Load list of models to process from model_cost.json | ||
|
|
||
| Args: | ||
| model_cost_path: Path to model_cost.json | ||
| specified_models: Optional list of specific models to process | ||
|
|
||
| Returns: | ||
| List of model names to process | ||
| """ | ||
| if not os.path.exists(model_cost_path): | ||
| raise FileNotFoundError(f"model_cost.json not found at {model_cost_path}") | ||
|
|
||
| with open(model_cost_path, "r", encoding="utf-8") as f: | ||
| model_cost = json.load(f) | ||
|
|
||
| all_models = list(model_cost.keys()) | ||
|
|
||
| if specified_models: | ||
| # Validate specified models exist in model_cost.json | ||
| invalid_models = [m for m in specified_models if m not in all_models] | ||
| if invalid_models: | ||
| logger.warning( | ||
| f"These models not found in model_cost.json: {invalid_models}" | ||
| ) | ||
|
|
||
| models = [m for m in specified_models if m in all_models] | ||
| logger.info( | ||
| f"Processing {len(models)} specified models (out of {len(all_models)} available)" | ||
| ) | ||
| else: | ||
| models = all_models | ||
| logger.info(f"Processing all {len(models)} models from model_cost.json") | ||
|
|
||
| return models | ||
|
|
||
|
|
||
| def main(): | ||
| """Main function to handle batch inference.""" | ||
| parser = argparse.ArgumentParser( | ||
| description="Batch LLM Inference - Process multiple models from model_cost.json", | ||
| formatter_class=argparse.RawDescriptionHelpFormatter, | ||
| epilog=""" | ||
| Examples: | ||
| # Process all models from model_cost.json with 16 workers per model | ||
| uv run python llm_inference/batch_inference.py --num-workers 16 | ||
|
|
||
| # Process all models with 2 runs per query | ||
| uv run python llm_inference/batch_inference.py \\ | ||
| --num-workers 16 \\ | ||
| --num-runs 2 | ||
|
|
||
| # Process specific models only with 3 runs per query | ||
| uv run python llm_inference/batch_inference.py \\ | ||
| --models gemini-2.0-flash-001 gpt-5-mini \\ | ||
| --num-workers 16 \\ | ||
| --num-runs 3 | ||
|
|
||
| # Process with custom cache directory and 8 workers | ||
| uv run python llm_inference/batch_inference.py \\ | ||
| --cache-dir ./my_cache \\ | ||
| --num-workers 8 | ||
| """, | ||
| ) | ||
|
|
||
| parser.add_argument( | ||
| "--num-workers", | ||
| type=int, | ||
| default=16, | ||
| help="Number of parallel workers per model (default: 16)", | ||
| ) | ||
| parser.add_argument( | ||
| "--num-runs", | ||
| type=int, | ||
| default=1, | ||
| help="Target number of successful inference runs per query (default: 1)", | ||
| ) | ||
| parser.add_argument( | ||
| "--models", | ||
| nargs="+", | ||
| help="Specific models to process (default: all models from model_cost.json)", | ||
| ) | ||
| parser.add_argument( | ||
| "--cache-dir", | ||
| default="./cached_results", | ||
| help="Directory where cached results are stored (default: ./cached_results)", | ||
| ) | ||
| parser.add_argument( | ||
| "--model-cost-path", | ||
| default="./model_cost/model_cost.json", | ||
| help="Path to model_cost.json (default: ./model_cost/model_cost.json)", | ||
| ) | ||
| parser.add_argument( | ||
| "--input-file", | ||
| default="./llm_inference/datasets/router_data.json", | ||
| help="Path to input data file (default: ./llm_inference/datasets/router_data.json)", | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
|
|
||
| try: | ||
| # Note: Working directory already changed to project root at module load time | ||
| start_time = datetime.datetime.now() | ||
|
|
||
| logger.info("\n" + "=" * 80) | ||
| logger.info("BATCH INFERENCE STARTING") | ||
| logger.info("=" * 80) | ||
| logger.info(f"Start time: {start_time.strftime('%Y-%m-%d %H:%M:%S')}") | ||
| logger.info(f"Workers per model: {args.num_workers}") | ||
| logger.info(f"Target runs per query: {args.num_runs}") | ||
| logger.info(f"Cache directory: {args.cache_dir}") | ||
| logger.info(f"Input file: {args.input_file}") | ||
| logger.info("=" * 80 + "\n") | ||
|
|
||
| # Validate input file exists | ||
| if not os.path.exists(args.input_file): | ||
| raise FileNotFoundError( | ||
| f"Input file not found: {args.input_file}\n" | ||
| f"Please run: uv run python scripts/process_datasets/prep_datasets.py" | ||
| ) | ||
|
|
||
| # Load models to process | ||
| models = load_model_list_from_cost_file( | ||
| model_cost_path=args.model_cost_path, specified_models=args.models | ||
| ) | ||
|
|
||
| if not models: | ||
| logger.error("No models to process!") | ||
| return 1 | ||
|
|
||
| logger.info(f"Models to process: {models}\n") | ||
|
|
||
| # Initialize parallel inference manager | ||
| manager = ParallelInferenceManager( | ||
| cache_dir=args.cache_dir, workers=args.num_workers | ||
| ) | ||
|
|
||
| # Load input data once (will be reused for all models) | ||
| data = manager.load_input_data(args.input_file) | ||
| logger.info(f"Loaded {len(data)} queries from input file\n") | ||
|
|
||
| # Process all models sequentially | ||
| all_stats = manager.process_all_models( | ||
| models=models, | ||
| data=data, | ||
| num_workers=args.num_workers, | ||
| num_runs=args.num_runs, | ||
| ) | ||
|
|
||
| # Final summary | ||
| end_time = datetime.datetime.now() | ||
| duration = (end_time - start_time).total_seconds() / 60 | ||
|
|
||
| logger.info("\n" + "=" * 80) | ||
| logger.info("BATCH INFERENCE COMPLETED") | ||
| logger.info("=" * 80) | ||
| logger.info(f"End time: {end_time.strftime('%Y-%m-%d %H:%M:%S')}") | ||
| logger.info(f"Total duration: {duration:.1f} minutes") | ||
| logger.info("=" * 80) | ||
|
|
||
| # Summary statistics | ||
| total_processed = sum(s["processed"] for s in all_stats.values()) | ||
| total_successful = sum(s["successful"] for s in all_stats.values()) | ||
| total_failed = sum(s["failed"] for s in all_stats.values()) | ||
|
|
||
| logger.info("\nSummary Statistics:") | ||
| logger.info(f" Models processed: {len(models)}") | ||
| logger.info(f" Total queries processed: {total_processed}") | ||
| logger.info(f" Total successful: {total_successful}") | ||
| logger.info(f" Total failed: {total_failed}") | ||
|
|
||
| if total_processed > 0: | ||
| success_rate = (total_successful / total_processed) * 100 | ||
| logger.info(f" Success rate: {success_rate:.1f}%") | ||
|
|
||
| logger.info("=" * 80 + "\n") | ||
|
|
||
| return 0 | ||
|
|
||
| except Exception as e: | ||
| logger.error(f"Error in batch inference: {e}", exc_info=True) | ||
| return 1 | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| sys.exit(main()) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.