Skip to content
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
9d8e272
Pick model runner change related to PR30475.
libinta Dec 23, 2025
49d7633
add qwen3_vl.py functions
libinta Dec 23, 2025
bdff63f
Merge branch 'main' into libinta/remove_gather_scatter
libinta Dec 23, 2025
c6526de
precomit fix
libinta Dec 24, 2025
7c6329e
precommit fix and fix use_window_sdpa
libinta Dec 25, 2025
bff3cf5
Update qwen3_vl.py
iboiko-habana Dec 29, 2025
625d9c2
Update qwen3_vl.py
iboiko-habana Dec 29, 2025
568b4eb
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana Dec 29, 2025
495643a
Merge branch 'main' into libinta/remove_gather_scatter
libinta Dec 30, 2025
327a9cc
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana Dec 30, 2025
bb3ac24
Update qwen3_vl.py
iboiko-habana Dec 30, 2025
fe67f98
Merge branch 'main' into libinta/remove_gather_scatter
libinta Jan 5, 2026
8a9efd1
Merge branch 'main' into libinta/remove_gather_scatter
libinta Jan 5, 2026
a394b9a
Merge branch 'main' into libinta/remove_gather_scatter
libinta Jan 8, 2026
6502061
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana Jan 15, 2026
48a96db
fix test failure
libinta Jan 15, 2026
0171641
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana Jan 15, 2026
db10548
fix precommit issue
libinta Jan 15, 2026
40d7635
Update interfaces.py for precommit fix
libinta Jan 15, 2026
e23e6d2
Update hpu_model_runner.py to match with upstream for MultiModalBudget
libinta Jan 16, 2026
46facad
Merge branch 'main' into libinta/remove_gather_scatter
libinta Jan 19, 2026
4089adf
Update qwen3_vl.py for precommit fix
libinta Jan 19, 2026
79d90a4
Update qwen3_vl.py for precommit fix
libinta Jan 19, 2026
e370a49
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana Jan 21, 2026
0df1f20
Merge branch 'main' into libinta/remove_gather_scatter
iboiko-habana Jan 21, 2026
5fdf237
Merge branch 'main' into libinta/remove_gather_scatter
libinta Jan 21, 2026
07f40c9
add back warmup with ratio and video warmup
libinta Jan 21, 2026
9db6b78
Update ops.py with removing uncessary change
libinta Jan 21, 2026
9be0056
Update hpu_model_runner.py for precommit fix
libinta Jan 22, 2026
3ff7e80
Merge branch 'main' into libinta/remove_gather_scatter
libinta Jan 22, 2026
b4f2e6c
Update hpu_model_runner.py for precommit fix
libinta Jan 22, 2026
02c239b
Update hpu_model_runner.py for precommit fix
libinta Jan 22, 2026
3dd1f5c
Update hpu_model_runner.py for precommit fix
libinta Jan 22, 2026
7757e80
Update hpu_model_runner.py for precommit fix
libinta Jan 22, 2026
9097164
Merge branch 'main' into libinta/remove_gather_scatter
libinta Jan 22, 2026
913176a
fix qwen2.5vl unified attn test failure
libinta Jan 23, 2026
091c5fe
precommit fix
libinta Jan 23, 2026
f0613fd
precommit fix
libinta Jan 23, 2026
ec827b8
add more mm bucket
libinta Jan 23, 2026
4cf5cb1
precommit fix
libinta Jan 23, 2026
150cf7a
Merge branch 'main' into libinta/remove_gather_scatter
libinta Jan 23, 2026
f46b48d
Update qwen2.5-vl-7b.yaml to revert change
libinta Jan 23, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions vllm_gaudi/extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,7 @@ def _fsdpa_prompt_attention(query: torch.Tensor,
query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths,
padding_side
]

args += [window_size] if window_size else []
attn_weights = fsdpa_op(*args)

Expand Down
4 changes: 4 additions & 0 deletions vllm_gaudi/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ def register_model():
from vllm_gaudi.models.qwen2_5_vl import HpuQwen2_5_VLForConditionalGeneration # noqa: F401
ModelRegistry.register_model("Qwen2_5_VLForConditionalGeneration",
"vllm_gaudi.models.qwen2_5_vl:HpuQwen2_5_VLForConditionalGeneration")

