Skip to content
42 changes: 35 additions & 7 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,8 +576,18 @@ def forward(
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key, key_cache, slot_mapping, k_scales)
value_cache = self.v_cache(value, value_cache, slot_mapping, v_scales)
key_cache = self.k_cache(key,
key_cache,
slot_mapping,
scales=k_scales,
block_size=attn_metadata.block_size,
is_prompt=attn_metadata.is_prompt)
value_cache = self.v_cache(value,
value_cache,
slot_mapping,
scales=v_scales,
block_size=attn_metadata.block_size,
is_prompt=attn_metadata.is_prompt)

if attn_metadata.is_prompt:
# Prompt run.
Expand Down Expand Up @@ -744,8 +754,18 @@ def forward_encoder_decoder(
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key, key_cache, cross_slot_mapping, k_scales)
value_cache = self.v_cache(value, value_cache, cross_slot_mapping, v_scales)
key_cache = self.k_cache(key,
key_cache,
cross_slot_mapping,
scales=k_scales,
block_size=attn_metadata.block_size,
is_prompt=attn_metadata.is_prompt)
value_cache = self.v_cache(value,
value_cache,
cross_slot_mapping,
scales=v_scales,
block_size=attn_metadata.block_size,
is_prompt=attn_metadata.is_prompt)

