diff --git a/.cd/README.md b/.cd/README.md index 0bceb6fe..270e1372 100644 --- a/.cd/README.md +++ b/.cd/README.md @@ -78,6 +78,7 @@ cd vllm-gaudi/.cd/ - `VLLM_DECODE_BS_BUCKET_STEP` - Determines the batch size step for decode operations, influencing how batches are grouped and processed. - `VLLM_PROMPT_BS_BUCKET_STEP` - Sets the batch size step for prompt processing, impacting how prompt batches are handled. - `VLLM_PROMPT_SEQ_BUCKET_STEP` - Controls the step size for prompt sequence allocation, affecting how sequences are bucketed for processing. +- `VLLM_PROMPT_CTX_BUCKET_STEP` - Controls the step size for prompt ctx allocation, affecting how ctx blocks are bucketed for processing. **Example usage:** diff --git a/docs/configuration/env_vars.md b/docs/configuration/env_vars.md index ac62d2a1..8451c7a6 100644 --- a/docs/configuration/env_vars.md +++ b/docs/configuration/env_vars.md @@ -53,6 +53,9 @@ Additionally, there are HPU PyTorch Bridge environment variables impacting vLLM - sequence length min (`VLLM_PROMPT_SEQ_BUCKET_MIN`): `block_size` - sequence length step (`VLLM_PROMPT_SEQ_BUCKET_STEP`): `block_size` - sequence length max (`VLLM_PROMPT_SEQ_BUCKET_MAX`): `max_model_len` + - sequence ctx min (`VLLM_PROMPT_CTX_BUCKET_MIN`): `0` + - sequence ctx step (`VLLM_PROMPT_CTX_BUCKET_STEP`): `1` + - sequence ctx max (`VLLM_PROMPT_CTX_BUCKET_MAX`): `(max_model_len - block_size) // block_size` - Decode: - batch size min (`VLLM_DECODE_BS_BUCKET_MIN`): `1` - batch size step (`VLLM_DECODE_BS_BUCKET_STEP`): `32` diff --git a/vllm_gaudi/extension/bucketing/linear.py b/vllm_gaudi/extension/bucketing/linear.py index 6d15724a..e32c6647 100644 --- a/vllm_gaudi/extension/bucketing/linear.py +++ b/vllm_gaudi/extension/bucketing/linear.py @@ -22,7 +22,7 @@ def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_toke step=block_size, max=max_model_len) max_ctx = math.ceil((max_model_len - prompt_query_bucket_cfg[0]) // block_size) - prompt_ctx_bucket_cfg = [0, 1, max_ctx] + prompt_ctx_bucket_cfg = read_bucket_settings('prompt', 'ctx', min=0, step=1, max=max_ctx) if use_merged_prefill: prev_prompt_bs_bucket_cfg = tuple(prompt_bs_bucket_cfg) @@ -32,7 +32,11 @@ def get_prompt_cfgs(self, max_num_prefill_seqs, block_size, max_num_batched_toke prompt_bs_bucket_cfg = (1, 1, 1) query_min, query_step, _ = prev_prompt_query_bucket_cfg prompt_query_bucket_cfg = (query_min, query_step * 4, max_num_batched_tokens) - prompt_ctx_bucket_cfg = (0, 4, max_ctx * max_num_prefill_seqs) + prompt_ctx_bucket_cfg = read_bucket_settings('prompt', + 'ctx', + min=0, + step=4, + max=max_ctx * max_num_prefill_seqs) msg = ('Merged prefill is enabled!\n' 'Overriding prompt bucketing settings!\n' diff --git a/vllm_gaudi/extension/features.py b/vllm_gaudi/extension/features.py index 609dfdcd..d8e1619c 100644 --- a/vllm_gaudi/extension/features.py +++ b/vllm_gaudi/extension/features.py @@ -23,6 +23,9 @@ def get_user_flags(): Env('VLLM_PROMPT_SEQ_BUCKET_STEP', int), Env('VLLM_PROMPT_SEQ_BUCKET_MAX', int), Env('VLLM_PROMPT_SEQ_BUCKET_LIMIT', int), + Env('VLLM_PROMPT_CTX_BUCKET_MIN', int), + Env('VLLM_PROMPT_CTX_BUCKET_STEP', int), + Env('VLLM_PROMPT_CTX_BUCKET_MAX', int), Env('VLLM_DECODE_BS_BUCKET_MIN', int), Env('VLLM_DECODE_BS_BUCKET_STEP', int), Env('VLLM_DECODE_BS_BUCKET_MAX', int),