from vllm_gaudi.models.qwen3_vl import HpuQwen3_VLForConditionalGeneration # noqa: F401
ModelRegistry.register_model("Qwen3_VLForConditionalGeneration",
"vllm_gaudi.models.qwen3_vl:HpuQwen3_VLForConditionalGeneration")
3 changes: 1 addition & 2 deletions vllm_gaudi/models/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections.abc import Callable
import torch
from torch import Tensor
from vllm.model_executor.models.interfaces import SupportsMultiModal


def _embed_text_input_ids(
Expand Down Expand Up @@ -38,4 +37,4 @@ def _embed_text_input_ids(
return embed_input_ids(input_ids)


SupportsMultiModal._embed_text_input_ids = _embed_text_input_ids
#SupportsMultiModal._embed_text_input_ids = _embed_text_input_ids
86 changes: 86 additions & 0 deletions vllm_gaudi/models/qwen3_vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
from .utils import _merge_multimodal_embeddings
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.qwen3_vl import Qwen3VLForConditionalGeneration
from vllm.model_executor.models.interfaces import _require_is_multimodal


class HpuQwen3_VLForConditionalGeneration(Qwen3VLForConditionalGeneration):

def _compute_deepstack_embeds(
self,
inputs_embeds: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings,
is_multimodal: torch.Tensor,
) -> tuple[torch.Tensor, MultiModalEmbeddings]:
visual_lens = [len(x) for x in multimodal_embeddings]
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)

(
multimodal_embeddings_main,
multimodal_embeddings_multiscale,
) = torch.split(
multimodal_embeddings_cat,
[self.visual_dim, self.multiscale_dim],
dim=-1,
)

multimodal_embeddings = torch.split(multimodal_embeddings_main, visual_lens, dim=0)
multimodal_embeddings_multiscale = torch.split(multimodal_embeddings_multiscale, visual_lens, dim=0)

deepstack_input_embeds = inputs_embeds.new_zeros(inputs_embeds.size(0),
self.deepstack_num_level * inputs_embeds.size(1))

deepstack_input_embeds = _merge_multimodal_embeddings(
inputs_embeds=deepstack_input_embeds,
multimodal_embeddings=multimodal_embeddings_multiscale,
is_multimodal=is_multimodal,
)
deepstack_input_embeds = deepstack_input_embeds.view(inputs_embeds.shape[0], self.deepstack_num_level,
self.visual_dim)
deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)

return deepstack_input_embeds, multimodal_embeddings

def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)

if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds

is_multimodal = _require_is_multimodal(is_multimodal)

if self.use_deepstack:
(
deepstack_input_embeds,
multimodal_embeddings,
) = self._compute_deepstack_embeds(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
else:
deepstack_input_embeds = None

inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)

if deepstack_input_embeds is not None:
self._set_deepstack_input_embeds(deepstack_input_embeds)

