diff --git a/docs/configuration/env_variables.md b/docs/configuration/env_variables.md
index 286043987..424d0fdb3 100644
--- a/docs/configuration/env_variables.md
+++ b/docs/configuration/env_variables.md
@@ -20,7 +20,7 @@ This document lists the supported diagnostic and profiling, as well as performan
| Parameter name | Description | Default value |
| ---------------------------- | ------------------------------------------------------------- | ------------- |
| `VLLM_GRAPH_RESERVED_MEM` | Percentage of memory dedicated to HPUGraph capture. | `0.1` |
-| `VLLM_EXPONENTIAL_BUCKETING` | Enables exponential bucket spacing instead of linear spacing. | `true` |
+| `VLLM_EXPONENTIAL_BUCKETING` | Enables exponential bucket spacing instead of linear spacing. | `false` |
| `VLLM_BUCKETING_FROM_FILE` | Enables reading bucket configuration from file | `None` |
## Developer Mode Parameters
@@ -52,29 +52,38 @@ HPU PyTorch bridge environment variables impacting vLLM execution:
`VLLM_{phase}_{dim}_BUCKET_{param}` is a collection of environment variables configuring ranges of linear bucketing mechanism, where:
-- `{phase}` is either `PROMPT` or `DECODE`
-- `{dim}` is either `BS`, `SEQ` or `BLOCK`
-- `{param}` is either `MIN`, `STEP` or `MAX`
+- `{phase}` is in `['PROMPT', 'DECODE']`.
+- `{dim}` is in `['BS', 'QUERY', 'CTX']` for `PROMPT` phase or in `['BS', 'BLOCK']` for `DECODE` phase.
+- `{param}` is in `['MIN', 'STEP', 'MAX', 'PAD_MAX', 'PAD_PERCENT']`.
The following table lists the available variables with their default values:
-
-| Phase | Variable name | Default value |
-| ------ | ------------------------------------------------- | -------------------------------------------- |
-| Prompt | batch size min (`VLLM_PROMPT_BS_BUCKET_MIN`) | `1` |
-| Prompt | batch size step (`VLLM_PROMPT_BS_BUCKET_STEP`) | `1` |
-| Prompt | batch size max (`VLLM_PROMPT_BS_BUCKET_MAX`) | `max_num_prefill_seqs` |
-| Prompt | query length min (`VLLM_PROMPT_SEQ_BUCKET_MIN`) | `block_size` |
-| Prompt | query length step (`VLLM_PROMPT_SEQ_BUCKET_STEP`) | `block_size` |
-| Prompt | query length max (`VLLM_PROMPT_SEQ_BUCKET_MAX`) | `max_num_batched_tokens` |
-| Prompt | sequence ctx min (`VLLM_PROMPT_CTX_BUCKET_MIN`) | `0` |
-| Prompt | sequence ctx step (`VLLM_PROMPT_CTX_BUCKET_STEP`) | `1` |
-| Prompt | sequence ctx max (`VLLM_PROMPT_CTX_BUCKET_MAX`) | `(max_model_len - block_size) // block_size` |
-| Decode | batch size min (`VLLM_DECODE_BS_BUCKET_MIN`) | `1` |
-| Decode | batch size step (`VLLM_DECODE_BS_BUCKET_STEP`) | `32` |
-| Decode | batch size max (`VLLM_DECODE_BS_BUCKET_MAX`) | `max_num_seqs` |
-| Decode | block size min (`VLLM_DECODE_BLOCK_BUCKET_MIN`) | `1` |
-| Decode | block size step (`VLLM_DECODE_BLOCK_BUCKET_STEP`) | `block_size` |
-| Decode | block size max (`VLLM_DECODE_BLOCK_BUCKET_MAX`) | `max_model_len * max_num_seqs // block_size`
by default or `max_blocks`
if `VLLM_CONTIGUOUS_PA = True` |
+| Phase | Variable Name | Default Value |
+|--------|------------------------------------------------------------|----------------------------------------------------------------------------------------------------|
+| **Prompt** | **Batch size min** (`VLLM_PROMPT_BS_BUCKET_MIN`) | `1` |
+| | **Batch size step** (`VLLM_PROMPT_BS_BUCKET_STEP`) | `2` |
+| | **Batch size max** (`VLLM_PROMPT_BS_BUCKET_MAX`) | `max_num_prefill_seqs` |
+| | **Batch size max abs padding** (`VLLM_PROMPT_BS_BUCKET_PAD_MAX`) | `16` |
+| | **Batch size max padding %** (`VLLM_PROMPT_BS_BUCKET_PAD_PERCENT`)| `25` |
+| | **Query length min** (`VLLM_PROMPT_QUERY_BUCKET_MIN`) | `block_size` |
+| | **Query length step** (`VLLM_PROMPT_QUERY_BUCKET_STEP`) | `block_size` |
+| | **Query length max** (`VLLM_PROMPT_QUERY_BUCKET_MAX`) | `max_num_batched_tokens` |
+| | **Query length max abs padding** (`VLLM_PROMPT_QUERY_BUCKET_PAD_MAX`) | `max_num_batched_tokens` |
+| | **Query length max padding %** (`VLLM_PROMPT_QUERY_BUCKET_PAD_PERCENT`)| `25` |
+| | **Sequence ctx min** (`VLLM_PROMPT_CTX_BUCKET_MIN`) | `0` |
+| | **Sequence ctx step** (`VLLM_PROMPT_CTX_BUCKET_STEP`) | `2` |
+| | **Sequence ctx max** (`VLLM_PROMPT_CTX_BUCKET_MAX`) | `(max_model_len - block_size) // block_size` |
+| | **Sequence ctx max abs padding** (`VLLM_PROMPT_CTX_BUCKET_PAD_MAX`)| `max_num_batched_tokens // block_size` |
+| | **Sequence ctx max padding %** (`VLLM_PROMPT_CTX_BUCKET_PAD_PERCENT`)| `25` |
+| **Decode** | **Batch size min** (`VLLM_DECODE_BS_BUCKET_MIN`) | `1` |
+| | **Batch size step** (`VLLM_DECODE_BS_BUCKET_STEP`) | `2` |
+| | **Batch size max** (`VLLM_DECODE_BS_BUCKET_MAX`) | `max_num_seqs` |
+| | **Batch size max abs padding** (`VLLM_DECODE_BS_BUCKET_PAD_MAX`) | `32` |
+| | **Batch size max padding %** (`VLLM_DECODE_BS_BUCKET_PAD_PERCENT`)| `25` |
+| | **Block size min** (`VLLM_DECODE_BLOCK_BUCKET_MIN`) | `block_size` |
+| | **Block size step** (`VLLM_DECODE_BLOCK_BUCKET_STEP`) | `block_size` |
+| | **Block size max** (`VLLM_DECODE_BLOCK_BUCKET_MAX`) | `max_model_len * max_num_seqs // block_size` (default)
or `max_blocks` if `VLLM_CONTIGUOUS_PA=True` |
+| | **Block size max abs padding** (`VLLM_DECODE_BLOCK_BUCKET_PAD_MAX`)| `max_num_batched_tokens * max_num_seqs // block_size` |
+| | **Block size max padding %** (`VLLM_DECODE_BLOCK_BUCKET_PAD_PERCENT`)| `25` |
When a deployed workload does not use the full context a model can handle, we
recommend you to limit the maximum values upfront, based on the expected input
@@ -88,7 +97,7 @@ unnecessary and you can limit the values upfront. It reduces the startup time
and warm-up. Recommended settings for this case are:
- `--max_model_len`: `3072`, which is the sum of input and output sequences (1+2)*1024.
-- `VLLM_PROMPT_SEQ_BUCKET_MAX`: `1024`, which is the maximum input token size that you expect to handle.
+- `VLLM_PROMPT_QUERY_BUCKET_MAX`: `1024`, which is the maximum input token size that you expect to handle.
!!! note
If the model config specifies a high `max_model_len`, set it to the sum of `input_tokens` and `output_tokens`, rounded up to a multiple of `block_size` according to actual requirements.
diff --git a/tests/unit_tests/test_bucketing.py b/tests/unit_tests/test_bucketing.py
index b82214b0c..51b5c7c30 100644
--- a/tests/unit_tests/test_bucketing.py
+++ b/tests/unit_tests/test_bucketing.py
@@ -24,24 +24,26 @@ def test_read_bucket_settings(monkeypatch):
monkeypatch.setenv("VLLM_PROMPT_BS_BUCKET_MIN", "1")
monkeypatch.setenv("VLLM_PROMPT_BS_BUCKET_STEP", "16")
monkeypatch.setenv("VLLM_PROMPT_BS_BUCKET_MAX", "64")
- config = linear.read_bucket_settings("prompt", "bs", min=1, step=32, max=128)
- assert config == [1, 16, 64]
+ monkeypatch.setenv("VLLM_PROMPT_BS_BUCKET_PAD_MAX", "32")
+ monkeypatch.setenv("VLLM_PROMPT_BS_BUCKET_PAD_PERCENT", "25")
+ config = linear.read_bucket_settings("prompt", "bs", min=1, step=32, max=128, pad_max=64, pad_percent=10)
+ assert config == [1, 16, 64, 32, 25]
def test_read_bucket_settings_empty_flags():
- config = linear.read_bucket_settings("prompt", "bs", min=1, step=32, max=128)
- assert config == [1, 32, 128]
+ config = linear.read_bucket_settings("prompt", "bs", min=1, step=32, max=128, pad_max=64, pad_percent=10)
+ assert config == [1, 32, 128, 64, 10]
def test_warmup_range():
- config = (2, 64, 128)
- result = linear.warmup_range(config)
+ config = (2, 64, 128, 64, 25)
+ result = linear.warmup_range_with_limits(config)
assert result == [2, 4, 8, 16, 32, 64, 128]
def test_warmup_range_with_one():
- config = (1, 64, 128)
- result = linear.warmup_range(config)
+ config = (1, 64, 128, 64, 25)
+ result = linear.warmup_range_with_limits(config)
assert result == [1, 2, 4, 8, 16, 32, 64, 128]
diff --git a/vllm_gaudi/extension/bucketing/common.py b/vllm_gaudi/extension/bucketing/common.py
index 74c8bb14b..468100226 100644
--- a/vllm_gaudi/extension/bucketing/common.py
+++ b/vllm_gaudi/extension/bucketing/common.py
@@ -91,11 +91,8 @@ def read_from_file(self, is_prompt):
def get_bucketing_strategy(self):
strategy = None
# TODO - we can use different strategies for decode and prompt
- use_exponential_bucketing = True if \
- get_config().VLLM_EXPONENTIAL_BUCKETING == None else \
- get_config().VLLM_EXPONENTIAL_BUCKETING
- if use_exponential_bucketing:
+ if get_config().VLLM_EXPONENTIAL_BUCKETING:
from vllm_gaudi.extension.bucketing.exponential import (ExponentialBucketingStrategy)
strategy = ExponentialBucketingStrategy()
else:
@@ -152,6 +149,9 @@ def generate_prompt_buckets(self):
bs_range = strategy.get_range(bs_cfg)
query_range = strategy.get_range(query_cfg)
ctx_range = strategy.get_range(ctx_cfg)
+ logger().debug(f"Prompt BS range: {bs_range}")
+ logger().debug(f"Prompt query range: {query_range}")
+ logger().debug(f"Prompt context range: {ctx_range}")
self.prompt_buckets = generate_buckets(bs_range, query_range, ctx_range, True, self.max_model_len,
self.max_num_seqs, self.max_num_prefill_seqs,
@@ -195,6 +195,10 @@ def generate_decode_buckets(self):
if get_config().use_contiguous_pa and ctx_range[-1] < self.num_hpu_blocks:
ctx_range.append(self.num_hpu_blocks)
+ logger().debug(f"Decode BS range: {bs_range}")
+ logger().debug(f"Decode query range: {query_range}")
+ logger().debug(f"Decode context range: {ctx_range}")
+
self.decode_buckets = generate_buckets(bs_range, query_range, ctx_range, False, self.max_model_len,
self.max_num_seqs, self.max_num_prefill_seqs,
self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks,
diff --git a/vllm_gaudi/extension/bucketing/linear.py b/vllm_gaudi/extension/bucketing/linear.py
index d50a8e57a..828895676 100644
--- a/vllm_gaudi/extension/bucketing/linear.py
+++ b/vllm_gaudi/extension/bucketing/linear.py
@@ -1,8 +1,5 @@
-import itertools
-import operator
import os
import math
-from dataclasses import dataclass, field
from typing import List, Tuple
from vllm_gaudi.extension.logger import logger as logger
@@ -13,30 +10,46 @@ class LinearBucketingStrategy:
def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_tokens, max_model_len):
use_merged_prefill = get_config().merged_prefill
- prefix_caching = get_config().prefix_caching
- prompt_bs_bucket_cfg = read_bucket_settings('prompt', 'bs', min=1, step=1, max=max_num_prefill_seqs)
+ prompt_bs_bucket_cfg = read_bucket_settings('prompt',
+ 'bs',
+ min=1,
+ step=2,
+ max=max_num_prefill_seqs,
+ pad_max=16,
+ pad_percent=25)
prompt_query_bucket_cfg = read_bucket_settings('prompt',
'query',
min=block_size,
step=block_size,
- max=max_num_batched_tokens)
+ max=max_num_batched_tokens,
+ pad_max=max_num_batched_tokens,
+ pad_percent=25)
max_ctx = math.ceil((max_model_len - prompt_query_bucket_cfg[0]) // block_size)
- prompt_ctx_bucket_cfg = read_bucket_settings('prompt', 'ctx', min=0, step=1, max=max_ctx)
+ prompt_ctx_bucket_cfg = read_bucket_settings('prompt',
+ 'ctx',
+ min=0,
+ step=2,
+ max=max_ctx,
+ pad_max=max_num_batched_tokens // block_size,
+ pad_percent=25)
if use_merged_prefill:
prev_prompt_bs_bucket_cfg = tuple(prompt_bs_bucket_cfg)
prev_prompt_query_bucket_cfg = tuple(prompt_query_bucket_cfg)
prev_prompt_ctx_bucket_cfg = tuple(prompt_ctx_bucket_cfg)
- prompt_bs_bucket_cfg = (1, 1, 1)
- query_min, query_step, _ = prev_prompt_query_bucket_cfg
- prompt_query_bucket_cfg = (query_min, query_step * 4, max_num_batched_tokens)
+ prompt_bs_bucket_cfg = (1, 1, 1, prev_prompt_bs_bucket_cfg[-2], prev_prompt_bs_bucket_cfg[-1])
+ query_min, query_step, _, query_pad_max, query_pad_percent = prev_prompt_query_bucket_cfg
+ prompt_query_bucket_cfg = (query_min, query_step * 4, max_num_batched_tokens, query_pad_max,
+ query_pad_percent)
prompt_ctx_bucket_cfg = read_bucket_settings('prompt',
'ctx',
min=0,
step=4,
- max=max_ctx * max_num_prefill_seqs)
+ max=max_ctx * max_num_prefill_seqs,
+ pad_max=max_num_batched_tokens // block_size,
+ pad_percent=25)
msg = ('Merged prefill is enabled!\n'
'Overriding prompt bucketing settings!\n'
@@ -45,7 +58,7 @@ def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_toke
f'prompt ctx cfg: {prev_prompt_ctx_bucket_cfg} -> {prompt_ctx_bucket_cfg}\n')
logger().info(msg)
- msg = ("Prompt bucket config (min, step, max_warmup) "
+ msg = ("Prompt bucket config (min, step, max_warmup, pad_max, pad_percent) "
f"bs:{prompt_bs_bucket_cfg}, "
f"query:{prompt_query_bucket_cfg}, "
f"blocks:{prompt_ctx_bucket_cfg}")
@@ -54,15 +67,26 @@ def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_toke
return prompt_bs_bucket_cfg, prompt_query_bucket_cfg, prompt_ctx_bucket_cfg
def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_model_len, max_blocks):
- prefix_caching = get_config().prefix_caching
contiguous_pa = get_config().use_contiguous_pa
- decode_bs_bucket_cfg = read_bucket_settings('decode', 'bs', min=1, step=32, max=max_num_seqs)
- decode_query_bucket_cfg = [1, 1, 1]
+ decode_bs_bucket_cfg = read_bucket_settings('decode',
+ 'bs',
+ min=1,
+ step=2,
+ max=max_num_seqs,
+ pad_max=32,
+ pad_percent=25)
+ decode_query_bucket_cfg = [1, 1, 1, 1, 1]
max_decode_blocks = max(math.ceil(max_model_len * max_num_seqs // block_size), block_size)
if contiguous_pa:
max_decode_blocks = max_blocks
- decode_block_bucket_cfg = read_bucket_settings('decode', 'block', min=1, step=block_size, max=max_decode_blocks)
+ decode_block_bucket_cfg = read_bucket_settings('decode',
+ 'block',
+ min=block_size,
+ step=block_size,
+ max=max_decode_blocks,
+ pad_max=max_num_batched_tokens * max_num_seqs // block_size,
+ pad_percent=25)
if decode_block_bucket_cfg[2] > max_blocks:
logger().info(
f'VLLM_DECODE_BLOCK_BUCKET_MAX={decode_block_bucket_cfg[2]} is higher than max_blocks={max_blocks}. Your configuration VLLM_DECODE_BLOCK_BUCKET_MAX={decode_block_bucket_cfg[2]} will be overwritten to VLLM_DECODE_BLOCK_BUCKET_MAX={max_blocks}'
@@ -75,7 +99,7 @@ def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_
)
decode_block_bucket_cfg[0] = decode_block_bucket_min
- msg = ("Decode bucket config (min, step, max_warmup) "
+ msg = ("Decode bucket config (min, step, max_warmup, pad_max, pad_percent) "
f"bs:{decode_bs_bucket_cfg}, "
f"blocks:{decode_block_bucket_cfg}")
logger().info(msg)
@@ -83,7 +107,7 @@ def get_decode_cfgs(self, max_num_seqs, block_size, max_num_batched_tokens, max_
return decode_bs_bucket_cfg, decode_query_bucket_cfg, decode_block_bucket_cfg
def get_range(self, cfg):
- range_for_cfg = warmup_range(cfg)
+ range_for_cfg = warmup_range_with_limits(cfg)
return sorted(range_for_cfg)
@@ -95,7 +119,7 @@ def read_bucket_settings(phase: str, dim: str, **defaults):
param is either 'min', 'step' or 'max'
example env variable: VLLM_DECODE_BS_BUCKET_STEP=128
"""
- params = ['min', 'step', 'max']
+ params = ['min', 'step', 'max', 'pad_max', 'pad_percent']
env_vars = [f'VLLM_{phase}_{dim}_BUCKET_{p}'.upper() for p in params]
default_values = [defaults[p] for p in params]
values = []
@@ -119,35 +143,62 @@ def read_bucket_settings(phase: str, dim: str, **defaults):
return values
-def warmup_range(config: Tuple[int, int, int]):
- """Generate a warmup range.
+def warmup_range_with_limits(config: Tuple[int, int, int, int, int]) -> List[int]:
+ """Generate a warmup range with absolute and relative padding limits.
- Start from bmin and multiply by 2 until you reach bstep.
- Then, increase the values in the range by the value of bstep until you
- reach bmax.
+ 1. Starts from `bucket_min` and multiply by 2 (or +1 for 0) till to `bucket_step`.
+ 2. Add `bucket_step` to the values till to `bucket_max` and choose current bucket if:
+ a. the next bucket exceeds the absolute padding limit `pad_max`,
+ b. or the next bucket exceeds the padding ratio limit `pad_percent`,
+ c. or the current bucket is a multiple of `pad_max`.
+ 3. Always include `bucket_max` as the last bucket.
Example:
- bmin = 2, bstep = 32, bmax = 64
- => ramp_up = (2, 4, 8, 16)
- => stable = (32, 64)
- => return ramp_up + stable => (2, 4, 8, 16, 32, 64)
+ 1. for config = (0, 8, 64, 64, 0)
+ ramp_up = [0, 1, 2, 4, 8]
+ stable = [16, 24, 32, 40, 48, 56, 64]
+ return [0, 1, 2, 4, 8, 16, 24, 32, 40, 48, 56, 64]
+ 2. for config = (0, 8, 64, 64, 50)
+ ramp_up = [0, 1, 2, 4, 8]
+ stable = [16, 24, 32, 48, 64] # 40 and 56 are skipped due to padding ratio limit
+ return [0, 1, 2, 4, 8, 16, 24, 32, 48, 64]
+ 3. for config = (0, 8, 64, 16, 50)
+ ramp_up = [0, 1, 2, 4, 8]
+ stable = [16, 32, 48, 64] # 24, 40, 56 are skipped due to absolute padding limit
+ return [0, 1, 2, 4, 8, 16, 32, 48, 64]
+ 4. for config = (16, 16, 128, 32, 25)
+ stable = [16, 32, 48, 64, 80, 96, 112, 128] # no ramp up phase
+ return [16, 32, 48, 64, 80, 96, 112, 128]
"""
- bmin, bstep, bmax = config
- add_zero_bucket = bmin == 0
- assert bmin <= bmax, ("Min. batch size cannot be greater than max. "
- "batch size. If you want to skip warmup, "
- "set VLLM_SKIP_WARMUP=true")
- if add_zero_bucket:
- if bmin == 0 and bmax == 0:
- return [0]
- bmin = bstep
- base = itertools.repeat(2)
- ramp_up_acc = itertools.accumulate(base, func=operator.mul, initial=bmin)
- ramp_up_tw = itertools.takewhile(lambda x: x < bstep and x <= bmax, \
- ramp_up_acc)
- stable = range(bstep, bmax + 1, bstep)
- buckets = list(ramp_up_tw) + list(stable)
- buckets = [b for b in buckets if b >= bmin]
- if add_zero_bucket:
- buckets.append(0)
- return list(buckets)
+ bucket_min, bucket_step, bucket_max, pad_max, pad_percent = config
+ assert bucket_min <= bucket_max, ("bucket_min cannot be greater than bucket_max. "
+ "If you want to skip warmup, set VLLM_SKIP_WARMUP=true")
+ assert bucket_step > 0, f"bucket_step must be positive, got: ({bucket_step})"
+ assert 0 <= pad_percent <= 50, f"pad_percent must be between 0 and 50 percentage points, got: ({pad_percent})"
+
+ buckets = [bucket_min]
+ current_bucket = bucket_min
+ while current_bucket <= bucket_max:
+ last_bucket = buckets[-1]
+ if current_bucket <= bucket_step:
+ next_bucket = last_bucket * 2
+ if next_bucket == 0:
+ next_bucket += 1
+ if next_bucket <= bucket_max:
+ buckets.append(next_bucket)
+ else:
+ next_bucket = current_bucket + bucket_step
+ max_padding = next_bucket - last_bucket - 1
+ max_padding_ratio = max_padding / next_bucket
+ keep_bucket = (
+ max_padding_ratio > pad_percent / 100.0 # next bucket exceeds padding ratio limit
+ or max_padding > pad_max # next bucket exceeds absolute padding limit
+ or current_bucket % pad_max == 0 # current bucket is a multiple of pad_max
+ )
+ if keep_bucket and current_bucket != last_bucket:
+ buckets.append(current_bucket)
+ current_bucket = next_bucket
+ if buckets[-1] != bucket_max:
+ buckets.append(bucket_max)
+
+ return buckets
diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py
index 5321b9809..2293229a4 100644
--- a/vllm_gaudi/extension/features.py
+++ b/vllm_gaudi/extension/features.py
@@ -17,22 +17,33 @@ def get_user_flags():
Env('VLLM_PROMPT_BS_BUCKET_MIN', int),
Env('VLLM_PROMPT_BS_BUCKET_STEP', int),
Env('VLLM_PROMPT_BS_BUCKET_MAX', int),
+ Env('VLLM_PROMPT_BS_BUCKET_PAD_MAX', int),
+ Env('VLLM_PROMPT_BS_BUCKET_PAD_PERCENT', int),
Env('VLLM_PROMPT_QUERY_BUCKET_MIN', int),
Env('VLLM_PROMPT_QUERY_BUCKET_STEP', int),
Env('VLLM_PROMPT_QUERY_BUCKET_MAX', int),
+ Env('VLLM_PROMPT_QUERY_BUCKET_PAD_MAX', int),
+ Env('VLLM_PROMPT_QUERY_BUCKET_PAD_PERCENT', int),
Env('VLLM_PROMPT_SEQ_BUCKET_MIN', int),
Env('VLLM_PROMPT_SEQ_BUCKET_STEP', int),
Env('VLLM_PROMPT_SEQ_BUCKET_MAX', int),
+ Env('VLLM_PROMPT_SEQ_BUCKET_PAD_MAX', int),
+ Env('VLLM_PROMPT_SEQ_BUCKET_PAD_PERCENT', int),
Env('VLLM_PROMPT_CTX_BUCKET_MIN', int),
Env('VLLM_PROMPT_CTX_BUCKET_STEP', int),
Env('VLLM_PROMPT_CTX_BUCKET_MAX', int),
+ Env('VLLM_PROMPT_CTX_BUCKET_PAD_MAX', int),
+ Env('VLLM_PROMPT_CTX_BUCKET_PAD_PERCENT', int),
Env('VLLM_DECODE_BS_BUCKET_MIN', int),
Env('VLLM_DECODE_BS_BUCKET_STEP', int),
Env('VLLM_DECODE_BS_BUCKET_MAX', int),
+ Env('VLLM_DECODE_BS_BUCKET_PAD_MAX', int),
+ Env('VLLM_DECODE_BS_BUCKET_PAD_PERCENT', int),
Env('VLLM_DECODE_BLOCK_BUCKET_MIN', int),
Env('VLLM_DECODE_BLOCK_BUCKET_STEP', int),
Env('VLLM_DECODE_BLOCK_BUCKET_MAX', int),
- Env('VLLM_DECODE_BLOCK_BUCKET_LIMIT', int),
+ Env('VLLM_DECODE_BLOCK_BUCKET_PAD_MAX', int),
+ Env('VLLM_DECODE_BLOCK_BUCKET_PAD_PERCENT', int),
Env('VLLM_BUCKETING_FROM_FILE', str),
# Non-vllm flags that are also important to print
@@ -63,7 +74,7 @@ def get_experimental_flags():
def get_features():
supported_attn_impls = ['flex_impl', 'fsdpa_impl', 'naive_impl']
- bucketing_strategies = ['exponential_bucketing', 'linear_bucketing']
+ bucketing_strategies = ['linear_bucketing', 'exponential_bucketing']
features = [
Value('fp32_alibi_biases', True, env_var='VLLM_ALIBI_USE_FLOAT32_BIASES'),
Value('fp32_softmax', ModelType('qwen2')),
@@ -82,8 +93,8 @@ def get_features():
Any(Disabled('prefix_caching'), Enabled('unified_attn')),
env_var='VLLM_CONTIGUOUS_PA'),
Value('use_bucketing', True, env_var='VLLM_ENABLE_BUCKETING'),
- Value('exponential_bucketing', True),
- Value('linear_bucketing', True),
+ Value('exponential_bucketing', False, env_var='VLLM_EXPONENTIAL_BUCKETING', env_var_type=boolean),
+ Value('linear_bucketing', Not(Enabled('exponential_bucketing'))),
ValueFromList('bucketing_strategy', bucketing_strategies),
Value('defrag', Enabled('unified_attn')),
Value('regional_compilation', True, env_var='VLLM_T_COMPILE_REGIONAL_COMPILATION', env_var_type=boolean),