Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions tests/unit_tests/worker/test_hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,6 @@ def test_reload_weights_before_load_model(model_runner):
model_runner.reload_weights()


@pytest.mark.xfail(reason="KV sharing doesn't currently work on HPU")
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
torch.set_default_dtype(torch.bfloat16)
layer_0 = "model.layers.0.self_attn.attn"
Expand Down Expand Up @@ -426,7 +425,6 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order():
assert fwd_context is not None


@pytest.mark.xfail(reason="KV sharing doesn't currently work on HPU")
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
torch.set_default_dtype(torch.bfloat16)
layer_0 = "model.layers.0.self_attn.attn"
Expand Down Expand Up @@ -456,7 +454,6 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist():
assert fwd_context is not None


@pytest.mark.xfail(reason="KV sharing doesn't currently work on HPU")
def test_init_kv_cache_with_kv_sharing_target_same_as_current():
torch.set_default_dtype(torch.bfloat16)
layer_0 = "model.layers.0.self_attn.attn"
Expand Down Expand Up @@ -552,7 +549,6 @@ def test_init_kv_cache_without_kv_sharing():
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1


@pytest.mark.xfail(reason="KV sharing doesn't currently work on HPU")
def test_init_kv_cache_with_kv_sharing_valid():
torch.set_default_dtype(torch.bfloat16)
layer_0 = "model.layers.0.self_attn.attn"
Expand Down Expand Up @@ -588,10 +584,11 @@ def test_init_kv_cache_with_kv_sharing_valid():
assert runner.shared_kv_cache_layers[layer_1] == layer_0

available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
# page size for layer 0's kv_cache_spec is 256KB
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks = 655360 # 20GB / 32KB
page_size = 128 * 8 * 64 * 2 * 2 # 128 for block_size, 2 for K+V, 2 for bfloat16
num_expected_blocks = available_memory / page_size # 20GB / 256KB
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], [available_memory])[0]
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.kv_cache_tensors) == 1
Expand Down
16 changes: 8 additions & 8 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,7 @@ def __init__(
use_irope: bool = False,
) -> None:
super(AttentionImpl, self).__init__()
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not currently supported on HPU.")
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
if use_irope:
logger.warning_once("Using irope in HPU is not supported yet, it will fall back "
"to global attention for long context.")
Expand Down Expand Up @@ -498,12 +497,13 @@ def forward(
if kv_cache is not None and isinstance(kv_cache, tuple):
key_cache, value_cache = HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key, key_cache, slot_mapping)
value_cache = self.v_cache(value, value_cache, slot_mapping)

if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key, key_cache, slot_mapping)
value_cache = self.v_cache(value, value_cache, slot_mapping)
if attn_metadata.is_prompt:
# Prompt run.
query_shape = (batch_size, seq_len, self.num_heads, self.head_size)
Expand Down
14 changes: 12 additions & 2 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
AsyncModelRunnerOutput, KVConnectorOutput)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.worker.utils import bind_kv_cache
from vllm_gaudi.v1.worker.hpu_input_batch import InputBatch
from vllm.v1.worker.gpu_input_batch import CachedRequestState
from vllm_gaudi.v1.worker.utils import initialize_kv_cache_for_kv_sharing
from vllm_gaudi.v1.worker.hpu_input_batch import InputBatch, CachedRequestState
from vllm.distributed.parallel_state import get_pp_group
from vllm.model_executor.models.interfaces import (supports_eagle3, supports_transcription)
from vllm.model_executor.models.interfaces_base import (VllmModelForPooling, is_pooling_model, is_text_generation_model)
Expand Down Expand Up @@ -645,6 +645,8 @@ def __init__(
# Lazy initialization
# self.model: nn.Module # set after load_model
self.kv_caches: list[torch.Tensor] = []
# KV sharing tracks layers that share the same KV cache
self.shared_kv_cache_layers: dict[str, str] = {}
self.inc_initialized_successfully = False
self._is_inc_finalized = False

Expand Down Expand Up @@ -885,6 +887,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:
use_mla = self.vllm_config.model_config.use_mla
kv_cache_spec: dict[str, KVCacheSpec] = {}
for layer_name, attn_module in forward_ctx.items():
kv_sharing_target_layer_name = getattr(attn_module, 'kv_sharing_target_layer_name', None)
if kv_sharing_target_layer_name is not None:
self.shared_kv_cache_layers[layer_name] = kv_sharing_target_layer_name
continue
if isinstance(attn_module, FusedMoE):
continue

Expand Down Expand Up @@ -3800,6 +3806,10 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
for group in kv_cache_config.kv_cache_groups:
layer_names.update(group.layer_names)
assert layer_names == set(kv_caches.keys()), "Some layers are not correctly initialized"

if self.shared_kv_cache_layers:
initialize_kv_cache_for_kv_sharing(self.shared_kv_cache_layers, kv_cache_config.kv_cache_groups, kv_caches)

bind_kv_cache(kv_caches, self.vllm_config.compilation_config.static_forward_context, self.kv_caches)

if self.enable_bucketing:
Expand Down
39 changes: 39 additions & 0 deletions vllm_gaudi/v1/worker/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from typing import TYPE_CHECKING
import torch

if TYPE_CHECKING:
from vllm.v1.kv_cache_interface import KVCacheGroupSpec


def initialize_kv_cache_for_kv_sharing(
shared_kv_cache_layers: dict[str, str],
kv_cache_groups: list["KVCacheGroupSpec"],
kv_caches: dict[str, torch.Tensor],
) -> None:
"""
Sets up KV cache sharing by reusing the allocated KV caches in `kv_caches`
for layers that do not allocate its own KV cache, based on the mapping in
`shared_kv_cache_layers`. Adds these layers to the corresponding KV cache
group, which is needed to ensure that attention metadata is assigned later.

Args:
shared_kv_cache_layers: Layer pairings for cross-layer KV sharing.
If an Attention layer `layer_name` is in the keys of this dict, it
means this layer will perform attention using the keys and values
from the KV cache of `shared_kv_cache_layers[layer_name]`.
kv_cache_groups: The KV cache groups of the model.
kv_caches: The allocated kv_caches with layer names as keys.
Note that layers in shared_kv_cache_layers.keys() are not
originally included as it only contains layers which have its own
KV cache allocation.
"""
# Record index of KV cache group for each layer that allocates a KV cache.
layer_to_kv_cache_group_idx: dict[str, int] = {}
for i, kv_cache_group in enumerate(kv_cache_groups):
for layer_name in kv_cache_group.layer_names:
layer_to_kv_cache_group_idx[layer_name] = i

for layer_name, target_layer_name in shared_kv_cache_layers.items():
kv_caches[layer_name] = kv_caches[target_layer_name]
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
kv_cache_groups[group_idx].layer_names.append(layer_name)