return inputs_embeds
11 changes: 4 additions & 7 deletions vllm_gaudi/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,15 @@ def _merge_multimodal_embeddings(
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype

if is_multimodal.dtype == torch.int64:
return inputs_embeds.index_copy_(0, is_multimodal, mm_embeds_flat)
try:
# For debugging
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# htcore.mark_step()

# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
# inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
# mm_embeds_flat.to(dtype=input_dtype))

multimodal_positions = torch.where(is_multimodal)[0][:mm_embeds_flat.shape[0]]
inputs_embeds[0, multimodal_positions] = mm_embeds_flat.to(dtype=input_dtype)

inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype))
except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item()
Expand Down
46 changes: 27 additions & 19 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
from vllm.model_executor.models.interfaces_base import (VllmModelForPooling, is_pooling_model, is_text_generation_model)
from vllm.tasks import GenerationTask, PoolingTask, SupportedTask
from vllm.transformers_utils.config import is_interleaved
from vllm.v1.worker.utils import (gather_mm_placeholders, sanity_check_mm_encoder_outputs, scatter_mm_placeholders)
from vllm.v1.worker.utils import sanity_check_mm_encoder_outputs
from vllm.v1.sample.rejection_sampler import RejectionSampler
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
Expand Down Expand Up @@ -691,7 +691,7 @@ def __init__(
self.head_size = self.model_config.get_head_size()
self.hidden_size = self.model_config.get_hidden_size()
self.is_pooling_model = (model_config.runner_type == 'pooling')
logger.debug("model config: ", self.model_config)
logger.debug("model config: %s", self.model_config)

self.attn_backend = get_attn_backend(
self.head_size,
Expand Down Expand Up @@ -1269,11 +1269,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput", req_ids: list
if req_id not in self.encoder_cache:
self.encoder_cache[req_id] = {}

self.encoder_cache[mm_hash] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed.to(
device=output.device) if pos_info.is_embed is not None else pos_info.is_embed,
)
self.encoder_cache[mm_hash] = output

# modified from: vllm/v1/worker/gpu_model_runner.py
def _gather_mm_embeddings(
Expand All @@ -1291,6 +1287,7 @@ def _gather_mm_embeddings(

req_start_idx = 0
for req_id in req_ids:
mm_embeds_req: list[torch.Tensor] = []
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_state = self.requests[req_id]
num_computed_tokens = \
Expand All @@ -1315,6 +1312,11 @@ def _gather_mm_embeddings(
start_idx = max(num_computed_tokens - start_pos, 0)
end_idx = min(num_computed_tokens - start_pos + num_scheduled_tokens, num_encoder_tokens)
assert start_idx < end_idx
curr_embeds_start, curr_embeds_end = (pos_info.get_embeds_indices_in_range(start_idx, end_idx))
# If there are no embeddings in the current range, we skip
# gathering the embeddings.
if curr_embeds_start == curr_embeds_end:
continue
mm_hash = mm_feature.identifier
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None,\
Expand All @@ -1323,21 +1325,26 @@ def _gather_mm_embeddings(

if (is_embed := pos_info.is_embed) is not None:
is_embed = is_embed[start_idx:end_idx]
mm_embeds_item = encoder_output[curr_embeds_start:curr_embeds_end]
else:
mm_embeds_item = encoder_output[start_idx:end_idx]

mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos+start_idx:req_start_pos + end_idx] \
= True

# Only whole mm items are processed
mm_embeds.append(mm_embeds_item)
is_mm_embed[req_start_pos + start_idx:req_start_pos +
end_idx] = (True if is_embed is None else is_embed)
mm_embeds_req.append(mm_embeds_item)
mm_embeds.extend(mm_embeds_req)
req_start_idx += num_scheduled_tokens

is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)

# Convert bool tensor to index tensor for merge embedding statically if optimized mm
if self.uses_mrope:
is_mm_embed_index = torch.nonzero(is_mm_embed[:total_num_scheduled_tokens], as_tuple=True)[0]
# Bounds validation on CPU
if len(is_mm_embed_index) > 0 and is_mm_embed_index.max() >= total_num_scheduled_tokens:
raise ValueError(f"Index {is_mm_embed_index.max()} exceeds tensor size {total_num_scheduled_tokens}")
is_mm_embed = is_mm_embed_index.to(self.device)
else:
is_mm_embed = self.is_mm_embed.copy_to_gpu(total_num_scheduled_tokens)
return mm_embeds, is_mm_embed

def _get_model_mm_inputs(
Expand Down Expand Up @@ -3755,6 +3762,7 @@ def load_model(self) -> None:
self._maybe_compile(self.model)
self.model_memory_usage = m.consumed_device_memory
logger.info("Compilation took %.4f GB", self.model_memory_usage / float(2**30))
self.is_mm_optimized = is_mm_optimized(self.model)

def _maybe_compile(self, *args, **kwargs):
"""Entrypoint for a torch.compilation of the model"""
Expand Down Expand Up @@ -5436,7 +5444,7 @@ def __init__(
if self.interleaved_sliding_window:
self.use_window_sdpa = with_default(get_config().PT_HPU_SDPA_QKV_SLICE_MODE_FWD, False)
#os.getenv("PT_HPU_SDPA_QKV_SLICE_MODE_FWD", "false").strip().lower() in ("1", "true")
self.slice_size = with_default(get_config().PT_HPU_SDPA_BC_FACTOR, False)
self.slice_size = int(with_default(get_config().PT_HPU_SDPA_BC_FACTOR, "1024"))
# int(os.getenv("PT_HPU_SDPA_BC_FACTOR", "1024"))
self.slice_thld = int(os.environ.get('VLLM_FUSEDSDPA_SLIDE_THLD', '8192'))

Expand Down
Loading