diff --git a/docs/user-guide/run-benchmark.md b/docs/user-guide/run-benchmark.md index a7c2a56c..fbbefb94 100644 --- a/docs/user-guide/run-benchmark.md +++ b/docs/user-guide/run-benchmark.md @@ -180,6 +180,12 @@ For heavier traffic scenarios, like `D(16000,200)` or `D(128000,200)`, use the f --num-concurrency 32 \ ``` +To benchmark with prefix caching, you can make a given fraction of each prompt a common prefix with `--prompt-prefix-ratio`. For example, to set the first half of each prompt to a common prefix, use: + +```shell + --prompt-prefix-ratio 0.5 \ +``` + ## Distributed Benchmark If you see the message below in the genai-bench logs, it indicates that a single process is insufficient to generate the desired load. diff --git a/genai_bench/cli/cli.py b/genai_bench/cli/cli.py index 1ff09985..29d1c45e 100644 --- a/genai_bench/cli/cli.py +++ b/genai_bench/cli/cli.py @@ -122,6 +122,7 @@ def benchmark( spawn_rate, upload_results, namespace, + prompt_prefix_ratio, # Storage auth options storage_provider, storage_bucket, @@ -287,6 +288,7 @@ def benchmark( data=data, additional_request_params=additional_request_params, dataset_config=dataset_config_obj, + prompt_prefix_ratio=prompt_prefix_ratio, ) # If user did not provide scenarios but provided a dataset, default to dataset mode diff --git a/genai_bench/cli/option_groups.py b/genai_bench/cli/option_groups.py index f8ec4b60..10514b83 100644 --- a/genai_bench/cli/option_groups.py +++ b/genai_bench/cli/option_groups.py @@ -369,6 +369,14 @@ def server_options(func): # Group experiment-related options def experiment_options(func): + func = click.option( + "--prompt-prefix-ratio", + type=click.FloatRange(0.0, 1.0), + default=0.0, + help="The ratio of prefix length to overall input length " + "to prepend to all inputs to test prefix caching. " + "Value should be between 0.0 and 1.0. ", + )(func) func = click.option( "--experiment-folder-name", type=str, diff --git a/genai_bench/sampling/text.py b/genai_bench/sampling/text.py index 3f96ae42..ad714a75 100644 --- a/genai_bench/sampling/text.py +++ b/genai_bench/sampling/text.py @@ -31,6 +31,7 @@ def __init__( model: str, output_modality: str, data: List[str], + prompt_prefix_ratio: float = 0.0, additional_request_params: Optional[Dict[str, Any]] = None, dataset_config: Optional[DatasetConfig] = None, **kwargs, @@ -41,6 +42,8 @@ def __init__( self.data = data self.batch_size = 1 # Default batch size + self.prompt_prefix_ratio = prompt_prefix_ratio + self.prefix = "" def sample(self, scenario: Optional[Scenario]) -> UserRequest: """ @@ -165,6 +168,68 @@ def _validate_scenario(self, scenario: Optional[Scenario]) -> None: f"{type(scenario.scenario_type)}" ) + def _sample_prefix(self, current_prefix_length) -> str: + """ + Generates prefix of length current_prefix_length to be + prepended to all input prompts. + """ + + data_copy = self.data.copy() + + if not self.data: + raise ValueError("Cannot generate prefix from an empty dataset") + + prefix = "" + prefix_tokens_len = 0 + # Generate the prefix + while prefix_tokens_len < current_prefix_length: + random.shuffle(data_copy) + for line in data_copy: + line_tokens = self.tokenizer.encode(line) + num_line_tokens = len(line_tokens) + if prefix_tokens_len + num_line_tokens > current_prefix_length: + remaining_prefix_len = current_prefix_length - prefix_tokens_len + truncated_text = self.tokenizer.decode( + line_tokens[:remaining_prefix_len] + ) + prefix += truncated_text + return prefix + prefix += line + prefix_tokens_len = len(self.tokenizer.encode(prefix)) + + return prefix + + def _get_current_prefix(self, prefix_length: int) -> str: + """ + Returns the prefix for the current prompt of the specified length. + + Args: + current_prefix_length (int): The desired length of the prefix. + """ + + # Prefix of the current prompt being generated + current_prefix: str = self.prefix + + # Get the difference in length between the existing + # prefix and the desired prefix length + + current_prefix_tokens = self.tokenizer.encode(current_prefix) + current_prefix_length = len(current_prefix_tokens) + prefix_length_diff: int = prefix_length - current_prefix_length + + # Generate the prefix if it hasn't been created yet, or add + # to its length if it's not long enough + if prefix_length_diff > 0: + self.prefix += self._sample_prefix(prefix_length_diff) + current_prefix = self.prefix + + elif prefix_length_diff < 0: + # If the prefix is longer than needed, truncate it + current_prefix = self.tokenizer.decode( + current_prefix_tokens[:prefix_length] + ) + return current_prefix + def _sample_text(self, num_input_tokens: Optional[int]) -> str: """ Samples text from a list of lines based on the specified number of @@ -174,16 +239,40 @@ def _sample_text(self, num_input_tokens: Optional[int]) -> str: Args: num_input_tokens (int): The target number of input tokens. + Raises: + ValueError: if the prompt length is shorter than the prefix + length. + Returns: str: A text prompt containing the desired number of tokens. """ if not num_input_tokens: return random.choice(self.data) + # Calculate actual prefix length based on ratio or fixed length + current_prefix_length = 0 + if self.prompt_prefix_ratio > 0.0: + current_prefix_length = round(num_input_tokens * self.prompt_prefix_ratio) + data_copy = self.data.copy() - prompt = "" - left_tokens_to_sample = num_input_tokens + if not self.data: + raise ValueError("Cannot sample text from an empty dataset") + + if num_input_tokens <= current_prefix_length: + raise ValueError("Prefix length must be shorter than total input length") + + # Get the prompt prefix + current_prefix: str = self._get_current_prefix(current_prefix_length) + + # Prepend the prefix to all prompts with a randomly picked 4 digits + prompt = f"{current_prefix}{random.randint(1000,9999)}" + + prompt_tokens = self.tokenizer.encode(prompt) + left_tokens_to_sample = num_input_tokens - len(prompt_tokens) + + if left_tokens_to_sample < 0: + return self.tokenizer.decode(prompt_tokens[:num_input_tokens]) while left_tokens_to_sample > 0: random.shuffle(data_copy) for line in data_copy: diff --git a/tests/sampling/test_text.py b/tests/sampling/test_text.py index cc185ed9..758c4bcc 100644 --- a/tests/sampling/test_text.py +++ b/tests/sampling/test_text.py @@ -221,6 +221,11 @@ def mock_encode(text, add_special_tokens=False): # Count actual tokens in result # Need to handle mixed content (original lines + decoded text) total_tokens = 0 + + # All prompts start with 4 numbers, which are 1 token + total_tokens += 1 + result = result[4:] + # Split by our test lines to count tokens properly remaining = result for line in self.test_data: @@ -255,6 +260,88 @@ def test_sample_text_truncation(self): _ = self.sampler._sample_text(requested_tokens) # Verify decode was called with truncated tokens - self.tokenizer.decode.assert_called_with( - line_tokens[:requested_tokens], skip_special_tokens=True + self.tokenizer.decode.assert_called_with(line_tokens[:requested_tokens]) + + def test_sample_chat_prefix_ratio_request(self): + """Test prefix generation using ratio.""" + + # Mock encode to return list with length equal to number of characters in input + def mock_encode(text, add_special_tokens=False): + # ignore space + encoded_text = [1] * len(text.replace(" ", "")) + return encoded_text + + self.tokenizer.encode = mock_encode + + # Mock decode to return the original text + def mock_decode(tokens, skip_special_tokens=True): + if isinstance(tokens, list): + return "a" * len(tokens) # Return 'a' repeated for the token count + return "decoded_text" + + self.tokenizer.decode = mock_decode + + scenario = NormalDistribution( + mean_input_tokens=20, + stddev_input_tokens=0, + mean_output_tokens=20, + stddev_output_tokens=0, + ) + prefix_sampler = TextSampler( + tokenizer=self.tokenizer, + model=self.model, + output_modality=self.output_modality, + data=self.test_data, + prompt_prefix_ratio=0.5, # 50% of 20 tokens = 10 tokens + ) + result = prefix_sampler.sample(scenario) + self.assertIsInstance(result, UserChatRequest) + self.assertEqual(result.model, self.model) + self.assertTrue(isinstance(result.prompt, str)) + self.assertGreater(len(result.prompt), 0) + self.assertTrue(result.prompt.startswith(prefix_sampler.prefix)) + self.assertEqual(len(mock_encode(result.prompt)), 20) + + def test_short_prompt_request(self): + """Test that short prompts are handled correctly.""" + + def mock_encode(text, add_special_tokens=False): + return [1] * len(text) + + self.tokenizer.encode = mock_encode + + # Mock decode to return the original text + def mock_decode(tokens): + if isinstance(tokens, list): + return "a" * len(tokens) # Return 'a' repeated for the token count + return "decoded_text" + + self.tokenizer.decode = mock_decode + + self.sampler.data = ["2"] + + # Scenario asks for only 1 input token + scenario = NormalDistribution(1, 0, 1, 0) + + result = self.sampler.sample(scenario) + self.assertIsInstance(result, UserChatRequest) + # The prompt will be the 4-digit number, truncated to 1 char + self.assertEqual(len(result.prompt), 1) + self.assertGreater(len(result.prompt), 0) + + def test_empty_dataset(self): + """Test sampling from an empty dataset.""" + empty_sampler = TextSampler( + tokenizer=self.tokenizer, + model=self.model, + output_modality=self.output_modality, + data=[], + ) + scenario = NormalDistribution(10, 0, 10, 0) + + with self.assertRaises(ValueError) as context: + empty_sampler.sample(scenario) + + self.assertEqual( + str(context.exception), "Cannot sample text from an empty dataset" )