diff --git a/docs/configuration/env_variables.md b/docs/configuration/env_variables.md index 286043987..424d0fdb3 100644 --- a/docs/configuration/env_variables.md +++ b/docs/configuration/env_variables.md @@ -20,7 +20,7 @@ This document lists the supported diagnostic and profiling, as well as performan | Parameter name | Description | Default value | | ---------------------------- | ------------------------------------------------------------- | ------------- | | `VLLM_GRAPH_RESERVED_MEM` | Percentage of memory dedicated to HPUGraph capture. | `0.1` | -| `VLLM_EXPONENTIAL_BUCKETING` | Enables exponential bucket spacing instead of linear spacing. | `true` | +| `VLLM_EXPONENTIAL_BUCKETING` | Enables exponential bucket spacing instead of linear spacing. | `false` | | `VLLM_BUCKETING_FROM_FILE` | Enables reading bucket configuration from file | `None` | ## Developer Mode Parameters @@ -52,29 +52,38 @@ HPU PyTorch bridge environment variables impacting vLLM execution: `VLLM_{phase}_{dim}_BUCKET_{param}` is a collection of environment variables configuring ranges of linear bucketing mechanism, where: -- `{phase}` is either `PROMPT` or `DECODE` -- `{dim}` is either `BS`, `SEQ` or `BLOCK` -- `{param}` is either `MIN`, `STEP` or `MAX` +- `{phase}` is in `['PROMPT', 'DECODE']`. +- `{dim}` is in `['BS', 'QUERY', 'CTX']` for `PROMPT` phase or in `['BS', 'BLOCK']` for `DECODE` phase. +- `{param}` is in `['MIN', 'STEP', 'MAX', 'PAD_MAX', 'PAD_PERCENT']`. The following table lists the available variables with their default values: - -| Phase | Variable name | Default value | -| ------ | ------------------------------------------------- | -------------------------------------------- | -| Prompt | batch size min (`VLLM_PROMPT_BS_BUCKET_MIN`) | `1` | -| Prompt | batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`) | `1` | -| Prompt | batch size max (`VLLM_PROMPT_BS_BUCKET_MAX`) | `max_num_prefill_seqs` | -| Prompt | query length min (`VLLM_PROMPT_SEQ_BUCKET_MIN`) | `block_size` | -| Prompt | query length step (`VLLM_PROMPT_SEQ_BUCKET_STEP`) | `block_size` | -| Prompt | query length max (`VLLM_PROMPT_SEQ_BUCKET_MAX`) | `max_num_batched_tokens` | -| Prompt | sequence ctx min (`VLLM_PROMPT_CTX_BUCKET_MIN`) | `0` | -| Prompt | sequence ctx step (`VLLM_PROMPT_CTX_BUCKET_STEP`) | `1` | -| Prompt | sequence ctx max (`VLLM_PROMPT_CTX_BUCKET_MAX`) | `(max_model_len - block_size) // block_size` | -| Decode | batch size min (`VLLM_DECODE_BS_BUCKET_MIN`) | `1` | -| Decode | batch size step (`VLLM_DECODE_BS_BUCKET_STEP`) | `32` | -| Decode | batch size max (`VLLM_DECODE_BS_BUCKET_MAX`) | `max_num_seqs` | -| Decode | block size min (`VLLM_DECODE_BLOCK_BUCKET_MIN`) | `1` | -| Decode | block size step (`VLLM_DECODE_BLOCK_BUCKET_STEP`) | `block_size` | -| Decode | block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`) | `max_model_len * max_num_seqs // block_size`
by default or `max_blocks`
if `VLLM_CONTIGUOUS_PA = True` | +| Phase | Variable Name | Default Value | +|--------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------| +| **Prompt** | **Batch size min** (`VLLM_PROMPT_BS_BUCKET_MIN`) | `1` | +| | **Batch size step** (`VLLM_PROMPT_BS_BUCKET_STEP`) | `2` | +| | **Batch size max** (`VLLM_PROMPT_BS_BUCKET_MAX`) | `max_num_prefill_seqs` | +| | **Batch size max abs padding** (`VLLM_PROMPT_BS_BUCKET_PAD_MAX`) | `16` | +| | **Batch size max padding %** (`VLLM_PROMPT_BS_BUCKET_PAD_PERCENT`)| `25` | +| | **Query length min** (`VLLM_PROMPT_QUERY_BUCKET_MIN`) | `block_size` | +| | **Query length step** (`VLLM_PROMPT_QUERY_BUCKET_STEP`) | `block_size` | +| | **Query length max** (`VLLM_PROMPT_QUERY_BUCKET_MAX`) | `max_num_batched_tokens` | +| | **Query length max abs padding** (`VLLM_PROMPT_QUERY_BUCKET_PAD_MAX`) | `max_num_batched_tokens` | +| | **Query length max padding %** (`VLLM_PROMPT_QUERY_BUCKET_PAD_PERCENT`)| `25` | +| | **Sequence ctx min** (`VLLM_PROMPT_CTX_BUCKET_MIN`) | `0` | +| | **Sequence ctx step** (`VLLM_PROMPT_CTX_BUCKET_STEP`) | `2` | +| | **Sequence ctx max** (`VLLM_PROMPT_CTX_BUCKET_MAX`) | `(max_model_len - block_size) // block_size` | +| | **Sequence ctx max abs padding** (`VLLM_PROMPT_CTX_BUCKET_PAD_MAX`)| `max_num_batched_tokens // block_size` | +| | **Sequence ctx max padding %** (`VLLM_PROMPT_CTX_BUCKET_PAD_PERCENT`)| `25` | +| **Decode** | **Batch size min** (`VLLM_DECODE_BS_BUCKET_MIN`) | `1` | +| | **Batch size step** (`VLLM_DECODE_BS_BUCKET_STEP`) | `2` | +| | **Batch size max** (`VLLM_DECODE_BS_BUCKET_MAX`) | `max_num_seqs` | +| | **Batch size max abs padding** (`VLLM_DECODE_BS_BUCKET_PAD_MAX`) | `32` | +| | **Batch size max padding %** (`VLLM_DECODE_BS_BUCKET_PAD_PERCENT`)| `25` | +| | **Block size min** (`VLLM_DECODE_BLOCK_BUCKET_MIN`) | `block_size` | +| | **Block size step** (`VLLM_DECODE_BLOCK_BUCKET_STEP`) | `block_size` | +| | **Block size max** (`VLLM_DECODE_BLOCK_BUCKET_MAX`) | `max_model_len * max_num_seqs // block_size` (default)
or `max_blocks` if `VLLM_CONTIGUOUS_PA=True` | +| | **Block size max abs padding** (`VLLM_DECODE_BLOCK_BUCKET_PAD_MAX`)| `max_num_batched_tokens * max_num_seqs // block_size` | +| | **Block size max padding %** (`VLLM_DECODE_BLOCK_BUCKET_PAD_PERCENT`)| `25` | When a deployed workload does not use the full context a model can handle, we recommend you to limit the maximum values upfront, based on the expected input @@ -88,7 +97,7 @@ unnecessary and you can limit the values upfront. It reduces the startup time and warm-up. Recommended settings for this case are: - `--max_model_len`: `3072`, which is the sum of input and output sequences (1+2)*1024. -- `VLLM_PROMPT_SEQ_BUCKET_MAX`: `1024`, which is the maximum input token size that you expect to handle. +- `VLLM_PROMPT_QUERY_BUCKET_MAX`: `1024`, which is the maximum input token size that you expect to handle. !!! note If the model config specifies a high `max_model_len`, set it to the sum of `input_tokens` and `output_tokens`, rounded up to a multiple of `block_size` according to actual requirements. diff --git a/tests/unit_tests/test_bucketing.py b/tests/unit_tests/test_bucketing.py index b82214b0c..51b5c7c30 100644 --- a/tests/unit_tests/test_bucketing.py +++ b/tests/unit_tests/test_bucketing.py @@ -24,24 +24,26 @@ def test_read_bucket_settings(monkeypatch): monkeypatch.setenv("VLLM_PROMPT_BS_BUCKET_MIN", "1") monkeypatch.setenv("VLLM_PROMPT_BS_BUCKET_STEP", "16") monkeypatch.setenv("VLLM_PROMPT_BS_BUCKET_MAX", "64") - config = linear.read_bucket_settings("prompt", "bs", min=1, step=32, max=128) - assert config == [1, 16, 64] + monkeypatch.setenv("VLLM_PROMPT_BS_BUCKET_PAD_MAX", "32") + monkeypatch.setenv("VLLM_PROMPT_BS_BUCKET_PAD_PERCENT", "25") + config = linear.read_bucket_settings("prompt", "bs", min=1, step=32, max=128, pad_max=64, pad_percent=10) + assert config == [1, 16, 64, 32, 25] def test_read_bucket_settings_empty_flags(): - config = linear.read_bucket_settings("prompt", "bs", min=1, step=32, max=128) - assert config == [1, 32, 128] + config = linear.read_bucket_settings("prompt", "bs", min=1, step=32, max=128, pad_max=64, pad_percent=10) + assert config == [1, 32, 128, 64, 10] def test_warmup_range(): - config = (2, 64, 128) - result = linear.warmup_range(config) + config = (2, 64, 128, 64, 25) + result = linear.warmup_range_with_limits(config) assert result == [2, 4, 8, 16, 32, 64, 128] def test_warmup_range_with_one(): - config = (1, 64, 128) - result = linear.warmup_range(config) + config = (1, 64, 128, 64, 25) + result = linear.warmup_range_with_limits(config) assert result == [1, 2, 4, 8, 16, 32, 64, 128] diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py index 74c8bb14b..468100226 100644 --- a/vllm_gaudi/extension/bucketing/common.py +++ b/vllm_gaudi/extension/bucketing/common.py @@ -91,11 +91,8 @@ def read_from_file(self, is_prompt): def get_bucketing_strategy(self): strategy = None # TODO - we can use different strategies for decode and prompt - use_exponential_bucketing = True if \ - get_config().VLLM_EXPONENTIAL_BUCKETING == None else \ - get_config().VLLM_EXPONENTIAL_BUCKETING - if use_exponential_bucketing: + if get_config().VLLM_EXPONENTIAL_BUCKETING: from vllm_gaudi.extension.bucketing.exponential import (ExponentialBucketingStrategy) strategy = ExponentialBucketingStrategy() else: @@ -152,6 +149,9 @@ def generate_prompt_buckets(self): bs_range = strategy.get_range(bs_cfg) query_range = strategy.get_range(query_cfg) ctx_range = strategy.get_range(ctx_cfg) + logger().debug(f"Prompt BS range: {bs_range}") + logger().debug(f"Prompt query range: {query_range}") + logger().debug(f"Prompt context range: {ctx_range}") self.prompt_buckets = generate_buckets(bs_range, query_range, ctx_range, True, self.max_model_len, self.max_num_seqs, self.max_num_prefill_seqs, @@ -195,6 +195,10 @@ def generate_decode_buckets(self): if get_config().use_contiguous_pa and ctx_range[-1] < self.num_hpu_blocks: ctx_range.append(self.num_hpu_blocks) + logger().debug(f"Decode BS range: {bs_range}") + logger().debug(f"Decode query range: {query_range}") + logger().debug(f"Decode context range: {ctx_range}") + self.decode_buckets = generate_buckets(bs_range, query_range, ctx_range, False, self.max_model_len, self.max_num_seqs, self.max_num_prefill_seqs, self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks, diff --git a/vllm_gaudi/extension/bucketing/linear.py b/vllm_gaudi/extension/bucketing/linear.py index d50a8e57a..828895676 100644 --- a/vllm_gaudi/extension/bucketing/linear.py +++ b/vllm_gaudi/extension/bucketing/linear.py @@ -1,8 +1,5 @@ -import itertools -import operator import os import math -from dataclasses import dataclass, field from typing import List, Tuple from vllm_gaudi.extension.logger import logger as logger @@ -13,30 +10,46 @@ class LinearBucketingStrategy: def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_tokens, max_model_len): use_merged_prefill = get_config().merged_prefill - prefix_caching = get_config().prefix_caching - prompt_bs_bucket_cfg = read_bucket_settings('prompt', 'bs', min=1, step=1, max=max_num_prefill_seqs) + prompt_bs_bucket_cfg = read_bucket_settings('prompt', + 'bs', + min=1, + step=2, + max=max_num_prefill_seqs, + pad_max=16, + pad_percent=25) prompt_query_bucket_cfg = read_bucket_settings('prompt', 'query', min=block_size, step=block_size, - max=max_num_batched_tokens) + max=max_num_batched_tokens, + pad_max=max_num_batched_tokens, + pad_percent=25) max_ctx = math.ceil((max_model_len - prompt_query_bucket_cfg[0]) // block_size) - prompt_ctx_bucket_cfg = read_bucket_settings('prompt', 'ctx', min=0, step=1, max=max_ctx) + prompt_ctx_bucket_cfg = read_bucket_settings('prompt', + 'ctx', + min=0, + step=2, + max=max_ctx, + pad_max=max_num_batched_tokens // block_size, + pad_percent=25) if use_merged_prefill: prev_prompt_bs_bucket_cfg = tuple(prompt_bs_bucket_cfg) prev_prompt_query_bucket_cfg = tuple(prompt_query_bucket_cfg) prev_prompt_ctx_bucket_cfg = tuple(prompt_ctx_bucket_cfg) - prompt_bs_bucket_cfg = (1, 1, 1) - query_min, query_step, _ = prev_prompt_query_bucket_cfg - prompt_query_bucket_cfg = (query_min, query_step * 4, max_num_batched_tokens) + prompt_bs_bucket_cfg = (1, 1, 1, prev_prompt_bs_bucket_cfg[-2], prev_prompt_bs_bucket_cfg[-1]) + query_min, query_step, _, query_pad_max, query_pad_percent = prev_prompt_query_bucket_cfg + prompt_query_bucket_cfg = (query_min, query_step * 4, max_num_batched_tokens, query_pad_max, + query_pad_percent) prompt_ctx_bucket_cfg = read_bucket_settings('prompt', 'ctx', min=0, step=4, - max=max_ctx * max_num_prefill_seqs) + max=max_ctx * max_num_prefill_seqs, + pad_max=max_num_batched_tokens // block_size, + pad_percent=25) msg = ('Merged prefill is enabled!\n' 'Overriding prompt bucketing settings!\n' @@ -45,7 +58,7 @@ def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_toke f'prompt ctx cfg: {prev_prompt_ctx_bucket_cfg} -> {prompt_ctx_bucket_cfg}\n') logger().info(msg) - msg = ("Prompt bucket config (min, step, max_warmup) " + msg = ("Prompt bucket config (min, step, max_warmup, pad_max, pad_percent) " f"bs:{prompt_bs_bucket_cfg}, " f"query:{prompt_query_bucket_cfg}, " f"blocks:{prompt_ctx_bucket_cfg}") @@ -54,15 +67,26 @@ def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_toke return prompt_bs_bucket_cfg, prompt_query_bucket_cfg, prompt_ctx_bucket_cfg def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_model_len, max_blocks): - prefix_caching = get_config().prefix_caching contiguous_pa = get_config().use_contiguous_pa - decode_bs_bucket_cfg = read_bucket_settings('decode', 'bs', min=1, step=32, max=max_num_seqs) - decode_query_bucket_cfg = [1, 1, 1] + decode_bs_bucket_cfg = read_bucket_settings('decode', + 'bs', + min=1, + step=2, + max=max_num_seqs, + pad_max=32, + pad_percent=25) + decode_query_bucket_cfg = [1, 1, 1, 1, 1] max_decode_blocks = max(math.ceil(max_model_len * max_num_seqs // block_size), block_size) if contiguous_pa: max_decode_blocks = max_blocks - decode_block_bucket_cfg = read_bucket_settings('decode', 'block', min=1, step=block_size, max=max_decode_blocks) + decode_block_bucket_cfg = read_bucket_settings('decode', + 'block', + min=block_size, + step=block_size, + max=max_decode_blocks, + pad_max=max_num_batched_tokens * max_num_seqs // block_size, + pad_percent=25) if decode_block_bucket_cfg[2] > max_blocks: logger().info( f'VLLM_DECODE_BLOCK_BUCKET_MAX={decode_block_bucket_cfg[2]} is higher than max_blocks={max_blocks}. Your configuration VLLM_DECODE_BLOCK_BUCKET_MAX={decode_block_bucket_cfg[2]} will be overwritten to VLLM_DECODE_BLOCK_BUCKET_MAX={max_blocks}' @@ -75,7 +99,7 @@ def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_ ) decode_block_bucket_cfg[0] = decode_block_bucket_min - msg = ("Decode bucket config (min, step, max_warmup) " + msg = ("Decode bucket config (min, step, max_warmup, pad_max, pad_percent) " f"bs:{decode_bs_bucket_cfg}, " f"blocks:{decode_block_bucket_cfg}") logger().info(msg) @@ -83,7 +107,7 @@ def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_ return decode_bs_bucket_cfg, decode_query_bucket_cfg, decode_block_bucket_cfg def get_range(self, cfg): - range_for_cfg = warmup_range(cfg) + range_for_cfg = warmup_range_with_limits(cfg) return sorted(range_for_cfg) @@ -95,7 +119,7 @@ def read_bucket_settings(phase: str, dim: str, **defaults): param is either 'min', 'step' or 'max' example env variable: VLLM_DECODE_BS_BUCKET_STEP=128 """ - params = ['min', 'step', 'max'] + params = ['min', 'step', 'max', 'pad_max', 'pad_percent'] env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params] default_values = [defaults[p] for p in params] values = [] @@ -119,35 +143,62 @@ def read_bucket_settings(phase: str, dim: str, **defaults): return values -def warmup_range(config: Tuple[int, int, int]): - """Generate a warmup range. +def warmup_range_with_limits(config: Tuple[int, int, int, int, int]) -> List[int]: + """Generate a warmup range with absolute and relative padding limits. - Start from bmin and multiply by 2 until you reach bstep. - Then, increase the values in the range by the value of bstep until you - reach bmax. + 1. Starts from `bucket_min` and multiply by 2 (or +1 for 0) till to `bucket_step`. + 2. Add `bucket_step` to the values till to `bucket_max` and choose current bucket if: + a. the next bucket exceeds the absolute padding limit `pad_max`, + b. or the next bucket exceeds the padding ratio limit `pad_percent`, + c. or the current bucket is a multiple of `pad_max`. + 3. Always include `bucket_max` as the last bucket. Example: - bmin = 2, bstep = 32, bmax = 64 - => ramp_up = (2, 4, 8, 16) - => stable = (32, 64) - => return ramp_up + stable => (2, 4, 8, 16, 32, 64) + 1. for config = (0, 8, 64, 64, 0) + ramp_up = [0, 1, 2, 4, 8] + stable = [16, 24, 32, 40, 48, 56, 64] + return [0, 1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64] + 2. for config = (0, 8, 64, 64, 50) + ramp_up = [0, 1, 2, 4, 8] + stable = [16, 24, 32, 48, 64] # 40 and 56 are skipped due to padding ratio limit + return [0, 1, 2, 4, 8, 16, 24, 32, 48, 64] + 3. for config = (0, 8, 64, 16, 50) + ramp_up = [0, 1, 2, 4, 8] + stable = [16, 32, 48, 64] # 24, 40, 56 are skipped due to absolute padding limit + return [0, 1, 2, 4, 8, 16, 32, 48, 64] + 4. for config = (16, 16, 128, 32, 25) + stable = [16, 32, 48, 64, 80, 96, 112, 128] # no ramp up phase + return [16, 32, 48, 64, 80, 96, 112, 128] """ - bmin, bstep, bmax = config - add_zero_bucket = bmin == 0 - assert bmin <= bmax, ("Min. batch size cannot be greater than max. " - "batch size. If you want to skip warmup, " - "set VLLM_SKIP_WARMUP=true") - if add_zero_bucket: - if bmin == 0 and bmax == 0: - return [0] - bmin = bstep - base = itertools.repeat(2) - ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin) - ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \ - ramp_up_acc) - stable = range(bstep, bmax + 1, bstep) - buckets = list(ramp_up_tw) + list(stable) - buckets = [b for b in buckets if b >= bmin] - if add_zero_bucket: - buckets.append(0) - return list(buckets) + bucket_min, bucket_step, bucket_max, pad_max, pad_percent = config + assert bucket_min <= bucket_max, ("bucket_min cannot be greater than bucket_max. " + "If you want to skip warmup, set VLLM_SKIP_WARMUP=true") + assert bucket_step > 0, f"bucket_step must be positive, got: ({bucket_step})" + assert 0 <= pad_percent <= 50, f"pad_percent must be between 0 and 50 percentage points, got: ({pad_percent})" + + buckets = [bucket_min] + current_bucket = bucket_min + while current_bucket <= bucket_max: + last_bucket = buckets[-1] + if current_bucket <= bucket_step: + next_bucket = last_bucket * 2 + if next_bucket == 0: + next_bucket += 1 + if next_bucket <= bucket_max: + buckets.append(next_bucket) + else: + next_bucket = current_bucket + bucket_step + max_padding = next_bucket - last_bucket - 1 + max_padding_ratio = max_padding / next_bucket + keep_bucket = ( + max_padding_ratio > pad_percent / 100.0 # next bucket exceeds padding ratio limit + or max_padding > pad_max # next bucket exceeds absolute padding limit + or current_bucket % pad_max == 0 # current bucket is a multiple of pad_max + ) + if keep_bucket and current_bucket != last_bucket: + buckets.append(current_bucket) + current_bucket = next_bucket + if buckets[-1] != bucket_max: + buckets.append(bucket_max) + + return buckets diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index 5321b9809..2293229a4 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -17,22 +17,33 @@ def get_user_flags(): Env('VLLM_PROMPT_BS_BUCKET_MIN', int), Env('VLLM_PROMPT_BS_BUCKET_STEP', int), Env('VLLM_PROMPT_BS_BUCKET_MAX', int), + Env('VLLM_PROMPT_BS_BUCKET_PAD_MAX', int), + Env('VLLM_PROMPT_BS_BUCKET_PAD_PERCENT', int), Env('VLLM_PROMPT_QUERY_BUCKET_MIN', int), Env('VLLM_PROMPT_QUERY_BUCKET_STEP', int), Env('VLLM_PROMPT_QUERY_BUCKET_MAX', int), + Env('VLLM_PROMPT_QUERY_BUCKET_PAD_MAX', int), + Env('VLLM_PROMPT_QUERY_BUCKET_PAD_PERCENT', int), Env('VLLM_PROMPT_SEQ_BUCKET_MIN', int), Env('VLLM_PROMPT_SEQ_BUCKET_STEP', int), Env('VLLM_PROMPT_SEQ_BUCKET_MAX', int), + Env('VLLM_PROMPT_SEQ_BUCKET_PAD_MAX', int), + Env('VLLM_PROMPT_SEQ_BUCKET_PAD_PERCENT', int), Env('VLLM_PROMPT_CTX_BUCKET_MIN', int), Env('VLLM_PROMPT_CTX_BUCKET_STEP', int), Env('VLLM_PROMPT_CTX_BUCKET_MAX', int), + Env('VLLM_PROMPT_CTX_BUCKET_PAD_MAX', int), + Env('VLLM_PROMPT_CTX_BUCKET_PAD_PERCENT', int), Env('VLLM_DECODE_BS_BUCKET_MIN', int), Env('VLLM_DECODE_BS_BUCKET_STEP', int), Env('VLLM_DECODE_BS_BUCKET_MAX', int), + Env('VLLM_DECODE_BS_BUCKET_PAD_MAX', int), + Env('VLLM_DECODE_BS_BUCKET_PAD_PERCENT', int), Env('VLLM_DECODE_BLOCK_BUCKET_MIN', int), Env('VLLM_DECODE_BLOCK_BUCKET_STEP', int), Env('VLLM_DECODE_BLOCK_BUCKET_MAX', int), - Env('VLLM_DECODE_BLOCK_BUCKET_LIMIT', int), + Env('VLLM_DECODE_BLOCK_BUCKET_PAD_MAX', int), + Env('VLLM_DECODE_BLOCK_BUCKET_PAD_PERCENT', int), Env('VLLM_BUCKETING_FROM_FILE', str), # Non-vllm flags that are also important to print @@ -63,7 +74,7 @@ def get_experimental_flags(): def get_features(): supported_attn_impls = ['flex_impl', 'fsdpa_impl', 'naive_impl'] - bucketing_strategies = ['exponential_bucketing', 'linear_bucketing'] + bucketing_strategies = ['linear_bucketing', 'exponential_bucketing'] features = [ Value('fp32_alibi_biases', True, env_var='VLLM_ALIBI_USE_FLOAT32_BIASES'), Value('fp32_softmax', ModelType('qwen2')), @@ -82,8 +93,8 @@ def get_features(): Any(Disabled('prefix_caching'), Enabled('unified_attn')), env_var='VLLM_CONTIGUOUS_PA'), Value('use_bucketing', True, env_var='VLLM_ENABLE_BUCKETING'), - Value('exponential_bucketing', True), - Value('linear_bucketing', True), + Value('exponential_bucketing', False, env_var='VLLM_EXPONENTIAL_BUCKETING', env_var_type=boolean), + Value('linear_bucketing', Not(Enabled('exponential_bucketing'))), ValueFromList('bucketing_strategy', bucketing_strategies), Value('defrag', Enabled('unified_attn')), Value('regional_compilation', True, env_var='VLLM_T_COMPILE_REGIONAL_COMPILATION', env_var_type=boolean),