Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9d8e272
Pick model runner change related to PR30475.
libinta Dec 23, 2025
49d7633
add qwen3_vl.py functions
libinta Dec 23, 2025
bdff63f
Merge branch 'main' into libinta/remove_gather_scatter
libinta Dec 23, 2025
c6526de
precomit fix
libinta Dec 24, 2025
7c6329e
precommit fix and fix use_window_sdpa
libinta Dec 25, 2025
bff3cf5
Update qwen3_vl.py
iboiko-habana Dec 29, 2025
625d9c2
Update qwen3_vl.py
iboiko-habana Dec 29, 2025
568b4eb
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana Dec 29, 2025
495643a
Merge branch 'main' into libinta/remove_gather_scatter
libinta Dec 30, 2025
327a9cc
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana Dec 30, 2025
bb3ac24
Update qwen3_vl.py
iboiko-habana Dec 30, 2025
fe67f98
Merge branch 'main' into libinta/remove_gather_scatter
libinta Jan 5, 2026
8a9efd1
Merge branch 'main' into libinta/remove_gather_scatter
libinta Jan 5, 2026
a394b9a
Merge branch 'main' into libinta/remove_gather_scatter
libinta Jan 8, 2026
6502061
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana Jan 15, 2026
48a96db
fix test failure
libinta Jan 15, 2026
0171641
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana Jan 15, 2026
db10548
fix precommit issue
libinta Jan 15, 2026
40d7635
Update interfaces.py for precommit fix
libinta Jan 15, 2026
e23e6d2
Update hpu_model_runner.py to match with upstream for MultiModalBudget
libinta Jan 16, 2026
46facad
Merge branch 'main' into libinta/remove_gather_scatter
libinta Jan 19, 2026
4089adf
Update qwen3_vl.py for precommit fix
libinta Jan 19, 2026
79d90a4
Update qwen3_vl.py for precommit fix
libinta Jan 19, 2026
e370a49
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana Jan 21, 2026
0321b6c
Interleaved sliding window fix (#805)
rsmyrek Jan 21, 2026
0df1f20
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana Jan 21, 2026
3108ed8
[GAUDISW-245665] fix diverge from vllm in multiModalBudget (#837)
linoybu Jan 21, 2026
2d1a7a7
KV cache sharing for HPU (#834)
jakub-sochacki Jan 21, 2026
5fdf237
Merge branch 'main' into libinta/remove_gather_scatter
libinta Jan 21, 2026
07f40c9
add back warmup with ratio and video warmup
libinta Jan 21, 2026
7252043
Merge branch 'releases/v0.14.0' into libinta/remove_gather_scatter
skaulintel Jan 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions tests/unit_tests/worker/test_hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,6 @@ def test_reload_weights_before_load_model(model_runner):
model_runner.reload_weights()


@pytest.mark.xfail(reason="KV sharing doesn't currently work on HPU")
def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(default_vllm_config: None):
torch.set_default_dtype(torch.bfloat16)
layer_0 = "model.layers.0.self_attn.attn"
Expand Down Expand Up @@ -418,7 +417,6 @@ def test_init_kv_cache_with_kv_sharing_invalid_target_layer_order(default_vllm_c
assert fwd_context is not None


@pytest.mark.xfail(reason="KV sharing doesn't currently work on HPU")
def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(default_vllm_config: None):
torch.set_default_dtype(torch.bfloat16)
layer_0 = "model.layers.0.self_attn.attn"
Expand Down Expand Up @@ -448,7 +446,6 @@ def test_init_kv_cache_with_kv_sharing_target_layer_not_exist(default_vllm_confi
assert fwd_context is not None


@pytest.mark.xfail(reason="KV sharing doesn't currently work on HPU")
def test_init_kv_cache_with_kv_sharing_target_same_as_current(default_vllm_config: None):
torch.set_default_dtype(torch.bfloat16)
layer_0 = "model.layers.0.self_attn.attn"
Expand Down Expand Up @@ -544,7 +541,6 @@ def test_init_kv_cache_without_kv_sharing(default_vllm_config: None):
assert kv_cache_config.kv_cache_groups[0].layer_names[1] == layer_1


@pytest.mark.xfail(reason="KV sharing doesn't currently work on HPU")
def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config: None):
torch.set_default_dtype(torch.bfloat16)
layer_0 = "model.layers.0.self_attn.attn"
Expand Down Expand Up @@ -580,10 +576,11 @@ def test_init_kv_cache_with_kv_sharing_valid(default_vllm_config: None):
assert runner.shared_kv_cache_layers[layer_1] == layer_0

available_memory = 20 * GiB_bytes
# page size for layer 0's kv_cache_spec is 32KB
# page size for layer 0's kv_cache_spec is 256KB
# with KV sharing, we can allocate (available_mem//page_size//1) blocks
# which is twice as many as without KV sharing
num_expected_blocks = 655360 # 20GB / 32KB
page_size = 128 * 8 * 64 * 2 * 2 # 128 for block_size, 2 for K+V, 2 for bfloat16
num_expected_blocks = available_memory / page_size # 20GB / 256KB
kv_cache_config = get_kv_cache_configs(vllm_config, [kv_cache_spec], [available_memory])[0]
assert kv_cache_config.num_blocks == num_expected_blocks
assert len(kv_cache_config.kv_cache_tensors) == 1
Expand Down
36 changes: 19 additions & 17 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,10 @@ def __init__(
use_irope: bool = False,
) -> None:
super(AttentionImpl, self).__init__()
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
if kv_sharing_target_layer_name is not None:
raise NotImplementedError("KV sharing is not currently supported on HPU.")
logger.info("[KV sharing] HPUAttentionImpl initialized with kv_sharing_target_layer_name: %s",
self.kv_sharing_target_layer_name)
if use_irope:
logger.warning_once("Using irope in HPU is not supported yet, it will fall back "
"to global attention for long context.")
Expand Down Expand Up @@ -572,22 +574,22 @@ def forward(
if kv_cache is not None and isinstance(kv_cache, tuple):
key_cache, value_cache, k_scales, v_scales = \
HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size)

# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key,
key_cache,
slot_mapping,
scales=k_scales,
block_size=attn_metadata.block_size,
is_prompt=attn_metadata.is_prompt)
value_cache = self.v_cache(value,
value_cache,
slot_mapping,
scales=v_scales,
block_size=attn_metadata.block_size,
is_prompt=attn_metadata.is_prompt)
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
key_cache = self.k_cache(key,
key_cache,
slot_mapping,
scales=k_scales,
block_size=attn_metadata.block_size,
is_prompt=attn_metadata.is_prompt)
value_cache = self.v_cache(value,
value_cache,
slot_mapping,
scales=v_scales,
block_size=attn_metadata.block_size,
is_prompt=attn_metadata.is_prompt)

if attn_metadata.is_prompt:
# Prompt run.
Expand Down
1 change: 1 addition & 0 deletions vllm_gaudi/extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def _fsdpa_prompt_attention(query: torch.Tensor,
query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths,
padding_side
]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didnt chang this file in original PR

args += [window_size] if window_size else []
attn_weights = fsdpa_op(*args)

Expand Down
3 changes: 1 addition & 2 deletions vllm_gaudi/models/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections.abc import Callable
import torch
from torch import Tensor
from vllm.model_executor.models.interfaces import SupportsMultiModal


def _embed_text_input_ids(
Expand Down Expand Up @@ -38,4 +37,4 @@ def _embed_text_input_ids(
return embed_input_ids(input_ids)


SupportsMultiModal._embed_text_input_ids = _embed_text_input_ids
#SupportsMultiModal._embed_text_input_ids = _embed_text_input_ids
89 changes: 85 additions & 4 deletions vllm_gaudi/models/qwen3_vl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from torch import nn
from vllm.model_executor.layers.activation import get_act_fn
import torch
from .utils import _merge_multimodal_embeddings
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.interfaces import _require_is_multimodal
from vllm.model_executor.models.qwen3_vl import (
Qwen3VLForConditionalGeneration,
Qwen3_VisionTransformer,
Expand Down Expand Up @@ -65,9 +68,9 @@ def __init__(
)

depth = vision_config.depth
norm_layer = lambda d: nn.LayerNorm(d, eps=norm_eps)
norm_layer = lambda d: torch.nn.LayerNorm(d, eps=norm_eps)

self.blocks = nn.ModuleList([
self.blocks = torch.nn.ModuleList([
HPUQwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
Expand Down Expand Up @@ -97,3 +100,81 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)

def _compute_deepstack_embeds(
self,
inputs_embeds: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings,
is_multimodal: torch.Tensor,
) -> tuple[torch.Tensor, MultiModalEmbeddings]:
visual_lens = [len(x) for x in multimodal_embeddings]
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)

(
multimodal_embeddings_main,
multimodal_embeddings_multiscale,
) = torch.split(
multimodal_embeddings_cat,
[self.visual_dim, self.multiscale_dim],
dim=-1,
)

multimodal_embeddings = torch.split(multimodal_embeddings_main, visual_lens, dim=0)
multimodal_embeddings_multiscale = torch.split(multimodal_embeddings_multiscale, visual_lens, dim=0)

deepstack_input_embeds = inputs_embeds.new_zeros(inputs_embeds.size(0),
self.deepstack_num_level * inputs_embeds.size(1))

deepstack_input_embeds = _merge_multimodal_embeddings(
inputs_embeds=deepstack_input_embeds,
multimodal_embeddings=multimodal_embeddings_multiscale,
is_multimodal=is_multimodal,
)
deepstack_input_embeds = deepstack_input_embeds.view(inputs_embeds.shape[0], self.deepstack_num_level,
self.visual_dim)
deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)

return deepstack_input_embeds, multimodal_embeddings

def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)

if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds

is_multimodal = _require_is_multimodal(is_multimodal)

if self.use_deepstack:
(
deepstack_input_embeds,
multimodal_embeddings,
) = self._compute_deepstack_embeds(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
else:
deepstack_input_embeds = None

inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)

if deepstack_input_embeds is not None:
self._set_deepstack_input_embeds(deepstack_input_embeds)

return inputs_embeds
11 changes: 4 additions & 7 deletions vllm_gaudi/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,15 @@ def _merge_multimodal_embeddings(
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype

if is_multimodal.dtype == torch.int64:
return inputs_embeds.index_copy_(0, is_multimodal, mm_embeds_flat)
try:
# For debugging
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# htcore.mark_step()

# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
# inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
# mm_embeds_flat.to(dtype=input_dtype))

multimodal_positions = torch.where(is_multimodal)[0][:mm_embeds_flat.shape[0]]
inputs_embeds[0, multimodal_positions] = mm_embeds_flat.to(dtype=input_dtype)

inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype))
except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item()
Expand Down
Loading