if attn_metadata.is_prompt:
# Prompt run.
Expand Down Expand Up @@ -917,7 +937,7 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
attn_metadata: HPUUnifiedAttentionMetadata,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Expand All @@ -930,8 +950,16 @@ def forward(
query = query.unflatten(-1, (-1, self.head_size))
key = key.unflatten(-1, (-1, self.head_size))
value = value.unflatten(-1, (-1, self.head_size))
key_cache = self.k_cache(key, key_cache, attn_metadata.slot_mapping, k_scales)
value_cache = self.v_cache(value, value_cache, attn_metadata.slot_mapping, v_scales)
key_cache = self.k_cache(key,
key_cache,
attn_metadata.slot_mapping,
scales=k_scales,
block_size=attn_metadata.block_size)
value_cache = self.v_cache(value,
value_cache,
attn_metadata.slot_mapping,
scales=v_scales,
block_size=attn_metadata.block_size)
output = unified_attn(
query=query,
key=key,
Expand Down
7 changes: 4 additions & 3 deletions vllm_gaudi/attention/ops/hpu_paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def forward_decode(**kwargs) -> torch.Tensor:

@staticmethod
def swap_blocks(
src_kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
dst_kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor],
src_kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
dst_kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple[torch.Tensor, torch.Tensor]],
src_to_dsts: torch.Tensor,
) -> None:
src_key_cache = src_kv_cache[0]
Expand All @@ -117,7 +117,8 @@ def swap_blocks(
if src_key_scales is not None:
cache_ops.swap_blocks(src_key_scales, dst_key_scales, src_to_dsts)
if src_value_scales is not None:
cache_ops.swap_blocks(src_value_scales, dst_value_scales, src_to_dsts)
cache_ops.swap_blocks(src_value_scales[0], dst_value_scales[0], src_to_dsts)
cache_ops.swap_blocks(src_value_scales[1], dst_value_scales[1], src_to_dsts)

@staticmethod
def copy_blocks(
Expand Down
5 changes: 3 additions & 2 deletions vllm_gaudi/extension/cache_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ def copy_blocks(key_caches, value_caches, key_scales, value_scales, block_mappin
value_cache.index_copy_(0, dst, value_cache.index_select(0, src))
if k_scales is not None:
k_scales.index_copy_(0, dst, k_scales.index_select(0, src))
if v_scales is not None:
v_scales.index_copy_(0, dst, v_scales.index_select(0, src))
if v_scales is not None and isinstance(v_scales, tuple):
v_scales[0].index_copy_(0, dst, v_scales[0].index_select(0, src))
v_scales[1].index_copy_(0, dst, v_scales[1].index_select(0, src))

if key_caches[0].device.type == 'hpu':
htorch.core.mark_step()
25 changes: 17 additions & 8 deletions vllm_gaudi/extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias
if k_scales is not None:
k_scales_uf = k_scales.unflatten(0, (-1, block_size))
if v_scales is not None:
v_scales_uf = v_scales.unflatten(0, (-1, block_size))
v_scales_uf = (v_scales[0].unflatten(0, (-1, block_size)), v_scales[1])

query_shape = (-1, q_heads, 1, head_size)
query = batch2block(scale * query, block_mapping, batch2block_matmul_op).view(query_shape)
Expand Down Expand Up @@ -391,19 +391,28 @@ def _get_all(data, *keys):
return [data.get(k, None) for k in keys]


def _include_past(tensor_str, fn_str, cache_str, args):
all_tensors = _get_all(args, tensor_str, fn_str, cache_str, 'block_list', 'block_size')
if all(t is not None for t in all_tensors):
current, fn, cache, block_list, block_size = all_tensors
past = fn(cache.unflatten(0, (-1, block_size)), block_list)
def _include_past(tensor_str, fn_str, cache_str, scales_str, args):
all_tensors = _get_all(args, tensor_str, fn_str, cache_str, scales_str, 'block_list', 'block_size')
current, fn, cache, scales, block_list, block_size = all_tensors
all_beside_scales = (current, fn, cache, block_list, block_size)
if all(t is not None for t in all_beside_scales):
is_v_scales = scales is not None and isinstance(scales, tuple)
is_k_scales = scales is not None and not is_v_scales
if is_v_scales:
scales_uf = (scales[0].unflatten(0, (-1, block_size)), scales[1])
elif is_k_scales:
scales_uf = scales.unflatten(0, (-1, block_size))
else:
scales_uf = None
past = fn(cache.unflatten(0, (-1, block_size)), **get_kv_fetch_extra_args(blocks=block_list, scales=scales_uf))
past = past.reshape(current.size(0), -1, past.shape[2], past.shape[3])
current = torch.concat((past, current), dim=1)
args[tensor_str] = current


def _get_context(args):
_include_past('key', 'keys_fetch_func', 'key_cache', args)
_include_past('value', 'values_fetch_func', 'value_cache', args)
_include_past('key', 'keys_fetch_func', 'key_cache', 'k_scales', args)
_include_past('value', 'values_fetch_func', 'value_cache', 'v_scales', args)


class LoraMask:
Expand Down
2 changes: 1 addition & 1 deletion vllm_gaudi/extension/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, is_v_cache: bool = False):
# is_v_cache is used in INC FP8 dynamic quantization to identify V cache
self.is_v_cache = is_v_cache

def forward(self, input, cache, slot_mapping, scales=None, **kwargs):
def forward(self, input, cache, slot_mapping, scales=None, block_size=None, is_prompt=False, **kwargs):
# In cross-attention kv cache forward inputs are None in decode
# We don't want to store them in the cache in such case
if input is not None:
Expand Down
31 changes: 22 additions & 9 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,8 +806,10 @@ def __init__(
self.graphed_multimodal_buckets: set[Any] = set()
else:
logger.info("Bucketing is OFF.")

self._PAD_SLOT_ID = -1
self._PAD_BLOCK_ID = -1
self._PAD_BLOCK_ID = 0
self._dummy_num_blocks = 0

if self.vllm_config.parallel_config.data_parallel_size > 1 and htorch.utils.internal.is_lazy(
) and not self.model_config.enforce_eager:
Expand Down Expand Up @@ -1816,7 +1818,8 @@ def _form_prefill_batch(self, contents):
token_positions = align_and_pad(token_positions, (target_bs, target_seq), itertools.repeat(-1))
token_slots = align_and_pad(token_slots, (target_bs, target_seq), itertools.repeat(-1))
token_groups = align_and_pad(token_groups, (target_bs, target_seq), itertools.repeat(-1))
context_blocks = align_and_pad(context_blocks, (target_bs, target_blocks), itertools.repeat(-1))
# use 0 for padding to avoid dynamic scale calculation issues
context_blocks = align_and_pad(context_blocks, (target_bs, target_blocks), itertools.repeat(0))
context_groups = align_and_pad(context_groups, (target_bs, target_blocks), itertools.repeat(-1))

# TODO: cycle through dummy slots and blocks
Expand Down Expand Up @@ -2200,7 +2203,7 @@ def _create_dummy_decode_input_data(self) -> DecodeInputData:
num_dummy_decodes = 1
num_dummy_scheduled_tokens = [1]
context_lens = np.array([128])
block_table_cpu_tensor = torch.zeros([self._PAD_BLOCK_ID], dtype=torch.int32).reshape(1, -1)
block_table_cpu_tensor = torch.zeros([self._dummy_num_blocks], dtype=torch.int32).reshape(1, -1)
return self._create_decode_input_data(num_dummy_decodes, num_dummy_scheduled_tokens, context_lens,
block_table_cpu_tensor)

Expand Down Expand Up @@ -4918,14 +4921,22 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
create_dynamic_scales = True
else:
create_dynamic_scales = False
kv_scales_shape = kv_cache_shape[:-1] + (1, )
min_val = torch.finfo(torch.bfloat16).tiny
kv_scales_shape = list(kv_cache_shape)
kv_scales_shape[-1] = 1
key_cache = torch.zeros(kv_cache_shape, dtype=dtype, device=self.device)
key_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device=self.device) if \
# initialize scale tensor with minimal scale values
key_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device=self.device) * min_val if \
create_dynamic_scales else None
if v_cache_shape is not None:
value_cache = torch.zeros(v_cache_shape, dtype=dtype, device=self.device)
value_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device=self.device) if \
create_dynamic_scales else None
value_scales_on_T = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device=self.device) * \
min_val if create_dynamic_scales else None
value_scales_on_hidden = torch.ones(
[num_blocks + 1, kv_cache_spec.num_kv_heads, kv_cache_spec.head_size],
dtype=torch.bfloat16,
device=self.device) * min_val if create_dynamic_scales else None
value_scales = (value_scales_on_T, value_scales_on_hidden) if create_dynamic_scales else None
else:
value_cache = None
value_scales = None
Expand All @@ -4944,8 +4955,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:

if self.enable_bucketing:
self.bucketing_manager.num_hpu_blocks = num_blocks
self._PAD_BLOCK_ID = num_blocks
self._PAD_SLOT_ID = num_blocks * self.block_size

self._PAD_BLOCK_ID = 0
self._PAD_SLOT_ID = -1
self._dummy_num_blocks = num_blocks

if has_kv_transfer_group():
get_kv_transfer_group().register_kv_caches(self.get_kv_caches_4D(kv_caches))
Expand Down