diff --git a/.gitignore b/.gitignore index 917e1120..99ef81c0 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,4 @@ genai_bench*.log # MkDocs site/ .cache/ +together_text-to-text_*/ diff --git a/README.md b/README.md index e9442f78..f7c40877 100644 --- a/README.md +++ b/README.md @@ -40,10 +40,63 @@ It provides detailed insights into model serving performance, offering both a us - 📝 **Rich Logs**: Automatically flushed to both terminal and file upon experiment completion. - 📈 **Experiment Analyzer**: Generates comprehensive Excel reports with pricing and raw metrics data, plus flexible plot configurations (default 2x4 grid) that visualize key performance metrics including throughput, latency (TTFT, E2E, TPOT), error rates, and RPS across different traffic scenarios and concurrency levels. Supports custom plot layouts and multi-line comparisons. +- 🧪 **Synthetic Tore-style prompts (optional)**: Generate synthetic requests that mimic tore-speed’s synthetic dataset prep, including a cached prefix region and exact input/output token counts for precise performance experiments. + +### Open-loop QPS mode (non-Locust) + +- Enable with `--non-locust` to use an open-loop arrival process (tore-speed style). Arrivals are scheduled globally by inter-arrival intervals; completions may lag depending on server speed. +- Use `--qps-level` (repeatable; floats allowed) to specify QPS levels and `--qps-distribution` (uniform|exponential|constant) for inter-arrival sampling. +- Duration of each level comes from `--max-time-per-run` (in minutes; floats allowed). Internally converted to seconds. +- Example (tore-speed compatible synthetic run): + +```bash +genai-bench benchmark \ + --non-locust \ + --qps-level 0.1 --qps-level 0.3 \ + --qps-distribution uniform \ + --max-requests-per-run 1500 --max-time-per-run 2 \ + --api-backend together --api-base https://api.together.xyz \ + --api-model-name --model-tokenizer \ + --task text-to-text \ + --traffic-scenario "D(10000,825)" \ + --synthetic --synthetic-cached-input-length 3000 +``` + +Notes: +- Arrival rate (QPS) is the planned schedule; observed RPS depends on completions within the time window. +- In synthetic mode, dataset file loading is skipped; prompts are constructed to exact token counts with a cached prefix region matching tore-speed semantics. + ## How to Start Please check [User Guide](https://docs.sglang.ai/genai-bench/user-guide/) and [CONTRIBUTING.md](https://docs.sglang.ai/genai-bench/development/contributing/) for how to install and use genai-bench. +### Synthetic data mode (tore-speed compatible) + +Genai-bench can synthesize prompts similar to tore-speed’s `--dataset_type synthetic`, with a fixed-size cached prefix and exact token counts enforced at the tokenizer level. + +- Enable with the `--synthetic` flag and provide a deterministic traffic scenario for input/output tokens (e.g., `D(10000,825)`). +- Specify the cached prefix size (in tokens) with `--synthetic-cached-input-length`. + +Example (concurrency mode): + +```bash +genai-bench benchmark \ + --api-backend together \ + --api-base https://api.together.xyz \ + --api-model-name \ + --model-tokenizer \ + --task text-to-text \ + --traffic-scenario "D(10000,825)" \ + --max-requests-per-run 1500 --max-time-per-run 2 \ + --num-concurrency 128 --spawn-rate 128 \ + --synthetic --synthetic-cached-input-length 3000 \ + --additional-request-params '{"stream": true}' +``` + +Notes: +- The sampler ensures the prompt contains exactly the requested number of input tokens. The leading `--synthetic-cached-input-length` tokens are filled with a repeated base phrase to emulate a cacheable prefix; a unique marker and a long instruction are appended to the uncached suffix region. +- This is useful for cache stress tests and apples-to-apples comparisons with tore-speed’s synthetic mode. + ## Benchmark Metrics Definition This section puts together the standard metrics required for LLM serving performance analysis. We classify metrics to two types: **single-request level metrics**, representing the metrics collected from one request. And **aggregated level metrics**, summarizing the single-request metrics from one run (with specific traffic scenario and num concurrency). diff --git a/TOGETHER_AI_INTEGRATION.md b/TOGETHER_AI_INTEGRATION.md new file mode 100644 index 00000000..323caca0 --- /dev/null +++ b/TOGETHER_AI_INTEGRATION.md @@ -0,0 +1,138 @@ +# Together AI Integration + +This document describes the Together AI backend integration for genai-bench. + +## Overview + +The Together AI backend has been fully integrated into genai-bench, allowing you to benchmark models hosted on Together AI's platform. + +## Features + +- **Chat Completions**: Support for text-to-text and image-text-to-text tasks +- **Embeddings**: Support for text-to-embeddings tasks +- **Streaming**: Full support for streaming responses +- **Authentication**: API key-based authentication + +## Usage + +### Basic Usage + +```bash +genai-bench benchmark \ + --api-backend together \ + --api-base https://api.together.xyz \ + --api-key YOUR_TOGETHER_API_KEY \ + --api-model-name meta-llama/Llama-2-7b-chat-hf \ + --task text-to-text \ + --num-concurrency 1,2,4,8 \ + --batch-size 1,2,4 \ + --dataset-path /path/to/your/dataset.json +``` + +### Environment Variables + +You can also set the API key via environment variable: + +```bash +export TOGETHER_API_KEY=your_api_key_here +genai-bench benchmark \ + --api-backend together \ + --api-base https://api.together.xyz \ + --api-model-name meta-llama/Llama-2-7b-chat-hf \ + --task text-to-text \ + # ... other options +``` + +### Supported Models + +Together AI supports a wide range of models. Some popular options include: + +- `meta-llama/Llama-2-7b-chat-hf` +- `meta-llama/Llama-2-13b-chat-hf` +- `meta-llama/Llama-2-70b-chat-hf` +- `mistralai/Mistral-7B-Instruct-v0.1` +- `togethercomputer/RedPajama-INCITE-Chat-3B-v1` +- And many more... + +### Supported Tasks + +- `text-to-text`: Standard chat completions +- `image-text-to-text`: Multimodal chat with images +- `text-to-embeddings`: Text embedding generation + +## Implementation Details + +### Files Added/Modified + +1. **User Implementation**: `genai_bench/user/together_user.py` + - Implements `TogetherUser` class extending `BaseUser` + - Supports chat completions and embeddings + - Handles streaming responses + +2. **Authentication**: `genai_bench/auth/together/` + - `auth.py`: Basic Together AI authentication + - `model_auth_adapter.py`: Adapter for model authentication + +3. **CLI Integration**: + - Added "together" to supported backends in `option_groups.py` + - Added together backend handling in `cli.py` + - Added TogetherUser to validation mapping + +### API Compatibility + +The Together AI backend uses OpenAI-compatible API endpoints: +- Chat completions: `/v1/chat/completions` +- Embeddings: `/v1/embeddings` + +This ensures compatibility with existing benchmarking scenarios and metrics collection. + +## Example Commands + +### Text-to-Text Benchmarking + +```bash +genai-bench benchmark \ + --api-backend together \ + --api-base https://api.together.xyz \ + --api-key $TOGETHER_API_KEY \ + --api-model-name meta-llama/Llama-2-7b-chat-hf \ + --task text-to-text \ + --num-concurrency 1,2,4,8,16 \ + --batch-size 1,2,4,8 \ + --dataset-path examples/dataset_configs/huggingface_simple.json +``` + +### Embeddings Benchmarking + +```bash +genai-bench benchmark \ + --api-backend together \ + --api-base https://api.together.xyz \ + --api-key $TOGETHER_API_KEY \ + --api-model-name togethercomputer/RedPajama-INCITE-Chat-3B-v1 \ + --task text-to-embeddings \ + --num-concurrency 1,2,4,8 \ + --batch-size 1,2,4,8 \ + --dataset-path examples/dataset_configs/huggingface_simple.json +``` + +### Multimodal Benchmarking + +```bash +genai-bench benchmark \ + --api-backend together \ + --api-base https://api.together.xyz \ + --api-key $TOGETHER_API_KEY \ + --api-model-name meta-llama/Llama-2-7b-chat-hf \ + --task image-text-to-text \ + --num-concurrency 1,2,4 \ + --batch-size 1,2 \ + --dataset-path examples/dataset_configs/config_llava-bench-in-the-wild.json +``` + +## Notes + +- The Together AI backend requires a valid API key from [Together AI](https://together.ai) +- All standard genai-bench features are supported (metrics collection, reporting, etc.) +- The implementation follows the same patterns as other backends for consistency +- Streaming responses are fully supported for accurate latency measurements diff --git a/docs/examples/index.md b/docs/examples/index.md index 2036a353..7da9bce1 100644 --- a/docs/examples/index.md +++ b/docs/examples/index.md @@ -3,6 +3,27 @@ This section provides practical examples and configurations for GenAI Bench. ## Quick Examples +### Open-loop QPS (non-Locust) — tore-speed style + +Use an open-loop arrival process that schedules requests by inter-arrival times. + +```bash +genai-bench benchmark \ + --non-locust \ + --qps-level 0.1 --qps-level 0.3 \ + --qps-distribution uniform \ + --max-requests-per-run 1500 --max-time-per-run 2 \ + --api-backend together --api-base https://api.together.xyz \ + --api-model-name --model-tokenizer \ + --task text-to-text \ + --traffic-scenario "D(10000,825)" \ + --synthetic --synthetic-cached-input-length 3000 +``` + +Notes: +- `--max-time-per-run` is in minutes (floats allowed); internally converted to seconds. It also drives the open-loop schedule duration per level. +- Arrival rate (QPS) sets the schedule; completion-based metrics (RPS) reflect how many finished within the window. + ### OpenAI GPT-4 Benchmark @@ -75,6 +96,24 @@ GenAI Bench supports various traffic patterns: - `N(480,240)/(300,150)` - Normal distribution - `U(50,100)/(200,250)` - Uniform distribution +### Synthetic Tore-style Prompts + +To mimic tore-speed’s synthetic dataset with a cached prefix and exact token counts: + +```bash +genai-bench benchmark \ + --api-backend together \ + --api-base https://api.together.xyz \ + --api-model-name \ + --model-tokenizer \ + --task text-to-text \ + --traffic-scenario "D(10000,825)" \ + --synthetic --synthetic-cached-input-length 3000 \ + --max-requests-per-run 1500 --max-time-per-run 2 +``` + +This constructs prompts with a leading 3000-token cacheable region and a unique uncached suffix, matching tore-speed synthetic behavior. + ### Embedding Scenarios - `E(64)` - 64 tokens per document diff --git a/genai_bench/analysis/excel_report.py b/genai_bench/analysis/excel_report.py index 59d806b9..5f6273f5 100644 --- a/genai_bench/analysis/excel_report.py +++ b/genai_bench/analysis/excel_report.py @@ -107,8 +107,9 @@ def _create_sheet_with_common_layout( sheet.append(row) num_rows += 1 - # Merge GPU Type column cells - merge_cells(sheet, 2, num_rows, 1) + # Merge GPU Type column cells only when there is at least one data row + if num_rows >= 2: + merge_cells(sheet, 2, num_rows, 1) apply_number_format(sheet, exclude_columns=["A", "B", "C"]) column_width_autofit(sheet) @@ -418,9 +419,9 @@ def create_aggregated_metrics_sheet( metrics: AggregatedMetrics = run_data[scenario][iteration][ # type: ignore[call-overload, assignment] "aggregated_metrics" ] - assert isinstance( - metrics, AggregatedMetrics - ), f"Expected AggregatedMetrics, got {type(metrics)}" + assert isinstance(metrics, AggregatedMetrics), ( + f"Expected AggregatedMetrics, got {type(metrics)}" + ) metrics_dict = metrics.model_dump() row = [] for field in metadata_headers: @@ -490,18 +491,19 @@ def create_single_request_metrics_sheet( sheet.append(row) rows_for_scenario += 1 row_for_iteration += 1 - merge_cells( - sheet, - start_row_iteration, - row_for_iteration + start_row_iteration - 1, - 1, - ) - merge_cells( - sheet, - start_row_iteration, - row_for_iteration + start_row_iteration - 1, - 2, - ) + if row_for_iteration >= 1: + merge_cells( + sheet, + start_row_iteration, + row_for_iteration + start_row_iteration - 1, + 1, + ) + merge_cells( + sheet, + start_row_iteration, + row_for_iteration + start_row_iteration - 1, + 2, + ) start_row_iteration += row_for_iteration start_row += rows_for_scenario diff --git a/genai_bench/analysis/flexible_plot_report.py b/genai_bench/analysis/flexible_plot_report.py index 7f9cd27c..4d472d92 100644 --- a/genai_bench/analysis/flexible_plot_report.py +++ b/genai_bench/analysis/flexible_plot_report.py @@ -889,7 +889,7 @@ def validate_plot_config_with_data( plot_spec.x_field, sample_agg_metrics, # type: ignore[arg-type] ): - errors.append(f"Plot {i+1}: Invalid x_field '{plot_spec.x_field}'") + errors.append(f"Plot {i + 1}: Invalid x_field '{plot_spec.x_field}'") # Validate Y field paths (single or multiple) try: @@ -901,22 +901,22 @@ def validate_plot_config_with_data( ): if len(y_field_specs) == 1: errors.append( - f"Plot {i+1}: Invalid y_field '{y_field_spec.field}'" + f"Plot {i + 1}: Invalid y_field '{y_field_spec.field}'" ) else: errors.append( - f"Plot {i+1}: Invalid y_fields[{j}] '{y_field_spec.field}'" + f"Plot {i + 1}: Invalid y_fields[{j}] '{y_field_spec.field}'" ) except Exception as e: - errors.append(f"Plot {i+1}: Error validating Y-fields: {e}") + errors.append(f"Plot {i + 1}: Error validating Y-fields: {e}") # Validate position bounds layout = config.layout row, col = plot_spec.position if row >= layout.rows or col >= layout.cols: errors.append( - f"Plot {i+1}: Position ({row}, {col}) exceeds layout bounds " - f"({layout.rows-1}, {layout.cols-1})" + f"Plot {i + 1}: Position ({row}, {col}) exceeds layout bounds " + f"({layout.rows - 1}, {layout.cols - 1})" ) return errors diff --git a/genai_bench/auth/factory.py b/genai_bench/auth/factory.py index b54d2213..b5ae40df 100644 --- a/genai_bench/auth/factory.py +++ b/genai_bench/auth/factory.py @@ -8,6 +8,7 @@ from genai_bench.auth.oci.session import OCISessionAuth from genai_bench.auth.oci.user_principal import OCIUserPrincipalAuth from genai_bench.auth.openai.auth import OpenAIAuth +from genai_bench.auth.together.auth import TogetherAuth class AuthFactory: @@ -25,6 +26,18 @@ def create_openai_auth(api_key: str) -> OpenAIAuth: """ return OpenAIAuth(api_key=api_key) + @staticmethod + def create_together_auth(api_key: str) -> TogetherAuth: + """Create Together authentication provider. + + Args: + api_key (str): Together API key + + Returns: + TogetherAuth: OpenAI auth provider + """ + return TogetherAuth(api_key=api_key) + @staticmethod def create_oci_auth( auth_type: str, diff --git a/genai_bench/auth/together/__init__.py b/genai_bench/auth/together/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/genai_bench/auth/together/auth.py b/genai_bench/auth/together/auth.py new file mode 100644 index 00000000..e345a5ba --- /dev/null +++ b/genai_bench/auth/together/auth.py @@ -0,0 +1,44 @@ +import os +from typing import Any, Dict, Optional + +from genai_bench.auth.auth_provider import AuthProvider + + +class TogetherAuth(AuthProvider): + """Together.ai Authentication Provider.""" + + def __init__(self, api_key: Optional[str] = None): + """Initialize Together Auth Provider. + + Args: + api_key (Optional[str]): Together API key. If None, will try to get from + TOGETHER_API_KEY environment variable. + + Raises: + ValueError: If no API key is provided or found in environment + """ + self.api_key = api_key or os.getenv("TOGETHER_API_KEY") + if not self.api_key or not self.api_key.strip(): + raise ValueError( + "Together API key must be provided or set in " + "TOGETHER_API_KEY environment variable" + ) + + def get_config(self) -> Dict[str, Any]: + """Get Together configuration. + + Returns: + Dict[str, Any]: Empty configuration dictionary + as OpenAI doesn't need additional config + """ + return {} + + def get_credentials(self) -> str: + """Get Together API key. + + Returns: + str: Together API key + """ + if not self.api_key: + raise ValueError("Together API key is not set") + return self.api_key diff --git a/genai_bench/auth/together/model_auth_adapter.py b/genai_bench/auth/together/model_auth_adapter.py new file mode 100644 index 00000000..c99b4259 --- /dev/null +++ b/genai_bench/auth/together/model_auth_adapter.py @@ -0,0 +1,58 @@ +"""Together model authentication adapter for backward compatibility.""" + +from typing import Any, Dict + +from genai_bench.auth.model_auth_provider import ModelAuthProvider +from genai_bench.auth.together.auth import TogetherAuth + + +class TogetherModelAuthAdapter(ModelAuthProvider): + """Adapter to use existing Together auth as model auth provider.""" + + def __init__(self, together_auth: TogetherAuth): + """Initialize Together model auth adapter. + + Args: + together_auth: Existing Together auth instance + """ + self.together_auth = together_auth + + def get_headers(self) -> Dict[str, str]: + """Get authentication headers for Together API requests. + + Returns: + Dict[str, str]: Headers with Authorization + """ + # Together uses Bearer token in Authorization header + if self.together_auth.api_key: + return {"Authorization": f"Bearer {self.together_auth.api_key}"} + return {} + + def get_config(self) -> Dict[str, Any]: + """Get Together model service configuration. + + Returns: + Dict[str, Any]: Configuration dictionary + """ + return { + "auth_type": self.get_auth_type(), + "has_api_key": bool(self.together_auth.api_key), + } + + def get_auth_type(self) -> str: + """Get the authentication type identifier. + + Returns: + 'api_key' + """ + return "api_key" + + def get_credentials(self) -> str: + """Get Together credentials. + + Returns: + API key string + """ + if self.together_auth.api_key: + return self.together_auth.api_key + return "" diff --git a/genai_bench/auth/unified_factory.py b/genai_bench/auth/unified_factory.py index 0ec8f394..21519afa 100644 --- a/genai_bench/auth/unified_factory.py +++ b/genai_bench/auth/unified_factory.py @@ -21,6 +21,8 @@ from genai_bench.auth.openai.auth import OpenAIAuth from genai_bench.auth.openai.model_auth_adapter import OpenAIModelAuthAdapter from genai_bench.auth.storage_auth_provider import StorageAuthProvider +from genai_bench.auth.together.auth import TogetherAuth +from genai_bench.auth.together.model_auth_adapter import TogetherModelAuthAdapter class UnifiedAuthFactory: @@ -32,7 +34,7 @@ def create_model_auth(provider: str, **kwargs) -> ModelAuthProvider: Args: provider: Provider type ('openai', 'oci', 'aws-bedrock', - 'azure-openai', 'gcp-vertex') + 'azure-openai', 'gcp-vertex', 'together') **kwargs: Provider-specific arguments Returns: @@ -84,10 +86,15 @@ def create_model_auth(provider: str, **kwargs) -> ModelAuthProvider: api_key=kwargs.get("api_key"), ) + elif provider == "together": + api_key = kwargs.get("api_key") + together_auth = TogetherAuth(api_key=api_key) + return TogetherModelAuthAdapter(together_auth) + else: raise ValueError( f"Unsupported model provider: {provider}. " - f"Supported: openai, oci, aws-bedrock, azure-openai, gcp-vertex" + f"Supported: openai, oci, aws-bedrock, azure-openai, gcp-vertex, together" ) @staticmethod diff --git a/genai_bench/cli/cli.py b/genai_bench/cli/cli.py index 1ff09985..bc26c20c 100644 --- a/genai_bench/cli/cli.py +++ b/genai_bench/cli/cli.py @@ -18,6 +18,7 @@ from genai_bench.cli.option_groups import ( api_options, distributed_locust_options, + open_loop_options, experiment_options, model_auth_options, object_storage_options, @@ -64,6 +65,7 @@ def cli(ctx): @experiment_options @sampling_options @distributed_locust_options +@open_loop_options @object_storage_options @storage_auth_options @click.pass_context @@ -117,9 +119,21 @@ def benchmark( dataset_config, dataset_prompt_column, dataset_image_column, + # Synthetic Tore-style options (added via sampling_options) + synthetic, + synthetic_input_length, + synthetic_input_length_stdev, + synthetic_output_length, + synthetic_output_length_stdev, + synthetic_cached_input_length, num_workers, master_port, spawn_rate, + # Open-loop options + non_locust, + qps_level, + qps_distribution, + random_seed, upload_results, namespace, # Storage auth options @@ -221,6 +235,10 @@ def benchmark( } ) + elif api_backend == "together": + # Together uses API key for authentication + auth_kwargs["api_key"] = model_api_key or api_key + elif api_backend in ["vllm", "sglang"]: # vLLM and SGLang use OpenAI-compatible API auth_kwargs["api_key"] = model_api_key or api_key @@ -274,6 +292,12 @@ def benchmark( dataset_path=dataset_path, prompt_column=dataset_prompt_column, image_column=dataset_image_column, + synthetic=ctx.params.get("synthetic", False), + synthetic_input_length=ctx.params.get("synthetic_input_length"), + synthetic_input_length_stdev=ctx.params.get("synthetic_input_length_stdev"), + synthetic_output_length=ctx.params.get("synthetic_output_length"), + synthetic_output_length_stdev=ctx.params.get("synthetic_output_length_stdev"), + synthetic_cached_input_length=ctx.params.get("synthetic_cached_input_length"), ) # Load data using the factory @@ -344,43 +368,55 @@ def benchmark( ) experiment_metadata_file.write_text(experiment_metadata.model_dump_json(indent=4)) - # Initialize environment - environment = Environment(user_classes=[user_class]) - # Assign the selected task to the user class - environment.user_classes[0].tasks = [user_task] - environment.sampler = sampler - - # Set up distributed runner - config = DistributedConfig( - num_workers=num_workers, - master_port=master_port, - ) - runner = DistributedRunner( - environment=environment, - config=config, - dashboard=dashboard, - ) - runner.setup() + if not non_locust: + # Initialize environment + environment = Environment(user_classes=[user_class]) + # Assign the selected task to the user class + environment.user_classes[0].tasks = [user_task] + environment.sampler = sampler + + # Set up distributed runner + config = DistributedConfig( + num_workers=num_workers, + master_port=master_port, + ) + runner = DistributedRunner( + environment=environment, + config=config, + dashboard=dashboard, + ) + runner.setup() + else: + # Non-Locust open-loop mode uses an in-process metrics collector + from genai_bench.metrics.aggregated_metrics_collector import ( + AggregatedMetricsCollector, + ) + aggregated_metrics_collector = AggregatedMetricsCollector() # Worker process doesn't need to run the main benchmark flow as it only # sends requests and collects response - if num_workers > 0 and isinstance(environment.runner, WorkerRunner): + if (not non_locust) and num_workers > 0 and isinstance(environment.runner, WorkerRunner): return # Get metrics collector from runner for master/local mode - if not runner.metrics_collector: - raise RuntimeError("Metrics collector not initialized") - aggregated_metrics_collector = runner.metrics_collector - - # Iterate over each scenario_str and concurrency level, - # and run the experiment - iteration_values = batch_size if iteration_type == "batch_size" else num_concurrency + if not non_locust: + if not runner.metrics_collector: + raise RuntimeError("Metrics collector not initialized") + aggregated_metrics_collector = runner.metrics_collector + + # Iterate over each scenario_str and concurrency/QPS level, and run the experiment + if non_locust and qps_level: + iteration_values = list(qps_level) + iteration_type = "num_concurrency" + else: + iteration_values = batch_size if iteration_type == "batch_size" else num_concurrency total_runs = len(traffic_scenario) * len(iteration_values) with dashboard.live: for scenario_str in traffic_scenario: dashboard.reset_plot_metrics() sanitized_scenario_str = sanitize_string(scenario_str) - runner.update_scenario(scenario_str) + if not non_locust: + runner.update_scenario(scenario_str) # Store metrics for current scenario for interim plot scenario_metrics = { @@ -399,7 +435,8 @@ def benchmark( ) # Update batch size for each iteration - runner.update_batch_size(batch_size) + if not non_locust: + runner.update_batch_size(batch_size) aggregated_metrics_collector.set_run_metadata( iteration, scenario_str, iteration_type @@ -410,22 +447,48 @@ def benchmark( dashboard.start_run(max_time_per_run, start_time, max_requests_per_run) # Use custom spawn rate if provided, otherwise use concurrency - actual_spawn_rate = ( - spawn_rate if spawn_rate is not None else concurrency - ) - logger.info( - f"Starting benchmark with concurrency={concurrency}, " - f"spawn_rate={actual_spawn_rate}" - ) - environment.runner.start(concurrency, spawn_rate=actual_spawn_rate) + if not non_locust: + actual_spawn_rate = ( + spawn_rate if spawn_rate is not None else concurrency + ) + logger.info( + f"Starting benchmark with concurrency={concurrency}, " + f"spawn_rate={actual_spawn_rate}" + ) + environment.runner.start(concurrency, spawn_rate=actual_spawn_rate) - total_run_time = manage_run_time( - max_time_per_run=max_time_per_run, - max_requests_per_run=max_requests_per_run, - environment=environment, - ) + total_run_time = manage_run_time( + max_time_per_run=max_time_per_run, + max_requests_per_run=max_requests_per_run, + environment=environment, + ) - environment.runner.stop() + environment.runner.stop() + else: + # Open-loop QPS: treat iteration 'concurrency' as target QPS + from genai_bench.openloop.runner import OpenLoopRunner + ol = OpenLoopRunner( + sampler=sampler, + api_backend=api_backend, + api_base=api_base, + api_model_name=api_model_name, + auth_provider=auth_provider, + aggregated_metrics_collector=aggregated_metrics_collector, + dashboard=dashboard, + ) + logger.info( + f"Starting open-loop run with qps={concurrency}, " + f"duration_s={max_time_per_run}, distribution={qps_distribution}" + ) + total_run_time = ol.run( + qps_level=concurrency, + duration_s=max_time_per_run, + distribution=qps_distribution, + random_seed=random_seed, + max_requests=max_requests_per_run, + max_time_s=None, + scenario=scenario_str, + ) # Aggregate metrics after each run end_time = time.monotonic() @@ -496,7 +559,8 @@ def benchmark( time.sleep(2) # Final cleanup - runner.cleanup() + if not non_locust: + runner.cleanup() # Flash all the logs to terminal if delayed_log_handler: diff --git a/genai_bench/cli/option_groups.py b/genai_bench/cli/option_groups.py index f8ec4b60..583f4edb 100644 --- a/genai_bench/cli/option_groups.py +++ b/genai_bench/cli/option_groups.py @@ -83,6 +83,7 @@ def api_options(func): "aws-bedrock", "azure-openai", "gcp-vertex", + "together", "vllm", "sglang", ], @@ -95,6 +96,35 @@ def api_options(func): "open-source servers.", )(func) return func +def open_loop_options(func): + func = click.option( + "--non-locust", + is_flag=True, + default=False, + help="Use open-loop QPS generator (tore-speed style) instead of Locust.", + )(func) + func = click.option( + "--qps-level", + type=click.FLOAT, + multiple=True, + default=None, + help="Open-loop QPS levels (can be specified multiple times).", + )(func) + func = click.option( + "--qps-distribution", + type=click.Choice(["uniform", "exponential", "constant"], case_sensitive=False), + default="uniform", + help="Interarrival distribution for open-loop mode (default: uniform).", + )(func) + func = click.option( + "--random-seed", + type=int, + default=42, + help="Random seed for interarrival generation (default: 42).", + )(func) + return func + + # Model endpoint authentication options @@ -295,6 +325,43 @@ def sampling_options(func): help="Path to JSON configuration file for advanced dataset options. " "This allows full control over dataset loading parameters.", )(func) + # Synthetic Tore-style generation options (optional) + func = click.option( + "--synthetic", + is_flag=True, + default=False, + help="Enable Tore-style synthetic prompt generation.", + )(func) + func = click.option( + "--synthetic-input-length", + type=int, + default=None, + help="Synthetic input length (tokens).", + )(func) + func = click.option( + "--synthetic-input-length-stdev", + type=int, + default=None, + help="Stddev for synthetic input length (tokens).", + )(func) + func = click.option( + "--synthetic-output-length", + type=int, + default=None, + help="Synthetic output length (tokens).", + )(func) + func = click.option( + "--synthetic-output-length-stdev", + type=int, + default=None, + help="Stddev for synthetic output length (tokens).", + )(func) + func = click.option( + "--synthetic-cached-input-length", + type=int, + default=None, + help="Number of input tokens to allocate to cached prefix.", + )(func) func = click.option( "--dataset-image-column", type=str, @@ -397,12 +464,11 @@ def experiment_options(func): )(func) func = click.option( "--max-time-per-run", - type=int, + type=float, required=True, prompt=True, - help="The max duration per experiment run. Unit: minute. " - "One experiment run will exit if max_time_per_run is " - "reached. ", + help="The max duration per experiment run in minutes (floats allowed). " + "Each run exits when this wall-clock limit or max-requests is reached.", )(func) func = click.option( "--warmup-ratio", diff --git a/genai_bench/cli/validation.py b/genai_bench/cli/validation.py index 163fa030..3e09ea3b 100644 --- a/genai_bench/cli/validation.py +++ b/genai_bench/cli/validation.py @@ -16,6 +16,7 @@ from genai_bench.user.oci_cohere_user import OCICohereUser from genai_bench.user.oci_genai_user import OCIGenAIUser from genai_bench.user.openai_user import OpenAIUser +from genai_bench.user.together_user import TogetherUser logger = init_logger(__name__) @@ -27,6 +28,7 @@ AWSBedrockUser.BACKEND_NAME: AWSBedrockUser, AzureOpenAIUser.BACKEND_NAME: AzureOpenAIUser, GCPVertexUser.BACKEND_NAME: GCPVertexUser, + TogetherUser.BACKEND_NAME: TogetherUser, "vllm": OpenAIUser, # vLLM uses OpenAI-compatible API "sglang": OpenAIUser, # SGLang uses OpenAI-compatible API } diff --git a/genai_bench/data/config.py b/genai_bench/data/config.py index 07592dc2..556c27a0 100644 --- a/genai_bench/data/config.py +++ b/genai_bench/data/config.py @@ -70,6 +70,27 @@ class DatasetConfig(BaseModel): description="Overrides pillows internal DDOS protection", ) + # Synthetic Tore-style options (optional) + synthetic: bool = Field( + False, + description="Enable Tore-style synthetic prompt generation.", + ) + synthetic_input_length: Optional[int] = Field( + None, description="Target number of input tokens for synthetic prompts" + ) + synthetic_input_length_stdev: Optional[int] = Field( + None, description="Stddev for input tokens (optional)" + ) + synthetic_output_length: Optional[int] = Field( + None, description="Target number of output tokens for synthetic prompts" + ) + synthetic_output_length_stdev: Optional[int] = Field( + None, description="Stddev for output tokens (optional)" + ) + synthetic_cached_input_length: Optional[int] = Field( + None, description="Number of input tokens to allocate to cached prefix" + ) + @classmethod def from_file(cls, config_path: str) -> "DatasetConfig": """Load configuration from a JSON file.""" @@ -124,4 +145,10 @@ def from_cli_args( image_column=image_column, prompt_lambda=None, unsafe_allow_large_images=False, + synthetic=bool(kwargs.get("synthetic", False)), + synthetic_input_length=kwargs.get("synthetic_input_length"), + synthetic_input_length_stdev=kwargs.get("synthetic_input_length_stdev"), + synthetic_output_length=kwargs.get("synthetic_output_length"), + synthetic_output_length_stdev=kwargs.get("synthetic_output_length_stdev"), + synthetic_cached_input_length=kwargs.get("synthetic_cached_input_length"), ) diff --git a/genai_bench/data/loaders/factory.py b/genai_bench/data/loaders/factory.py index b8e066d5..57c6c1a5 100644 --- a/genai_bench/data/loaders/factory.py +++ b/genai_bench/data/loaders/factory.py @@ -40,6 +40,12 @@ def _load_text_data( dataset_config: DatasetConfig, output_modality: str ) -> List[str]: """Load text data.""" + # Synthetic-only path: skip file/HF loading entirely + if bool(getattr(dataset_config, "synthetic", False)): + logger.info( + "Synthetic mode enabled: skipping dataset file loading and returning a minimal placeholder." + ) + return ["synthetic"] loader = TextDatasetLoader(dataset_config) data = loader.load_request() diff --git a/genai_bench/metrics/aggregated_metrics_collector.py b/genai_bench/metrics/aggregated_metrics_collector.py index 6801f58b..e8f4dbd5 100644 --- a/genai_bench/metrics/aggregated_metrics_collector.py +++ b/genai_bench/metrics/aggregated_metrics_collector.py @@ -235,7 +235,12 @@ def aggregate_metrics_data( f"check logs from genai-bench and server!" ) - # Calculate requests per minute + # Total responses (success + error) + self.aggregated_metrics.num_requests = ( + self.aggregated_metrics.num_completed_requests + + self.aggregated_metrics.num_error_requests + ) + # Requests/sec (completed only) self.aggregated_metrics.requests_per_second = ( ( self.aggregated_metrics.num_completed_requests @@ -244,9 +249,14 @@ def aggregate_metrics_data( if self.aggregated_metrics.run_duration > 0 else 0 ) - self.aggregated_metrics.num_requests = ( - self.aggregated_metrics.num_completed_requests - + self.aggregated_metrics.num_error_requests + # Tore-speed style: responses returned per second (success + error) + self.aggregated_metrics.summary_actual_qps = ( + ( + self.aggregated_metrics.num_requests + / self.aggregated_metrics.run_duration + ) + if self.aggregated_metrics.run_duration > 0 + else 0 ) def set_run_metadata( diff --git a/genai_bench/metrics/metrics.py b/genai_bench/metrics/metrics.py index 22fb78d2..70759e40 100644 --- a/genai_bench/metrics/metrics.py +++ b/genai_bench/metrics/metrics.py @@ -157,6 +157,16 @@ class AggregatedMetrics(BaseModel): requests_per_second: float = Field( 0.0, description="The average number of completed requests per second" ) + summary_actual_qps: float = Field( + 0.0, + description="Responses returned per second over the full run (success + error)", + ) + arrival_requests_per_second: float = Field( + 0.0, description="Planned/actual arrival rate (arrivals per second) for the run" + ) + total_arrivals: int = Field( + 0, description="Total number of requests scheduled/launched during the arrival window" + ) error_codes_frequency: Dict[int, int] = Field( default_factory=dict, description="Frequency of error codes" ) diff --git a/genai_bench/metrics/request_metrics_collector.py b/genai_bench/metrics/request_metrics_collector.py index 4201b838..2f6da4a3 100644 --- a/genai_bench/metrics/request_metrics_collector.py +++ b/genai_bench/metrics/request_metrics_collector.py @@ -28,12 +28,12 @@ def calculate_metrics( response (UserResponse): The customized UserResponse object containing the response data needed to calculate metrics. """ - assert ( - response.num_prefill_tokens is not None - ), "response.num_prefill_tokens is None" - assert ( - response.time_at_first_token is not None - ), "response.time_at_first_token is None" + assert response.num_prefill_tokens is not None, ( + "response.num_prefill_tokens is None" + ) + assert response.time_at_first_token is not None, ( + "response.time_at_first_token is None" + ) assert response.start_time is not None, "response.start_time is None" assert response.end_time is not None, "response.end_time is None" diff --git a/genai_bench/openloop/runner.py b/genai_bench/openloop/runner.py new file mode 100644 index 00000000..edd61f23 --- /dev/null +++ b/genai_bench/openloop/runner.py @@ -0,0 +1,364 @@ +import asyncio +import orjson +import time +import random +from typing import List, Optional + +import aiohttp + +from genai_bench.logging import init_logger +from genai_bench.metrics.request_metrics_collector import RequestMetricsCollector +from genai_bench.protocol import ( + UserChatRequest, + UserEmbeddingRequest, + UserImageChatRequest, + UserResponse, + UserChatResponse, +) +from genai_bench.scenarios.base import Scenario + + +logger = init_logger(__name__) + + +class OpenLoopRunner: + """ + Open-loop QPS runner that schedules global inter-arrivals (tore-speed style) + and emits RequestLevelMetrics via AggregatedMetricsCollector. + """ + + def __init__( + self, + *, + sampler, + api_backend: str, + api_base: str, + api_model_name: str, + auth_provider, + aggregated_metrics_collector, + dashboard=None, + ) -> None: + self.sampler = sampler + self.api_backend = api_backend + self.api_base = api_base + self.api_model_name = api_model_name + self.auth_provider = auth_provider + self.aggregated = aggregated_metrics_collector + self.dashboard = dashboard + + self.headers = None + if auth_provider and hasattr(auth_provider, "get_credentials"): + token = auth_provider.get_credentials() + if token: + self.headers = { + "Authorization": f"Bearer {token}", + "Content-Type": "application/json", + } + + # AIOHTTP settings aligned with tore-speed + self._aio_timeout = aiohttp.ClientTimeout(total=6 * 60 * 60) + self._aio_read_bufsize = 256 * 1024 + + def _wait_intervals( + self, qps_level: float, duration_s: int, random_seed: int, distribution: str + ) -> List[float]: + mean = 1.0 / qps_level + random.seed(random_seed) + out: List[float] = [] + for _ in range(int(qps_level * duration_s)): + if distribution == "exponential": + out.append(random.expovariate(1.0 / mean)) + elif distribution == "uniform": + out.append(random.uniform(0, 2 * mean)) + elif distribution == "constant": + out.append(mean) + else: + raise ValueError(f"Invalid distribution: {distribution}") + return out + + def _prepare_request(self, scenario_input): + # Accept either a prebuilt Scenario or a scenario string, for parity with Locust path + if isinstance(scenario_input, str): + scenario_obj = Scenario.from_string(scenario_input) + else: + scenario_obj = scenario_input + req = self.sampler.sample(scenario_obj) + return req + + # Removed session reuse; sessions are created per-request to match tore-speed + + async def _send_request(self, req) -> UserResponse: + # Currently implement OpenAI-compatible endpoints for text chat and embeddings + try: + if isinstance(req, (UserChatRequest, UserImageChatRequest)): + endpoint = "/v1/chat/completions" + if isinstance(req, UserImageChatRequest): + text_content = [{"type": "text", "text": req.prompt}] # type: ignore[attr-defined] + image_content = [ + {"type": "image_url", "image_url": {"url": image}} # type: ignore[attr-defined] + for image in req.image_content # type: ignore[attr-defined] + ] + content = text_content + image_content + else: + content = req.prompt + + payload = { + "model": req.model, + "messages": [ + { + "role": "user", + "content": content, + } + ], + "max_tokens": req.additional_request_params.get("max_tokens", None) + or req.__dict__.get("max_tokens"), + "temperature": req.additional_request_params.get("temperature", 0.0), + "ignore_eos": req.additional_request_params.get( + "ignore_eos", bool(req.__dict__.get("max_tokens")) + ), + # Force streaming to compute TTFT/TPOT properly + "stream": True, + "stream_options": {"include_usage": True}, + **{k: v for k, v in req.additional_request_params.items() if k not in {"stream"}}, + } + + start_time = time.monotonic() + async with aiohttp.ClientSession( + headers=self.headers, + timeout=self._aio_timeout, + read_bufsize=self._aio_read_bufsize, + ) as session: + async with session.post( + url=f"{self.api_base}{endpoint}", json=payload + ) as resp: + if resp.status != 200: + # Stream entire error body for parity with tore-speed + error_message_bytes = b"" + async for chunk_bytes in resp.content: + error_message_bytes += chunk_bytes + text = error_message_bytes.decode("utf-8") + return UserResponse(status_code=resp.status, error_message=text) + + stream_chunk_prefix = "data: " + end_chunk = b"[DONE]" + + generated_text = "" + tokens_received = 0 + time_at_first_token: Optional[float] = None + finish_reason = None + previous_data = None + num_prompt_tokens = None + + async for raw_line in resp.content: + chunk = (raw_line or b"").strip() + if not chunk: + continue + # Gate on SSE style lines like tore-speed does + if not chunk.startswith(stream_chunk_prefix.encode()): + continue + chunk = chunk[len(stream_chunk_prefix) :] + if chunk.strip() == end_chunk: + break + try: + data = orjson.loads(chunk) + except Exception: + previous_data = chunk + continue + + if data.get("error") is not None: + return UserResponse( + status_code=data["error"].get("code", -1), + error_message=data["error"].get("message", "Unknown error"), + ) + + if (not data.get("choices")) and finish_reason and data.get("usage"): + usage = data["usage"] + num_prompt_tokens = usage.get("prompt_tokens") + tokens_received = usage.get("completion_tokens", 0) + if not time_at_first_token: + time_at_first_token = time.monotonic() + break + + try: + delta = data["choices"][0]["delta"] + content_piece = delta.get("content") or delta.get("reasoning_content") + usage = delta.get("usage") + + if usage: + tokens_received = usage.get("completion_tokens", tokens_received) + if content_piece: + if not time_at_first_token: + time_at_first_token = time.monotonic() + generated_text += content_piece + + finish_reason = data["choices"][0].get("finish_reason", None) + if finish_reason and data.get("usage"): + usage = data["usage"] + num_prompt_tokens = usage.get("prompt_tokens") + tokens_received = usage.get("completion_tokens", tokens_received) + break + except (IndexError, KeyError): + previous_data = data + continue + + previous_data = data + + end_time = time.monotonic() + + if not tokens_received: + tokens_received = self.sampler.get_token_length( + generated_text, add_special_tokens=False + ) + + # Fallback: if server didn't return prompt_tokens in usage, derive from request + if num_prompt_tokens is None: + num_prompt_tokens = getattr(req, "num_prefill_tokens", None) + if num_prompt_tokens is None: + num_prompt_tokens = self.sampler.get_token_length( + req.prompt, add_special_tokens=False + ) + + if not time_at_first_token: + time_at_first_token = end_time + + return UserChatResponse( + status_code=200, + generated_text=generated_text, + tokens_received=tokens_received, + time_at_first_token=time_at_first_token, + num_prefill_tokens=num_prompt_tokens, + start_time=start_time, + end_time=end_time, + ) + + elif isinstance(req, UserEmbeddingRequest): + endpoint = "/v1/embeddings" + payload = { + "model": req.model, + "input": req.documents, + **req.additional_request_params, + } + start_time = time.monotonic() + async with aiohttp.ClientSession( + headers=self.headers, + timeout=self._aio_timeout, + read_bufsize=self._aio_read_bufsize, + ) as session: + async with session.post( + url=f"{self.api_base}{endpoint}", json=payload + ) as resp: + end_time = time.monotonic() + if resp.status == 200: + data = await resp.json() + num_prompt_tokens = data.get("usage", {}).get("prompt_tokens") + return UserResponse( + status_code=200, + start_time=start_time, + end_time=end_time, + time_at_first_token=end_time, + num_prefill_tokens=num_prompt_tokens, + ) + else: + # Stream entire error body for parity with tore-speed + error_message_bytes = b"" + async for chunk_bytes in resp.content: + error_message_bytes += chunk_bytes + text = error_message_bytes.decode("utf-8") + return UserResponse(status_code=resp.status, error_message=text) + + else: + return UserResponse(status_code=400, error_message="Unsupported request type") + except aiohttp.ClientConnectionError as e: + return UserResponse(status_code=503, error_message=f"Connection error: {e}") + except asyncio.TimeoutError as e: + return UserResponse(status_code=408, error_message=f"Request timed out: {e}") + except Exception as e: + return UserResponse(status_code=500, error_message=str(e)) + + async def _send_one(self, req) -> None: + response = await self._send_request(req) + # Convert to RequestLevelMetrics and add to collector + collector = RequestMetricsCollector() + if response.status_code == 200: + collector.calculate_metrics(response) + else: + collector.metrics.error_code = response.status_code + collector.metrics.error_message = response.error_message + self.aggregated.add_single_request_metrics(collector.metrics) + # Update dashboard live if available + if self.dashboard is not None: + live = self.aggregated.get_live_metrics() + total_requests = ( + self.aggregated.aggregated_metrics.num_completed_requests + + self.aggregated.aggregated_metrics.num_error_requests + ) + self.dashboard.handle_single_request( + live, total_requests, collector.metrics.error_code + ) + + def run( + self, + *, + qps_level: float, + duration_s: int, + distribution: str, + random_seed: int, + max_requests: Optional[int], + max_time_s: Optional[int], + scenario: str, + ) -> float: + intervals = self._wait_intervals(qps_level, duration_s, random_seed, distribution) + n = len(intervals) + if max_requests is not None: + n = min(n, max_requests) + intervals = intervals[:n] + + prepared = [self._prepare_request(scenario) for _ in range(n)] + + async def produce(): + # Periodic UI tick to advance time-based progress even before first completion + done_flag = {"done": False} + + async def tick_progress(): + if self.dashboard is None: + return + while not done_flag["done"]: + try: + progress = self.dashboard.calculate_time_based_progress() + self.dashboard.update_benchmark_progress_bars(progress) + except Exception: + pass + await asyncio.sleep(0.5) + + tick_task = None + if self.dashboard is not None: + tick_task = asyncio.create_task(tick_progress()) + tasks = [] + for wait_s, req in zip(intervals, prepared): + tasks.append(asyncio.create_task(self._send_one(req))) + await asyncio.sleep(wait_s) + if tasks: + await asyncio.gather(*tasks) + if tick_task is not None: + done_flag["done"] = True + # Give one last update chance + await asyncio.sleep(0) + tick_task.cancel() + + start = time.monotonic() + try: + if max_time_s is not None and max_time_s > 0: + asyncio.run(asyncio.wait_for(produce(), timeout=max_time_s)) + else: + asyncio.run(produce()) + except asyncio.TimeoutError: + logger.info("Open-loop run timed out per max_time_s") + end = time.monotonic() + # No shared session to close; each request used its own session + # Record arrivals as an arrival rate metric for this run + arrival_rate = (n / (duration_s if duration_s > 0 else 1)) + self.aggregated.aggregated_metrics.total_arrivals = n + self.aggregated.aggregated_metrics.arrival_requests_per_second = arrival_rate + return end - start + + diff --git a/genai_bench/sampling/text.py b/genai_bench/sampling/text.py index 3f96ae42..3e9b1110 100644 --- a/genai_bench/sampling/text.py +++ b/genai_bench/sampling/text.py @@ -1,4 +1,5 @@ import random +import time from typing import Any, Dict, List, Optional from genai_bench.data.config import DatasetConfig @@ -41,6 +42,12 @@ def __init__( self.data = data self.batch_size = 1 # Default batch size + # Synthetic Tore-style configuration + self.synthetic_enabled = bool(getattr(self.dataset_config, "synthetic", False)) + self.synthetic_cached_tokens = int( + getattr(self.dataset_config, "synthetic_cached_input_length", 0) or 0 + ) + self._synthetic_request_counter = 0 def sample(self, scenario: Optional[Scenario]) -> UserRequest: """ @@ -180,6 +187,67 @@ def _sample_text(self, num_input_tokens: Optional[int]) -> str: if not num_input_tokens: return random.choice(self.data) + # Synthetic Tore-style generation path: build exact-length prompt with cached prefix + if self.synthetic_enabled: + target_tokens = int(num_input_tokens) + + # Base phrase to assemble tokens from + base_phrase = "hi," + base_tokens = self.tokenizer.encode(base_phrase, add_special_tokens=False) + if not base_tokens: + # Fallback to a dot if tokenizer strips everything + base_tokens = self.tokenizer.encode(".", add_special_tokens=False) + + def repeat_tokens_to_length(tokens: List[int], length: int) -> List[int]: + if length <= 0: + return [] + repeated: List[int] = [] + while len(repeated) < length: + repeated.extend(tokens) + return repeated[:length] + + # Prefix (cached region) + num_prefix = min(self.synthetic_cached_tokens, target_tokens) + prefix_tokens = repeat_tokens_to_length(base_tokens, num_prefix) + + # Unique marker to differentiate prompts + self._synthetic_request_counter += 1 + marker_text = f"{self._synthetic_request_counter}-{int(time.time()*1000) % 1000000}" + marker_tokens = self.tokenizer.encode(marker_text, add_special_tokens=False) + + # Remaining tokens for suffix + remaining = target_tokens - len(prefix_tokens) - len(marker_tokens) + if remaining < 0: + # If marker overflowed, trim it + marker_tokens = marker_tokens[: max(0, target_tokens - len(prefix_tokens))] + remaining = target_tokens - len(prefix_tokens) - len(marker_tokens) + + # Suffix (uncached region) seeded with a long-instruction then filled by base tokens + tail_text = " Write a very long essay about San Francisco" + tail_tokens = self.tokenizer.encode(tail_text, add_special_tokens=False) + suffix_tokens: List[int] = [] + # Prefer to include tail once if it fits + if remaining > 0 and len(tail_tokens) <= remaining: + suffix_tokens.extend(tail_tokens) + remaining -= len(tail_tokens) + if remaining > 0: + suffix_tokens.extend(repeat_tokens_to_length(base_tokens, remaining)) + + full_tokens = prefix_tokens + marker_tokens + suffix_tokens + # Enforce exact length (truncate if any rounding issues) + if len(full_tokens) > target_tokens: + full_tokens = full_tokens[:target_tokens] + elif len(full_tokens) < target_tokens: + pad_token_id = ( + self.tokenizer.pad_token_id + if self.tokenizer.pad_token_id is not None + else (self.tokenizer.eos_token_id or base_tokens[0]) + ) + full_tokens.extend([pad_token_id] * (target_tokens - len(full_tokens))) + + return self.tokenizer.decode(full_tokens, skip_special_tokens=True) + + # Default path: assemble from dataset lines to desired token length data_copy = self.data.copy() prompt = "" left_tokens_to_sample = num_input_tokens diff --git a/genai_bench/scenarios/base.py b/genai_bench/scenarios/base.py index 83906be3..5a61c6a7 100644 --- a/genai_bench/scenarios/base.py +++ b/genai_bench/scenarios/base.py @@ -99,9 +99,9 @@ def from_string(cls, scenario_str: str) -> "Scenario": type_token = match.group(1) if match else scenario_str[0] cls.validate(scenario_str) scenario_class = cls._registry.get(type_token) - assert ( - scenario_class is not None - ), "scenario_class should not be None at this step" + assert scenario_class is not None, ( + "scenario_class should not be None at this step" + ) # Pass the parameter substring (if any) to parser params_str = scenario_str[len(type_token) :] return scenario_class.parse(params_str) diff --git a/genai_bench/user/aws_bedrock_user.py b/genai_bench/user/aws_bedrock_user.py index ec985e81..bac22286 100644 --- a/genai_bench/user/aws_bedrock_user.py +++ b/genai_bench/user/aws_bedrock_user.py @@ -45,8 +45,7 @@ def on_start(self): from botocore.config import Config except ImportError as e: raise ImportError( - "boto3 is required for AWS Bedrock. " - "Install it with: pip install boto3" + "boto3 is required for AWS Bedrock. Install it with: pip install boto3" ) from e # Get credentials from auth provider diff --git a/genai_bench/user/together_user.py b/genai_bench/user/together_user.py new file mode 100644 index 00000000..51751162 --- /dev/null +++ b/genai_bench/user/together_user.py @@ -0,0 +1,392 @@ +"""Customized user for Together backends.""" + +from locust import task + +import json +import random +import time +from typing import Any, Callable, Dict, Optional + +import requests +from requests import Response + +from genai_bench.auth.model_auth_provider import ModelAuthProvider +from genai_bench.logging import init_logger +from genai_bench.protocol import ( + UserChatRequest, + UserChatResponse, + UserEmbeddingRequest, + UserImageChatRequest, + UserResponse, +) +from genai_bench.user.base_user import BaseUser + +logger = init_logger(__name__) + + +class TogetherUser(BaseUser): + BACKEND_NAME = "together" + supported_tasks = { + "text-to-text": "chat", + "image-text-to-text": "chat", + "text-to-embeddings": "embeddings", + # Future support can be added here + } + + host: Optional[str] = None + auth_provider: Optional[ModelAuthProvider] = None + headers = None + + def on_start(self): + if not self.host or not self.auth_provider: + raise ValueError("API key and base must be set for TogetherUser.") + self.headers = { + "Authorization": f"Bearer {self.auth_provider.get_credentials()}", + "Content-Type": "application/json", + } + super().on_start() + + @task + def chat(self): + endpoint = "/v1/chat/completions" + user_request = self.sample() + + if not isinstance(user_request, UserChatRequest): + raise AttributeError( + f"user_request should be of type " + f"UserChatRequest for TogetherUser.chat, got " + f"{type(user_request)}" + ) + + if isinstance(user_request, UserImageChatRequest): + text_content = [{"type": "text", "text": user_request.prompt}] + image_content = [ + { + "type": "image_url", + "image_url": {"url": image}, + } + for image in user_request.image_content + ] + content = text_content + image_content + else: + # Backward compatibility for vLLM versions prior to v0.5.1. + # OpenAI API used a different text prompt format before + # multi-modality model support. + content = user_request.prompt + + payload = { + "model": user_request.model, + "messages": [ + { + "role": "user", + "content": content, + } + ], + "max_tokens": user_request.max_tokens, + "temperature": user_request.additional_request_params.get( + "temperature", 0.0 + ), + "ignore_eos": user_request.additional_request_params.get( + "ignore_eos", + bool(user_request.max_tokens), + ), + "stream": True, + "stream_options": { + "include_usage": True, + }, + **user_request.additional_request_params, + } + self.send_request( + True, + endpoint, + payload, + self.parse_chat_response, + user_request.num_prefill_tokens, + ) + + @task + def embeddings(self): + endpoint = "/v1/embeddings" + + user_request = self.sample() + + if not isinstance(user_request, UserEmbeddingRequest): + raise AttributeError( + f"user_request should be of type " + f"UserEmbeddingRequest for TogetherUser." + f"embeddings, got {type(user_request)}" + ) + + random.shuffle(user_request.documents) + payload = { + "model": user_request.model, + "input": user_request.documents, + "encoding_format": user_request.additional_request_params.get( + "encoding_format", "float" + ), + **user_request.additional_request_params, + } + self.send_request(False, endpoint, payload, self.parse_embedding_response) + + def send_request( + self, + stream: bool, + endpoint: str, + payload: Dict[str, Any], + parse_strategy: Callable[..., UserResponse], + num_prefill_tokens: Optional[int] = None, + ) -> UserResponse: + """ + Sends a POST request, handling both streaming and non-streaming + responses. + + Args: + endpoint (str): The API endpoint. + payload (Dict[str, Any]): The JSON payload for the request. + stream (bool): Whether to stream the response. + parse_strategy (Callable[[Response, float], UserResponse]): + The function to parse the response. + num_prefill_tokens (Optional[int]): The num of tokens in the + prefill/prompt. Only need for streaming requests. + + Returns: + UserResponse: A response object containing status and metrics data. + """ + response = None + + try: + start_time = time.monotonic() + response = requests.post( + url=f"{self.host}{endpoint}", + json=payload, + stream=stream, + headers=self.headers, + ) + non_stream_post_end_time = time.monotonic() + + if response.status_code == 200: + metrics_response = parse_strategy( + response, + start_time, + num_prefill_tokens, + non_stream_post_end_time, + ) + else: + metrics_response = UserResponse( + status_code=response.status_code, + error_message=response.text, + ) + except requests.exceptions.ConnectionError as e: + metrics_response = UserResponse( + status_code=503, error_message=f"Connection error: {e}" + ) + except requests.exceptions.Timeout as e: + metrics_response = UserResponse( + status_code=408, error_message=f"Request timed out: {e}" + ) + except requests.exceptions.RequestException as e: + metrics_response = UserResponse( + status_code=500, # Assign a generic 500 + error_message=str(e), + ) + finally: + if response is not None: + response.close() + + self.collect_metrics(metrics_response, endpoint) + return metrics_response + + def parse_chat_response( + self, + response: Response, + start_time: float, + num_prefill_tokens: int, + _: float, + ) -> UserResponse: + """ + Parses a streaming response. + + Args: + response (Response): The response object. + start_time (float): The time when the request was started. + num_prefill_tokens (int): The num of tokens in the prefill/prompt. + _ (float): Placeholder for an unused var, to keep parse_*_response + have the same interface. + + Returns: + UserChatResponse: A response object with metrics and generated text. + """ + stream_chunk_prefix = "data: " + end_chunk = b"[DONE]" + + generated_text = "" + tokens_received = 0 + time_at_first_token = None + finish_reason = None + previous_data = None + num_prompt_tokens = None + for chunk in response.iter_lines(chunk_size=None): + # Caution: Adding logs here can make debug mode unusable. + chunk = chunk.strip() + + if not chunk: + continue + + chunk = chunk[len(stream_chunk_prefix) :] + if chunk == end_chunk: + break + data = json.loads(chunk) + + # Handle streaming error response as OpenAI API server handles it + # differently. Some might return 200 first and generate error response + # later in the chunk + if data.get("error") is not None: + return UserResponse( + status_code=data["error"].get("code", -1), + error_message=data["error"].get( + "message", "Unknown error, please check server logs" + ), + ) + + # Standard OpenAI API streams include "finish_reason" + # in the second-to-last chunk, + # followed by "usage" in the final chunk, + # which does not contain "finish_reason" + if ( + not data["choices"] + and finish_reason + and "usage" in data + and data["usage"] + ): + num_prefill_tokens, num_prompt_tokens, tokens_received = ( + self._get_usage_info(data, num_prefill_tokens) + ) + # Additional check for time_at_first_token when the response is + # too short + if not time_at_first_token: + tokens_received = data["usage"].get("completion_tokens", 0) + if tokens_received > 1: + logger.warning( + f"🚨🚨🚨 The first chunk the server returned " + f"has >1 tokens: {tokens_received}. It will " + f"affect the accuracy of time_at_first_token!" + ) + time_at_first_token = time.monotonic() + else: + raise Exception("Invalid Response") + break + + try: + delta = data["choices"][0]["delta"] + content = delta.get("content") or delta.get("reasoning_content") + usage = delta.get("usage") + + if usage: + tokens_received = usage["completion_tokens"] + if content: + if not time_at_first_token: + if tokens_received > 1: + logger.warning( + f"🚨🚨🚨 The first chunk the server returned " + f"has >1 tokens: {tokens_received}. It will " + f"affect the accuracy of time_at_first_token!" + ) + time_at_first_token = time.monotonic() + generated_text += content + + finish_reason = data["choices"][0].get("finish_reason", None) + + # SGLang v0.4.3 to v0.4.7 has finish_reason and usage + # in the last chunk + if finish_reason and "usage" in data and data["usage"]: + num_prefill_tokens, num_prompt_tokens, tokens_received = ( + self._get_usage_info(data, num_prefill_tokens) + ) + break + + except (IndexError, KeyError) as e: + logger.warning( + f"Error processing chunk: {e}, data: {data}, " + f"previous_data: {previous_data}, " + f"finish_reason: {finish_reason}, skipping" + ) + + previous_data = data + + end_time = time.monotonic() + logger.debug( + f"Generated text: {generated_text} \n" + f"Time at first token: {time_at_first_token} \n" + f"Finish reason: {finish_reason}\n" + f"Prompt Tokens: {num_prompt_tokens} \n" + f"Completion Tokens: {tokens_received}\n" + f"Start Time: {start_time}\n" + f"End Time: {end_time}" + ) + + if not tokens_received: + tokens_received = self.environment.sampler.get_token_length( + generated_text, add_special_tokens=False + ) + logger.warning( + "🚨🚨🚨 There is no usage info returned from the model " + "server. Estimated tokens_received based on the model " + "tokenizer." + ) + return UserChatResponse( + status_code=200, + generated_text=generated_text, + tokens_received=tokens_received, + time_at_first_token=time_at_first_token, + num_prefill_tokens=num_prefill_tokens, + start_time=start_time, + end_time=end_time, + ) + + @staticmethod + def _get_usage_info(data, num_prefill_tokens): + num_prompt_tokens = data["usage"]["prompt_tokens"] + tokens_received = data["usage"]["completion_tokens"] + # For vision task + if num_prefill_tokens is None: + # use num_prompt_tokens as prefill to cover image tokens + num_prefill_tokens = num_prompt_tokens + if abs(num_prompt_tokens - num_prefill_tokens) >= 50: + logger.warning( + f"Significant difference detected in prompt tokens: " + f"The number of prompt tokens processed by the model " + f"server ({num_prompt_tokens}) differs from the number " + f"of prefill tokens returned by the sampler " + f"({num_prefill_tokens}) by " + f"{abs(num_prompt_tokens - num_prefill_tokens)} tokens." + ) + return num_prefill_tokens, num_prompt_tokens, tokens_received + + @staticmethod + def parse_embedding_response( + response: Response, start_time: float, _: Optional[int], end_time: float + ) -> UserResponse: + """ + Parses a non-streaming response. + + Args: + response (Response): The response object. + start_time (float): The time when the request was started. + _ (Optional[int]): Placeholder for an unused var, to keep + parse_*_response have the same interface. + end_time(float): The time when the request was finished. + + Returns: + UserResponse: A response object with metrics. + """ + + data = response.json() + num_prompt_tokens = data["usage"]["prompt_tokens"] + + return UserResponse( + status_code=200, + start_time=start_time, + end_time=end_time, + time_at_first_token=end_time, + num_prefill_tokens=num_prompt_tokens, + ) diff --git a/pyproject.toml b/pyproject.toml index 2739c5e6..c4d05b8a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "datasets>=3.1.0", "pillow>=11.1.0,<12.0.0", "huggingface_hub>=0.20.0", + "aiohttp>=3.9.0,<4.0.0", ] [project.scripts] diff --git a/tests/analysis/test_excel_na.py b/tests/analysis/test_excel_na.py index aa9be72c..8e0567c3 100644 --- a/tests/analysis/test_excel_na.py +++ b/tests/analysis/test_excel_na.py @@ -124,9 +124,9 @@ def test_time_unit_conversion_seconds_to_milliseconds(): # Check that TTFT column header shows milliseconds headers = [cell.value for cell in ws[1]] ttft_header = headers[3] # TTFT column - assert "ms" in str( - ttft_header - ), f"Expected TTFT header to show ms, got: {ttft_header}" + assert "ms" in str(ttft_header), ( + f"Expected TTFT header to show ms, got: {ttft_header}" + ) # Check that the actual TTFT value was converted from 0.5s to 500ms ttft_value = ws[2][3].value # Row 2, column 4 (TTFT value) @@ -134,9 +134,9 @@ def test_time_unit_conversion_seconds_to_milliseconds(): # Check that e2e_latency value was converted from 1.0s to 1000ms e2e_latency_value = ws[2][6].value # Row 2, column 7 (e2e_latency value) - assert ( - e2e_latency_value == 1000.0 - ), f"Expected e2e_latency value 1000.0ms, got: {e2e_latency_value}" + assert e2e_latency_value == 1000.0, ( + f"Expected e2e_latency value 1000.0ms, got: {e2e_latency_value}" + ) def test_time_unit_conversion_milliseconds_to_seconds(): @@ -170,9 +170,9 @@ def test_time_unit_conversion_milliseconds_to_seconds(): # Check that TTFT column header shows seconds headers = [cell.value for cell in ws[1]] ttft_header = headers[3] # TTFT column - assert "s" in str( - ttft_header - ), f"Expected TTFT header to show s, got: {ttft_header}" + assert "s" in str(ttft_header), ( + f"Expected TTFT header to show s, got: {ttft_header}" + ) # Check that the actual TTFT value was converted from 500ms to 0.5s ttft_value = ws[2][3].value # Row 2, column 4 (TTFT value) @@ -180,6 +180,6 @@ def test_time_unit_conversion_milliseconds_to_seconds(): # Check that e2e_latency value was converted from 1000ms to 1.0s e2e_latency_value = ws[2][6].value # Row 2, column 7 (e2e_latency value) - assert ( - e2e_latency_value == 1 - ), f"Expected e2e_latency value 1s, got: {e2e_latency_value}" + assert e2e_latency_value == 1, ( + f"Expected e2e_latency value 1s, got: {e2e_latency_value}" + ) diff --git a/tests/auth/test_auth_factory.py b/tests/auth/test_auth_factory.py index a73be8d6..7c61db58 100644 --- a/tests/auth/test_auth_factory.py +++ b/tests/auth/test_auth_factory.py @@ -6,6 +6,7 @@ from genai_bench.auth.oci.session import OCISessionAuth from genai_bench.auth.oci.user_principal import OCIUserPrincipalAuth from genai_bench.auth.openai.auth import OpenAIAuth +from genai_bench.auth.together.auth import TogetherAuth MOCK_API_KEY = "genai-bench-test-123456789" MOCK_CONFIG_PATH = "~/.oci/config" @@ -21,6 +22,12 @@ def test_create_openai_auth(self): assert isinstance(auth, OpenAIAuth) assert auth.api_key == MOCK_API_KEY + def test_create_together_auth(self): + """Test creating Together auth provider.""" + auth = AuthFactory.create_together_auth(MOCK_API_KEY) + assert isinstance(auth, TogetherAuth) + assert auth.api_key == MOCK_API_KEY + def test_create_oci_user_principal_auth(self): """Test creating OCI user principal auth.""" auth = AuthFactory.create_oci_auth( diff --git a/tests/auth/test_together_auth.py b/tests/auth/test_together_auth.py new file mode 100644 index 00000000..431f4e7e --- /dev/null +++ b/tests/auth/test_together_auth.py @@ -0,0 +1,61 @@ +import pytest + +from genai_bench.auth.auth_provider import AuthProvider +from genai_bench.auth.together.auth import TogetherAuth + +MOCK_API_KEY = "genai-bench-test-123456789" + + +class MockAuthProvider(AuthProvider): + """Mock implementation of AuthProvider for testing.""" + + def get_config(self): + return {} + + def get_credentials(self): + return "mock-credentials" + + +def test_auth_provider_abstract(): + """Test that AuthProvider cannot be instantiated directly.""" + with pytest.raises(TypeError): + AuthProvider() + + +class TestTogetherAuth: + def test_init_with_key(self): + """Test initialization with API key.""" + auth = TogetherAuth(api_key=MOCK_API_KEY) + assert auth.api_key == MOCK_API_KEY + + def test_init_with_env(self, monkeypatch): + """Test initialization with environment variable.""" + monkeypatch.setenv("TOGETHER_API_KEY", MOCK_API_KEY) + auth = TogetherAuth() + assert auth.api_key == MOCK_API_KEY + + def test_init_no_key(self, monkeypatch): + """Test initialization with no API key.""" + monkeypatch.delenv("TOGETHER_API_KEY", raising=False) + with pytest.raises(ValueError): + TogetherAuth() + + def test_init_empty_key(self): + """Test initialization with empty API key.""" + with pytest.raises(ValueError): + TogetherAuth(api_key="") + + def test_init_whitespace_key(self): + """Test initialization with whitespace API key.""" + with pytest.raises(ValueError): + TogetherAuth(api_key=" ") + + def test_get_config(self): + """Test getting Together config.""" + auth = TogetherAuth(api_key=MOCK_API_KEY) + assert auth.get_config() == {} + + def test_get_credentials(self): + """Test getting Together credentials.""" + auth = TogetherAuth(api_key=MOCK_API_KEY) + assert auth.get_credentials() == MOCK_API_KEY diff --git a/tests/auth/test_unified_factory.py b/tests/auth/test_unified_factory.py index 06b9d39c..a7748908 100644 --- a/tests/auth/test_unified_factory.py +++ b/tests/auth/test_unified_factory.py @@ -96,6 +96,15 @@ def test_create_gcp_vertex_model_auth(self): assert auth.location == "us-central1" assert auth.credentials_path == "/path/to/creds.json" + def test_create_together_model_auth(self): + """Test creating Together model auth provider.""" + auth = UnifiedAuthFactory.create_model_auth("together", api_key="test_key") + + assert isinstance(auth, ModelAuthProvider) + assert auth.get_auth_type() == "api_key" + creds = auth.get_credentials() + assert creds["api_key"] == "test_key" + def test_create_model_auth_unsupported(self): """Test creating model auth with unsupported provider.""" with pytest.raises(ValueError) as exc_info: diff --git a/tests/openloop/test_arrival_metrics.py b/tests/openloop/test_arrival_metrics.py new file mode 100644 index 00000000..7e4ab8ef --- /dev/null +++ b/tests/openloop/test_arrival_metrics.py @@ -0,0 +1,61 @@ +from unittest.mock import patch + +from genai_bench.openloop.runner import OpenLoopRunner +from genai_bench.metrics.aggregated_metrics_collector import AggregatedMetricsCollector +from genai_bench.protocol import UserChatRequest, UserResponse + + +class DummySampler: + def __init__(self, model: str = "dummy-model") -> None: + self.model = model + + def sample(self, scenario: str) -> UserChatRequest: + return UserChatRequest( + model=self.model, + prompt="Hello", + num_prefill_tokens=10, + max_tokens=10, + additional_request_params={}, + ) + + +def _build_runner() -> tuple[OpenLoopRunner, AggregatedMetricsCollector]: + aggregated = AggregatedMetricsCollector() + runner = OpenLoopRunner( + sampler=DummySampler(), + api_backend="openai", + api_base="https://example.com", + api_model_name="dummy-model", + auth_provider=None, + aggregated_metrics_collector=aggregated, + dashboard=None, + ) + return runner, aggregated + + +@patch.object(OpenLoopRunner, "_send_request", autospec=True) +def test_arrival_metrics_recorded(mock_send): + # Quick success stub + async def _ok(self, req): + return UserResponse(status_code=200, start_time=0.0, end_time=0.01, time_at_first_token=0.001, num_prefill_tokens=10) + + mock_send.side_effect = _ok + runner, aggregated = _build_runner() + + # qps=5, duration=2 -> 10 planned arrivals + with patch.object(OpenLoopRunner, "_wait_intervals", return_value=[0.0] * 10): + _ = runner.run( + qps_level=5, + duration_s=2, + distribution="constant", + random_seed=0, + max_requests=None, + max_time_s=None, + scenario="D(100,10)", + ) + + metrics = aggregated.aggregated_metrics + assert metrics.total_arrivals == 10 + assert abs(metrics.arrival_requests_per_second - 5.0) < 1e-6 + + diff --git a/tests/openloop/test_arrival_pacing.py b/tests/openloop/test_arrival_pacing.py new file mode 100644 index 00000000..bf5cc317 --- /dev/null +++ b/tests/openloop/test_arrival_pacing.py @@ -0,0 +1,69 @@ +import time +from unittest.mock import patch + +from genai_bench.openloop.runner import OpenLoopRunner +from genai_bench.metrics.aggregated_metrics_collector import AggregatedMetricsCollector +from genai_bench.protocol import UserChatRequest, UserResponse + + +class DummySampler: + def __init__(self, model: str = "dummy-model") -> None: + self.model = model + + def sample(self, scenario: str) -> UserChatRequest: + return UserChatRequest( + model=self.model, + prompt="Hello", + num_prefill_tokens=10, + max_tokens=10, + additional_request_params={}, + ) + + +def _build_runner() -> tuple[OpenLoopRunner, AggregatedMetricsCollector]: + aggregated = AggregatedMetricsCollector() + runner = OpenLoopRunner( + sampler=DummySampler(), + api_backend="openai", + api_base="https://example.com", + api_model_name="dummy-model", + auth_provider=None, + aggregated_metrics_collector=aggregated, + dashboard=None, + ) + return runner, aggregated + + +@patch.object(OpenLoopRunner, "_send_request", autospec=True) +def test_arrival_intervals_respected(mock_send): + # Record send times + send_times = [] + + async def _ok(self, req): + send_times.append(time.perf_counter()) + return UserResponse(status_code=200, start_time=0.0, end_time=0.01, time_at_first_token=0.001, num_prefill_tokens=10) + + mock_send.side_effect = _ok + runner, aggregated = _build_runner() + + # Intervals of 0.05s then 0.07s + with patch.object(OpenLoopRunner, "_wait_intervals", return_value=[0.05, 0.07]): + t0 = time.perf_counter() + _ = runner.run( + qps_level=2, + duration_s=1, + distribution="constant", + random_seed=0, + max_requests=None, + max_time_s=None, + scenario="D(100,10)", + ) + + # We should have two send timestamps. First send is immediate (no pre-wait), then wait ~0.05s + assert len(send_times) == 2 + d1 = send_times[0] - t0 + d2 = send_times[1] - send_times[0] + assert d1 <= 0.01 + assert 0.05 <= d2 <= 0.09 + + diff --git a/tests/openloop/test_midstream_error.py b/tests/openloop/test_midstream_error.py new file mode 100644 index 00000000..4d8999f8 --- /dev/null +++ b/tests/openloop/test_midstream_error.py @@ -0,0 +1,88 @@ +import asyncio +from typing import Tuple + +from unittest.mock import patch + +from genai_bench.openloop.runner import OpenLoopRunner +from genai_bench.metrics.aggregated_metrics_collector import AggregatedMetricsCollector +from genai_bench.protocol import UserChatRequest, UserChatResponse, UserResponse + + +class DummySampler: + def __init__(self, model: str = "dummy-model") -> None: + self.model = model + + def sample(self, scenario: str) -> UserChatRequest: + return UserChatRequest( + model=self.model, + prompt="Hello", + num_prefill_tokens=10, + max_tokens=10, + additional_request_params={}, + ) + + +def _build_runner() -> Tuple[OpenLoopRunner, AggregatedMetricsCollector]: + aggregated = AggregatedMetricsCollector() + runner = OpenLoopRunner( + sampler=DummySampler(), + api_backend="openai", + api_base="https://example.com", + api_model_name="dummy-model", + auth_provider=None, + aggregated_metrics_collector=aggregated, + dashboard=None, + ) + return runner, aggregated + + +def _ok_chat_resp() -> UserChatResponse: + return UserChatResponse( + status_code=200, + generated_text="abcde", + tokens_received=5, + time_at_first_token=0.02, + num_prefill_tokens=10, + start_time=0.0, + end_time=0.1, + ) + + +def _err_resp() -> UserResponse: + return UserResponse(status_code=500, error_message="mid-stream error") + + +@patch.object(OpenLoopRunner, "_send_request", autospec=True) +def test_midstream_error_recorded_without_blocking(mock_send): + # First request fails, next two succeed + async def _seq(self, req): + if not hasattr(_seq, "i"): + _seq.i = 0 # type: ignore[attr-defined] + _seq.i += 1 # type: ignore[attr-defined] + await asyncio.sleep(0) + if _seq.i == 1: # type: ignore[attr-defined] + return _err_resp() + return _ok_chat_resp() + + mock_send.side_effect = _seq + runner, aggregated = _build_runner() + + # Three arrivals immediately + with patch.object(OpenLoopRunner, "_wait_intervals", return_value=[0.0, 0.0, 0.0]): + total_run_time = runner.run( + qps_level=3, + duration_s=1, + distribution="constant", + random_seed=0, + max_requests=None, + max_time_s=None, + scenario="D(100,10)", + ) + + assert total_run_time >= 0 + # Two successes, one error + assert aggregated.aggregated_metrics.num_completed_requests == 2 + # Error recorded in frequency map + freq = aggregated.aggregated_metrics.error_codes_frequency + assert 500 in freq and freq[500] == 1 + diff --git a/tests/openloop/test_qps.py b/tests/openloop/test_qps.py new file mode 100644 index 00000000..8227e336 --- /dev/null +++ b/tests/openloop/test_qps.py @@ -0,0 +1,187 @@ +import time +import asyncio +from typing import Any, List + +import pytest +from unittest.mock import patch + +from genai_bench.openloop.runner import OpenLoopRunner +from genai_bench.metrics.aggregated_metrics_collector import AggregatedMetricsCollector +from genai_bench.protocol import UserChatRequest, UserResponse + + +class DummyAuth: + def get_credentials(self) -> str: + return "test-token" + + +class DummySampler: + def __init__(self, model: str = "dummy-model") -> None: + self.model = model + + def sample(self, scenario: str) -> UserChatRequest: + return UserChatRequest( + model=self.model, + prompt="Hello", + num_prefill_tokens=10, + max_tokens=10, + additional_request_params={}, + ) + + +class DummyResp: + def __init__(self, status_code: int = 200, prompt_tokens: int = 10, completion_tokens: int = 1) -> None: + self.status_code = status_code + self._prompt_tokens = prompt_tokens + self._completion_tokens = completion_tokens + self.text = "OK" + + def json(self) -> Any: + return { + "usage": { + "prompt_tokens": self._prompt_tokens, + "completion_tokens": self._completion_tokens, + } + } + + def close(self) -> None: + return None + + +def _build_runner() -> tuple[OpenLoopRunner, AggregatedMetricsCollector]: + aggregated = AggregatedMetricsCollector() + runner = OpenLoopRunner( + sampler=DummySampler(), + api_backend="openai", + api_base="https://example.com", + api_model_name="dummy-model", + auth_provider=DummyAuth(), + aggregated_metrics_collector=aggregated, + dashboard=None, + ) + return runner, aggregated + + +def test_wait_intervals_reproducible_and_count(): + runner, _ = _build_runner() + qps = 10 + duration = 2 + n = qps * duration + intervals_a: List[float] = runner._wait_intervals(qps, duration, random_seed=42, distribution="uniform") + intervals_b: List[float] = runner._wait_intervals(qps, duration, random_seed=42, distribution="uniform") + intervals_c: List[float] = runner._wait_intervals(qps, duration, random_seed=43, distribution="uniform") + assert len(intervals_a) == n + assert intervals_a == intervals_b + assert intervals_a != intervals_c + + +def test_wait_intervals_constant_distribution(): + runner, _ = _build_runner() + qps = 5 + duration = 3 + intervals = runner._wait_intervals(qps, duration, random_seed=123, distribution="constant") + assert all(abs(x - 1.0 / qps) < 1e-9 for x in intervals) + + +def test_wait_intervals_exponential_mean_close(): + runner, _ = _build_runner() + qps = 10 + duration = 100 # enough samples for mean to concentrate + intervals = runner._wait_intervals(qps, duration, random_seed=999, distribution="exponential") + empirical_mean = sum(intervals) / len(intervals) + assert abs(empirical_mean - (1.0 / qps)) < 0.05 # loose tolerance + + +def _ok_resp() -> UserResponse: + return UserResponse( + status_code=200, + start_time=0.0, + end_time=0.1, + time_at_first_token=0.02, + num_prefill_tokens=10, + ) + + +@patch.object(OpenLoopRunner, "_send_request", autospec=True) +def test_run_dispatches_exact_number_of_requests(mock_send): + async def _ok(self, req): + return _ok_resp() + mock_send.side_effect = _ok + runner, aggregated = _build_runner() + qps = 7 + duration = 2 + expected = qps * duration + + # Force zero intervals for a quick run + zero_intervals = [0.0] * expected + with patch.object(OpenLoopRunner, "_wait_intervals", return_value=zero_intervals): + total_run_time = runner.run( + qps_level=qps, + duration_s=duration, + distribution="uniform", + random_seed=42, + max_requests=None, + max_time_s=None, + scenario="D(100,100)", + ) + + assert total_run_time >= 0 + assert len(aggregated.all_request_metrics) == expected + + +@patch.object(OpenLoopRunner, "_send_request", autospec=True) +def test_run_respects_max_requests(mock_send): + async def _ok(self, req): + return _ok_resp() + mock_send.side_effect = _ok + runner, aggregated = _build_runner() + qps = 50 + duration = 2 + target = qps * duration + max_requests = 30 + + zero_intervals = [0.0] * target + with patch.object(OpenLoopRunner, "_wait_intervals", return_value=zero_intervals): + runner.run( + qps_level=qps, + duration_s=duration, + distribution="uniform", + random_seed=42, + max_requests=max_requests, + max_time_s=None, + scenario="D(100,100)", + ) + + assert len(aggregated.all_request_metrics) == max_requests + + +@patch.object(OpenLoopRunner, "_send_request", autospec=True) +def test_run_honors_timeout(mock_send): + async def _slow(self, req): + await asyncio.sleep(1.0) + return _ok_resp() + mock_send.side_effect = _slow + runner, aggregated = _build_runner() + qps = 5 + duration = 100 # many intervals + zero_intervals = [0.1] * (qps * 10) + + with patch.object(OpenLoopRunner, "_wait_intervals", return_value=zero_intervals): + start = time.monotonic() + runner.run( + qps_level=qps, + duration_s=duration, + distribution="uniform", + random_seed=42, + max_requests=None, + max_time_s=0.5, # time out early + scenario="D(100,100)", + ) + end = time.monotonic() + + # Should stop in around 0.5s, allow slack + assert (end - start) < 2.0 + # And we should have fewer requests than intervals list length + assert len(aggregated.all_request_metrics) < len(zero_intervals) + + diff --git a/tests/openloop/test_session_lifecycle.py b/tests/openloop/test_session_lifecycle.py new file mode 100644 index 00000000..fa9a397d --- /dev/null +++ b/tests/openloop/test_session_lifecycle.py @@ -0,0 +1,59 @@ +from unittest.mock import patch + +from genai_bench.openloop.runner import OpenLoopRunner +from genai_bench.metrics.aggregated_metrics_collector import AggregatedMetricsCollector +from genai_bench.protocol import UserChatRequest, UserResponse + + +class DummySampler: + def __init__(self, model: str = "dummy-model") -> None: + self.model = model + + def sample(self, scenario: str) -> UserChatRequest: + return UserChatRequest( + model=self.model, + prompt="Hello", + num_prefill_tokens=10, + max_tokens=10, + additional_request_params={}, + ) + + +def _build_runner() -> tuple[OpenLoopRunner, AggregatedMetricsCollector]: + aggregated = AggregatedMetricsCollector() + runner = OpenLoopRunner( + sampler=DummySampler(), + api_backend="openai", + api_base="https://example.com", + api_model_name="dummy-model", + auth_provider=None, + aggregated_metrics_collector=aggregated, + dashboard=None, + ) + return runner, aggregated + + +@patch.object(OpenLoopRunner, "_send_request", autospec=True) +def test_session_per_request_and_metrics(mock_send): + async def _ok(self, req): + return UserResponse(status_code=200, start_time=0.0, end_time=0.01, time_at_first_token=0.001, num_prefill_tokens=10) + + mock_send.side_effect = _ok + + runner, aggregated = _build_runner() + # Two immediate arrivals + with patch.object(OpenLoopRunner, "_wait_intervals", return_value=[0.0, 0.0]): + _ = runner.run( + qps_level=2, + duration_s=1, + distribution="constant", + random_seed=0, + max_requests=None, + max_time_s=None, + scenario="D(100,10)", + ) + + # Validate two completions were recorded + assert aggregated.aggregated_metrics.num_completed_requests == 2 + + diff --git a/tests/openloop/test_streaming.py b/tests/openloop/test_streaming.py new file mode 100644 index 00000000..ffec36a1 --- /dev/null +++ b/tests/openloop/test_streaming.py @@ -0,0 +1,76 @@ +import asyncio +from unittest.mock import patch + +from genai_bench.openloop.runner import OpenLoopRunner +from genai_bench.metrics.aggregated_metrics_collector import AggregatedMetricsCollector +from genai_bench.protocol import UserChatRequest, UserChatResponse + + +class DummySampler: + def __init__(self, model: str = "dummy-model") -> None: + self.model = model + + def sample(self, scenario: str) -> UserChatRequest: + return UserChatRequest( + model=self.model, + prompt="Hello", + num_prefill_tokens=10, + max_tokens=10, + additional_request_params={}, + ) + + +def _build_runner() -> tuple[OpenLoopRunner, AggregatedMetricsCollector]: + aggregated = AggregatedMetricsCollector() + runner = OpenLoopRunner( + sampler=DummySampler(), + api_backend="openai", + api_base="https://example.com", + api_model_name="dummy-model", + auth_provider=None, + aggregated_metrics_collector=aggregated, + dashboard=None, + ) + return runner, aggregated + + +def _ok_chat_resp() -> UserChatResponse: + return UserChatResponse( + status_code=200, + generated_text="abcde", + tokens_received=5, + time_at_first_token=0.02, + num_prefill_tokens=10, + start_time=0.0, + end_time=0.1, + ) + + +@patch.object(OpenLoopRunner, "_send_request", autospec=True) +def test_streaming_ttft_and_tokens(mock_send): + async def _ok(self, req): + # simulate small delay then a successful streaming completion + await asyncio.sleep(0) + return _ok_chat_resp() + + mock_send.side_effect = _ok + runner, aggregated = _build_runner() + + # 2 arrivals immediately + with patch.object(OpenLoopRunner, "_wait_intervals", return_value=[0.0, 0.0]): + total_run_time = runner.run( + qps_level=2, + duration_s=1, + distribution="constant", + random_seed=0, + max_requests=None, + max_time_s=None, + scenario="D(100,10)", + ) + + assert total_run_time >= 0 + assert aggregated.aggregated_metrics.num_completed_requests >= 1 + m = aggregated.all_request_metrics[0] + assert m.ttft is not None and m.ttft > 0 + assert m.num_output_tokens is not None and m.num_output_tokens > 0 + diff --git a/tests/openloop/test_timeout_semantics.py b/tests/openloop/test_timeout_semantics.py new file mode 100644 index 00000000..11848497 --- /dev/null +++ b/tests/openloop/test_timeout_semantics.py @@ -0,0 +1,64 @@ +import asyncio +import time +from unittest.mock import patch + +from genai_bench.openloop.runner import OpenLoopRunner +from genai_bench.metrics.aggregated_metrics_collector import AggregatedMetricsCollector +from genai_bench.protocol import UserChatRequest, UserResponse + + +class DummySampler: + def __init__(self, model: str = "dummy-model") -> None: + self.model = model + + def sample(self, scenario: str) -> UserChatRequest: + return UserChatRequest( + model=self.model, + prompt="Hello", + num_prefill_tokens=10, + max_tokens=10, + additional_request_params={}, + ) + + +def _build_runner() -> tuple[OpenLoopRunner, AggregatedMetricsCollector]: + aggregated = AggregatedMetricsCollector() + runner = OpenLoopRunner( + sampler=DummySampler(), + api_backend="openai", + api_base="https://example.com", + api_model_name="dummy-model", + auth_provider=None, + aggregated_metrics_collector=aggregated, + dashboard=None, + ) + return runner, aggregated + + +@patch.object(OpenLoopRunner, "_send_request", autospec=True) +def test_max_time_timeout(mock_send): + async def _slow(self, req): + await asyncio.sleep(1.0) + return UserResponse(status_code=200, start_time=0.0, end_time=1.0, time_at_first_token=0.5, num_prefill_tokens=10) + + mock_send.side_effect = _slow + runner, aggregated = _build_runner() + + with patch.object(OpenLoopRunner, "_wait_intervals", return_value=[0.0] * 10): + t0 = time.perf_counter() + _ = runner.run( + qps_level=10, + duration_s=10, + distribution="constant", + random_seed=0, + max_requests=None, + max_time_s=0.5, + scenario="D(100,10)", + ) + t1 = time.perf_counter() + + # Should exit near the timeout + assert (t1 - t0) < 2.0 + assert len(aggregated.all_request_metrics) < 10 + + diff --git a/tests/sampling/test_text.py b/tests/sampling/test_text.py index cc185ed9..6e8aa0df 100644 --- a/tests/sampling/test_text.py +++ b/tests/sampling/test_text.py @@ -9,6 +9,7 @@ from genai_bench.sampling.text import TextSampler from genai_bench.scenarios import DatasetScenario, EmbeddingScenario, NormalDistribution from genai_bench.scenarios.text import ReRankScenario +from genai_bench.data.config import DatasetConfig, DatasetSourceConfig class TestTextSampler(unittest.TestCase): @@ -183,6 +184,124 @@ def test_validate_scenario_invalid2(self): with self.assertRaises(ValueError): self.sampler._validate_scenario(invalid_scenario) + def test_synthetic_cached_prefix_and_exact_length(self): + """ + Ensure synthetic prompt generation matches tore-speed semantics: + - Exact input token length equals requested. + - Cached prefix occupies the first synthetic_cached_input_length tokens. + - Unique marker and tail are present after the prefix. + """ + + # Fake tokenizer that understands the synthetic builder pieces + class FakeTokenizer: + pad_token_id = 0 + eos_token_id = 100 + + def __init__(self): + self.base_phrase = "hi," + self.base_tokens = [1, 2] + self.marker_tokens = [8, 9, 10] + self.tail_text = " Write a very long essay about San Francisco" + self.tail_tokens = [3, 4, 5] + + def encode(self, text, add_special_tokens=False): + # Exact matches used by the builder + if text == self.base_phrase: + return list(self.base_tokens) + if text == self.tail_text: + return list(self.tail_tokens) + # Marker text during build contains '-' and digits + if any(ch.isdigit() or ch == '-' for ch in text): + return list(self.marker_tokens) + + # When re-encoding the final prompt for counting, parse known segments + tokens = [] + i = 0 + while i < len(text): + if text.startswith(self.base_phrase, i): + tokens.extend(self.base_tokens) + i += len(self.base_phrase) + elif text.startswith("MARK", i): + tokens.extend(self.marker_tokens) + i += len("MARK") + elif text.startswith(self.tail_text, i): + tokens.extend(self.tail_tokens) + i += len(self.tail_text) + else: + # Skip any other chars (spaces/pads) as 0-cost + i += 1 + return tokens + + def decode(self, tokens, skip_special_tokens=True): + # Reconstruct the text from token ids by chunking known patterns + out = [] + i = 0 + while i < len(tokens): + if tokens[i : i + 2] == self.base_tokens: + out.append(self.base_phrase) + i += 2 + elif tokens[i : i + 3] == self.marker_tokens: + out.append("MARK") + i += 3 + elif tokens[i : i + 3] == self.tail_tokens: + out.append(self.tail_text) + i += 3 + elif tokens[i] == self.pad_token_id: + i += 1 + else: + # Fallback for any stray token + i += 1 + return "".join(out) + + tokenizer = FakeTokenizer() + + # Synthetic config + cached_len = 3000 + target_in = 10000 + target_out = 825 + + ds_cfg = DatasetConfig( + source=DatasetSourceConfig(type="file", path=None, file_format="txt"), + prompt_column=None, + image_column=None, + synthetic=True, + synthetic_input_length=target_in, + synthetic_output_length=target_out, + synthetic_cached_input_length=cached_len, + ) + + sampler = TextSampler( + tokenizer=tokenizer, + model="mock_model", + output_modality="text", + data=["irrelevant"], + dataset_config=ds_cfg, + ) + + # Deterministic scenario D(10000,825) + class Deterministic: + scenario_type = NormalDistribution.scenario_type # irrelevant for validation in this path + def sample(self): + return target_in, target_out + + req = sampler.sample(Deterministic()) + + # Exact token length via tokenizer re-encode of final prompt + self.assertEqual(req.num_prefill_tokens, target_in) + + # Cached prefix check: leading repeats of base phrase + base_repeat = cached_len // len(tokenizer.base_tokens) + expected_prefix = tokenizer.base_phrase * base_repeat + self.assertTrue( + req.prompt.startswith(expected_prefix), + "Prompt does not start with expected cached prefix", + ) + + # Ensure marker and tail exist after the prefix + remaining = req.prompt[len(expected_prefix) :] + self.assertIn("MARK", remaining) + self.assertIn(tokenizer.tail_text, remaining) + def test_sample_text_exact_token_count(self): """ Test that _sample_text returns text with exact number of tokens diff --git a/tests/ui/test_dashboard.py b/tests/ui/test_dashboard.py index a2208e61..b147cac6 100644 --- a/tests/ui/test_dashboard.py +++ b/tests/ui/test_dashboard.py @@ -203,12 +203,12 @@ def test_scatter_plot_spacing_for_different_time_units(): assert label_line_ms is not None, "Could not find label line with milliseconds" # Verify the label spacing - assert ( - label_line_s.index("|") == 7 - ), f"Expected 7 spaces for seconds, got: {label_line_s.index('|')}" - assert ( - label_line_ms.index("|") == 9 - ), f"Expected 9 spaces for milliseconds, got: {label_line_ms.index('|')}" + assert label_line_s.index("|") == 7, ( + f"Expected 7 spaces for seconds, got: {label_line_s.index('|')}" + ) + assert label_line_ms.index("|") == 9, ( + f"Expected 9 spaces for milliseconds, got: {label_line_ms.index('|')}" + ) def test_minimal_dashboard_update_scatter_plot_does_not_crash(): diff --git a/tests/user/test_oci_genai_user.py b/tests/user/test_oci_genai_user.py index b7e844bb..309f19e6 100644 --- a/tests/user/test_oci_genai_user.py +++ b/tests/user/test_oci_genai_user.py @@ -46,8 +46,7 @@ def test_chat_grok_format(mock_client_class, test_genai_user): '"content": [{"type": "TEXT", "text": " world"}]}}' ) exclamation_msg = ( - '{"message": {"role": "ASSISTANT", ' - '"content": [{"type": "TEXT", "text": "!"}]}}' + '{"message": {"role": "ASSISTANT", "content": [{"type": "TEXT", "text": "!"}]}}' ) mock_client_instance.chat.return_value.data.events.return_value = iter(