From db8e00679c3bc243b32532c13c2549895699e14c Mon Sep 17 00:00:00 2001 From: Dudi Lester Date: Mon, 29 Dec 2025 17:10:12 +0200 Subject: [PATCH 1/4] [GAUDISW-244752] add dynamic scale for V-Cache on Hiddden dim Signed-off-by: Dudi Lester --- vllm_gaudi/attention/backends/hpu_attn.py | 12 ++++++++- vllm_gaudi/attention/ops/hpu_paged_attn.py | 4 +-- vllm_gaudi/extension/cache_ops.py | 2 +- vllm_gaudi/extension/ops.py | 23 ++++++++++------ vllm_gaudi/v1/worker/hpu_model_runner.py | 31 +++++++++++++++++----- 5 files changed, 53 insertions(+), 19 deletions(-) diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index f619be467..16af4cfe4 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -573,6 +573,11 @@ def forward( key_cache, value_cache, k_scales, v_scales = \ HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size) + # reset the Value scales for the Hidden dim for the new sequence on prompt + if attn_metadata.is_prompt and value_cache is not None and isinstance(value_cache, tuple) \ + and attn_metadata.block_list is not None: + v_scales[1].index_fill_(0, attn_metadata.block_list, torch.finfo(torch.bfloat16).tiny) + # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. @@ -741,6 +746,11 @@ def forward_encoder_decoder( key_cache, value_cache, k_scales, v_scales = \ HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size) + # reset the Value scales for the Hidden dim for the new sequence on prompt + if attn_metadata.is_prompt and value_cache is not None and isinstance(value_cache, tuple) \ + and attn_metadata.block_list is not None: + v_scales[1].index_fill_(0, attn_metadata.block_list, torch.finfo(torch.bfloat16).tiny) + # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. @@ -917,7 +927,7 @@ def forward( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, - kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]], attn_metadata: HPUUnifiedAttentionMetadata, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: diff --git a/vllm_gaudi/attention/ops/hpu_paged_attn.py b/vllm_gaudi/attention/ops/hpu_paged_attn.py index b28fd68f4..7fe187f6d 100644 --- a/vllm_gaudi/attention/ops/hpu_paged_attn.py +++ b/vllm_gaudi/attention/ops/hpu_paged_attn.py @@ -98,8 +98,8 @@ def forward_decode(**kwargs) -> torch.Tensor: @staticmethod def swap_blocks( - src_kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], - dst_kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], + src_kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]], + dst_kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]], src_to_dsts: torch.Tensor, ) -> None: src_key_cache = src_kv_cache[0] diff --git a/vllm_gaudi/extension/cache_ops.py b/vllm_gaudi/extension/cache_ops.py index cc3190a04..09b0e84f2 100644 --- a/vllm_gaudi/extension/cache_ops.py +++ b/vllm_gaudi/extension/cache_ops.py @@ -39,7 +39,7 @@ def copy_blocks(key_caches, value_caches, key_scales, value_scales, block_mappin if k_scales is not None: k_scales.index_copy_(0, dst, k_scales.index_select(0, src)) if v_scales is not None: - v_scales.index_copy_(0, dst, v_scales.index_select(0, src)) + v_scales[0].index_copy_(0, dst, v_scales[0].index_select(0, src)) if key_caches[0].device.type == 'hpu': htorch.core.mark_step() diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index a40388d3e..9e448471a 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -186,7 +186,7 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias if k_scales is not None: k_scales_uf = k_scales.unflatten(0, (-1, block_size)) if v_scales is not None: - v_scales_uf = v_scales.unflatten(0, (-1, block_size)) + v_scales_uf = (v_scales[0].unflatten(0, (-1, block_size)), v_scales[1]) query_shape = (-1, q_heads, 1, head_size) query = batch2block(scale * query, block_mapping, batch2block_matmul_op).view(query_shape) @@ -389,19 +389,26 @@ def _get_all(data, *keys): return [data.get(k, None) for k in keys] -def _include_past(tensor_str, fn_str, cache_str, args): - all_tensors = _get_all(args, tensor_str, fn_str, cache_str, 'block_list', 'block_size') - if all(t is not None for t in all_tensors): - current, fn, cache, block_list, block_size = all_tensors - past = fn(cache.unflatten(0, (-1, block_size)), block_list) +def _include_past(tensor_str, fn_str, cache_str, scales_str, args): + all_tensors = _get_all(args, tensor_str, fn_str, cache_str, scales_str, 'block_list', 'block_size') + current, fn, cache, scales, block_list, block_size = all_tensors + all_beside_scales = (current, fn, cache, block_list, block_size) + if all(t is not None for t in all_beside_scales): + if scales is not None and isinstance(scales, tuple): + scales_uf = (scales[0].unflatten(0, (-1, block_size)), scales[1]) + elif scales is not None: + scales_uf = scales.unflatten(0, (-1, block_size)) + else: + scales_uf = None + past = fn(cache.unflatten(0, (-1, block_size)), **get_kv_fetch_extra_args(blocks=block_list, scales=scales_uf)) past = past.reshape(current.size(0), -1, past.shape[2], past.shape[3]) current = torch.concat((past, current), dim=1) args[tensor_str] = current def _get_context(args): - _include_past('key', 'keys_fetch_func', 'key_cache', args) - _include_past('value', 'values_fetch_func', 'value_cache', args) + _include_past('key', 'keys_fetch_func', 'key_cache', 'k_scales', args) + _include_past('value', 'values_fetch_func', 'value_cache', 'v_scales', args) class LoraMask: diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 1566699fc..3626a0cd2 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -808,8 +808,14 @@ def __init__( self.graphed_multimodal_buckets: set[Any] = set() else: logger.info("Bucketing is OFF.") +<<<<<<< HEAD self._PAD_SLOT_ID = -1 self._PAD_BLOCK_ID = -1 +======= + # set an out of range value for the padding tokens so that they are ignored when inserting into the KV cache. + self._PAD_SLOT_ID = torch.iinfo(torch.int32).max + self._PAD_BLOCK_ID = torch.iinfo(torch.int32).max +>>>>>>> 500c8ba ([GAUDISW-244752] add dynamic scale for V-Cache on Hiddden dim) if self.vllm_config.parallel_config.data_parallel_size > 1 and htorch.utils.internal.is_lazy( ) and not self.model_config.enforce_eager: @@ -1798,9 +1804,12 @@ def _form_prefill_batch(self, contents): else: token_positions = align_and_pad(token_positions, (target_bs, target_seq), itertools.repeat(-1)) - token_slots = align_and_pad(token_slots, (target_bs, target_seq), itertools.repeat(-1)) + # set an out of range value for the padding tokens so that they are ignored when inserting into the KV cache. + token_slots = align_and_pad(token_slots, (target_bs, target_seq), + itertools.repeat(torch.iinfo(torch.int32).max)) token_groups = align_and_pad(token_groups, (target_bs, target_seq), itertools.repeat(-1)) - context_blocks = align_and_pad(context_blocks, (target_bs, target_blocks), itertools.repeat(-1)) + # use 0 for padding context blocks to avoid dynamic scale calculation issues + context_blocks = align_and_pad(context_blocks, (target_bs, target_blocks), itertools.repeat(0)) context_groups = align_and_pad(context_groups, (target_bs, target_blocks), itertools.repeat(-1)) # TODO: cycle through dummy slots and blocks @@ -2184,7 +2193,7 @@ def _create_dummy_decode_input_data(self) -> DecodeInputData: num_dummy_decodes = 1 num_dummy_scheduled_tokens = [1] context_lens = np.array([128]) - block_table_cpu_tensor = torch.zeros([self._PAD_BLOCK_ID], dtype=torch.int32).reshape(1, -1) + block_table_cpu_tensor = torch.zeros([context_lens], dtype=torch.int32).reshape(1, -1) return self._create_decode_input_data(num_dummy_decodes, num_dummy_scheduled_tokens, context_lens, block_table_cpu_tensor) @@ -4814,14 +4823,22 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: create_dynamic_scales = True else: create_dynamic_scales = False - kv_scales_shape = kv_cache_shape[:-1] + (1, ) + min_val = torch.finfo(torch.bfloat16).tiny + kv_scales_shape = list(kv_cache_shape) + kv_scales_shape[-1] = 1 key_cache = torch.zeros(kv_cache_shape, dtype=dtype, device=self.device) - key_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device=self.device) if \ + # initialize scale tensor with minimal scale values + key_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device=self.device) * min_val if \ create_dynamic_scales else None if v_cache_shape is not None: value_cache = torch.zeros(v_cache_shape, dtype=dtype, device=self.device) - value_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device=self.device) if \ - create_dynamic_scales else None + value_scales_on_T = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device=self.device) * \ + min_val if create_dynamic_scales else None + value_scales_on_hidden = torch.ones( + [num_blocks + 1, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size], + dtype=torch.bfloat16, + device=self.device) * min_val if create_dynamic_scales else None + value_scales = (value_scales_on_T, value_scales_on_hidden) if create_dynamic_scales else None else: value_cache = None value_scales = None From 86ee21a76be19da46f62fb836edbda3ced54023d Mon Sep 17 00:00:00 2001 From: Dudi Lester Date: Thu, 1 Jan 2026 10:53:48 +0200 Subject: [PATCH 2/4] Fix review comments Signed-off-by: Dudi Lester --- vllm_gaudi/attention/ops/hpu_paged_attn.py | 3 ++- vllm_gaudi/extension/cache_ops.py | 3 ++- vllm_gaudi/extension/ops.py | 6 ++++-- vllm_gaudi/v1/worker/hpu_model_runner.py | 24 ++++++++-------------- 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/vllm_gaudi/attention/ops/hpu_paged_attn.py b/vllm_gaudi/attention/ops/hpu_paged_attn.py index 7fe187f6d..1dd4111f7 100644 --- a/vllm_gaudi/attention/ops/hpu_paged_attn.py +++ b/vllm_gaudi/attention/ops/hpu_paged_attn.py @@ -117,7 +117,8 @@ def swap_blocks( if src_key_scales is not None: cache_ops.swap_blocks(src_key_scales, dst_key_scales, src_to_dsts) if src_value_scales is not None: - cache_ops.swap_blocks(src_value_scales, dst_value_scales, src_to_dsts) + cache_ops.swap_blocks(src_value_scales[0], dst_value_scales[0], src_to_dsts) + cache_ops.swap_blocks(src_value_scales[1], dst_value_scales[1], src_to_dsts) @staticmethod def copy_blocks( diff --git a/vllm_gaudi/extension/cache_ops.py b/vllm_gaudi/extension/cache_ops.py index 09b0e84f2..2da879415 100644 --- a/vllm_gaudi/extension/cache_ops.py +++ b/vllm_gaudi/extension/cache_ops.py @@ -38,8 +38,9 @@ def copy_blocks(key_caches, value_caches, key_scales, value_scales, block_mappin value_cache.index_copy_(0, dst, value_cache.index_select(0, src)) if k_scales is not None: k_scales.index_copy_(0, dst, k_scales.index_select(0, src)) - if v_scales is not None: + if v_scales is not None and isinstance(v_scales, tuple): v_scales[0].index_copy_(0, dst, v_scales[0].index_select(0, src)) + v_scales[1].index_copy_(0, dst, v_scales[1].index_select(0, src)) if key_caches[0].device.type == 'hpu': htorch.core.mark_step() diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 9e448471a..2e2bb635b 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -394,9 +394,11 @@ def _include_past(tensor_str, fn_str, cache_str, scales_str, args): current, fn, cache, scales, block_list, block_size = all_tensors all_beside_scales = (current, fn, cache, block_list, block_size) if all(t is not None for t in all_beside_scales): - if scales is not None and isinstance(scales, tuple): + is_v_scales = scales is not None and isinstance(scales, tuple) + is_k_scales = scales is not None and not is_v_scales + if is_v_scales: scales_uf = (scales[0].unflatten(0, (-1, block_size)), scales[1]) - elif scales is not None: + elif is_k_scales: scales_uf = scales.unflatten(0, (-1, block_size)) else: scales_uf = None diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 3626a0cd2..ee3a59cbb 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -808,14 +808,9 @@ def __init__( self.graphed_multimodal_buckets: set[Any] = set() else: logger.info("Bucketing is OFF.") -<<<<<<< HEAD - self._PAD_SLOT_ID = -1 - self._PAD_BLOCK_ID = -1 -======= - # set an out of range value for the padding tokens so that they are ignored when inserting into the KV cache. - self._PAD_SLOT_ID = torch.iinfo(torch.int32).max - self._PAD_BLOCK_ID = torch.iinfo(torch.int32).max ->>>>>>> 500c8ba ([GAUDISW-244752] add dynamic scale for V-Cache on Hiddden dim) + + self._PAD_SLOT_ID = 0 + self._PAD_BLOCK_ID = 0 if self.vllm_config.parallel_config.data_parallel_size > 1 and htorch.utils.internal.is_lazy( ) and not self.model_config.enforce_eager: @@ -1786,7 +1781,7 @@ def _form_prefill_batch(self, contents): # For models with multimodal support, we may want to get embeddings # for the valid tokens before padding. # This would require getting multimodal input embeddings here as well - token_ids = align_and_pad(contents.token_ids, (target_bs, target_seq), itertools.repeat(-1)) + token_ids = align_and_pad(contents.token_ids, (target_bs, target_seq), itertools.repeat(0)) # Update query_lens and context_lens after padding query_lens.extend([0] * (target_bs - len(query_lens))) context_lens.extend([0] * (target_bs - len(context_lens))) @@ -1804,11 +1799,9 @@ def _form_prefill_batch(self, contents): else: token_positions = align_and_pad(token_positions, (target_bs, target_seq), itertools.repeat(-1)) - # set an out of range value for the padding tokens so that they are ignored when inserting into the KV cache. - token_slots = align_and_pad(token_slots, (target_bs, target_seq), - itertools.repeat(torch.iinfo(torch.int32).max)) token_groups = align_and_pad(token_groups, (target_bs, target_seq), itertools.repeat(-1)) - # use 0 for padding context blocks to avoid dynamic scale calculation issues + # use 0 for padding to avoid dynamic scale calculation issues + token_slots = align_and_pad(token_slots, (target_bs, target_seq), itertools.repeat(0)) context_blocks = align_and_pad(context_blocks, (target_bs, target_blocks), itertools.repeat(0)) context_groups = align_and_pad(context_groups, (target_bs, target_blocks), itertools.repeat(-1)) @@ -4857,8 +4850,9 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: if self.enable_bucketing: self.bucketing_manager.num_hpu_blocks = num_blocks - self._PAD_BLOCK_ID = num_blocks - self._PAD_SLOT_ID = num_blocks * self.block_size + + self._PAD_BLOCK_ID = 0 + self._PAD_SLOT_ID = 0 if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(self.get_kv_caches_4D(kv_caches)) From b3aaad196aac06ebe3b632c97b83ed70fd4e0e25 Mon Sep 17 00:00:00 2001 From: Dudi Lester Date: Tue, 6 Jan 2026 09:51:52 +0200 Subject: [PATCH 3/4] fix _create_dummy_decode_input_data to num_blocks value Signed-off-by: Dudi Lester --- vllm_gaudi/v1/worker/hpu_model_runner.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index ee3a59cbb..feb38f63d 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -811,6 +811,7 @@ def __init__( self._PAD_SLOT_ID = 0 self._PAD_BLOCK_ID = 0 + self._dummy_num_blocks = 0 if self.vllm_config.parallel_config.data_parallel_size > 1 and htorch.utils.internal.is_lazy( ) and not self.model_config.enforce_eager: @@ -2186,7 +2187,7 @@ def _create_dummy_decode_input_data(self) -> DecodeInputData: num_dummy_decodes = 1 num_dummy_scheduled_tokens = [1] context_lens = np.array([128]) - block_table_cpu_tensor = torch.zeros([context_lens], dtype=torch.int32).reshape(1, -1) + block_table_cpu_tensor = torch.zeros([self._dummy_num_blocks], dtype=torch.int32).reshape(1, -1) return self._create_decode_input_data(num_dummy_decodes, num_dummy_scheduled_tokens, context_lens, block_table_cpu_tensor) @@ -4853,6 +4854,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self._PAD_BLOCK_ID = 0 self._PAD_SLOT_ID = 0 + self._dummy_num_blocks = num_blocks if has_kv_transfer_group(): get_kv_transfer_group().register_kv_caches(self.get_kv_caches_4D(kv_caches)) From 89f051a80bc4179a96118dc8453e81d4da78064e Mon Sep 17 00:00:00 2001 From: Dudi Lester Date: Sun, 18 Jan 2026 16:52:40 +0200 Subject: [PATCH 4/4] Add block_size and is_prompt to VLLMKVCache forward call Signed-off-by: Dudi Lester --- vllm_gaudi/attention/backends/hpu_attn.py | 50 +++++++++++++++-------- vllm_gaudi/extension/utils.py | 2 +- vllm_gaudi/v1/worker/hpu_model_runner.py | 8 ++-- 3 files changed, 39 insertions(+), 21 deletions(-) diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index 16af4cfe4..b4ac4617c 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -573,16 +573,21 @@ def forward( key_cache, value_cache, k_scales, v_scales = \ HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size) - # reset the Value scales for the Hidden dim for the new sequence on prompt - if attn_metadata.is_prompt and value_cache is not None and isinstance(value_cache, tuple) \ - and attn_metadata.block_list is not None: - v_scales[1].index_fill_(0, attn_metadata.block_list, torch.finfo(torch.bfloat16).tiny) - # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - key_cache = self.k_cache(key, key_cache, slot_mapping, k_scales) - value_cache = self.v_cache(value, value_cache, slot_mapping, v_scales) + key_cache = self.k_cache(key, + key_cache, + slot_mapping, + scales=k_scales, + block_size=attn_metadata.block_size, + is_prompt=attn_metadata.is_prompt) + value_cache = self.v_cache(value, + value_cache, + slot_mapping, + scales=v_scales, + block_size=attn_metadata.block_size, + is_prompt=attn_metadata.is_prompt) if attn_metadata.is_prompt: # Prompt run. @@ -746,16 +751,21 @@ def forward_encoder_decoder( key_cache, value_cache, k_scales, v_scales = \ HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size) - # reset the Value scales for the Hidden dim for the new sequence on prompt - if attn_metadata.is_prompt and value_cache is not None and isinstance(value_cache, tuple) \ - and attn_metadata.block_list is not None: - v_scales[1].index_fill_(0, attn_metadata.block_list, torch.finfo(torch.bfloat16).tiny) - # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. - key_cache = self.k_cache(key, key_cache, cross_slot_mapping, k_scales) - value_cache = self.v_cache(value, value_cache, cross_slot_mapping, v_scales) + key_cache = self.k_cache(key, + key_cache, + cross_slot_mapping, + scales=k_scales, + block_size=attn_metadata.block_size, + is_prompt=attn_metadata.is_prompt) + value_cache = self.v_cache(value, + value_cache, + cross_slot_mapping, + scales=v_scales, + block_size=attn_metadata.block_size, + is_prompt=attn_metadata.is_prompt) if attn_metadata.is_prompt: # Prompt run. @@ -940,8 +950,16 @@ def forward( query = query.unflatten(-1, (-1, self.head_size)) key = key.unflatten(-1, (-1, self.head_size)) value = value.unflatten(-1, (-1, self.head_size)) - key_cache = self.k_cache(key, key_cache, attn_metadata.slot_mapping, k_scales) - value_cache = self.v_cache(value, value_cache, attn_metadata.slot_mapping, v_scales) + key_cache = self.k_cache(key, + key_cache, + attn_metadata.slot_mapping, + scales=k_scales, + block_size=attn_metadata.block_size) + value_cache = self.v_cache(value, + value_cache, + attn_metadata.slot_mapping, + scales=v_scales, + block_size=attn_metadata.block_size) output = unified_attn( query=query, key=key, diff --git a/vllm_gaudi/extension/utils.py b/vllm_gaudi/extension/utils.py index bcdd05b21..ede1c10dd 100644 --- a/vllm_gaudi/extension/utils.py +++ b/vllm_gaudi/extension/utils.py @@ -53,7 +53,7 @@ def __init__(self, is_v_cache: bool = False): # is_v_cache is used in INC FP8 dynamic quantization to identify V cache self.is_v_cache = is_v_cache - def forward(self, input, cache, slot_mapping, scales=None, **kwargs): + def forward(self, input, cache, slot_mapping, scales=None, block_size=None, is_prompt=False, **kwargs): # In cross-attention kv cache forward inputs are None in decode # We don't want to store them in the cache in such case if input is not None: diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index feb38f63d..8e800c33d 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -809,7 +809,7 @@ def __init__( else: logger.info("Bucketing is OFF.") - self._PAD_SLOT_ID = 0 + self._PAD_SLOT_ID = -1 self._PAD_BLOCK_ID = 0 self._dummy_num_blocks = 0 @@ -1782,7 +1782,7 @@ def _form_prefill_batch(self, contents): # For models with multimodal support, we may want to get embeddings # for the valid tokens before padding. # This would require getting multimodal input embeddings here as well - token_ids = align_and_pad(contents.token_ids, (target_bs, target_seq), itertools.repeat(0)) + token_ids = align_and_pad(contents.token_ids, (target_bs, target_seq), itertools.repeat(-1)) # Update query_lens and context_lens after padding query_lens.extend([0] * (target_bs - len(query_lens))) context_lens.extend([0] * (target_bs - len(context_lens))) @@ -1800,9 +1800,9 @@ def _form_prefill_batch(self, contents): else: token_positions = align_and_pad(token_positions, (target_bs, target_seq), itertools.repeat(-1)) + token_slots = align_and_pad(token_slots, (target_bs, target_seq), itertools.repeat(-1)) token_groups = align_and_pad(token_groups, (target_bs, target_seq), itertools.repeat(-1)) # use 0 for padding to avoid dynamic scale calculation issues - token_slots = align_and_pad(token_slots, (target_bs, target_seq), itertools.repeat(0)) context_blocks = align_and_pad(context_blocks, (target_bs, target_blocks), itertools.repeat(0)) context_groups = align_and_pad(context_groups, (target_bs, target_blocks), itertools.repeat(-1)) @@ -4853,7 +4853,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self.bucketing_manager.num_hpu_blocks = num_blocks self._PAD_BLOCK_ID = 0 - self._PAD_SLOT_ID = 0 + self._PAD_SLOT_ID = -1 self._dummy_num_blocks = num_blocks if has_kv_transfer_group():