Skip to content
Open
4 changes: 2 additions & 2 deletions vllm_gaudi/extension/bucketing/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
},
'qwen3_vl': {
'is_batch_based': False,
#coverage for lmarena-ai/VisionArena-Chat
'buckets': [512, 1024, 2048, 3072, 4096, 5120, 6144, 7168, 8192, 9216, 10240, 11264, 12288, 131076]
'buckets':
[256, 512, 1024, 1350, 1602, 2048, 3072, 4096, 5120, 6144, 7168, 8192, 9216, 10240, 11264, 12288, 131076]
}
}

Expand Down
3 changes: 1 addition & 2 deletions vllm_gaudi/models/interfaces.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from collections.abc import Callable
import torch
from torch import Tensor
from vllm.model_executor.models.interfaces import SupportsMultiModal


def _embed_text_input_ids(
Expand Down Expand Up @@ -38,4 +37,4 @@ def _embed_text_input_ids(
return embed_input_ids(input_ids)


SupportsMultiModal._embed_text_input_ids = _embed_text_input_ids
#SupportsMultiModal._embed_text_input_ids = _embed_text_input_ids
89 changes: 85 additions & 4 deletions vllm_gaudi/models/qwen3_vl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from torch import nn
from vllm.model_executor.layers.activation import get_act_fn
import torch
from .utils import _merge_multimodal_embeddings
from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.interfaces import _require_is_multimodal
from vllm.model_executor.models.qwen3_vl import (
Qwen3VLForConditionalGeneration,
Qwen3_VisionTransformer,
Expand Down Expand Up @@ -65,9 +68,9 @@ def __init__(
)

depth = vision_config.depth
norm_layer = lambda d: nn.LayerNorm(d, eps=norm_eps)
norm_layer = lambda d: torch.nn.LayerNorm(d, eps=norm_eps)

self.blocks = nn.ModuleList([
self.blocks = torch.nn.ModuleList([
HPUQwen3_VisionBlock(
dim=self.hidden_size,
num_heads=self.num_heads,
Expand Down Expand Up @@ -97,3 +100,81 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
multimodal_config=multimodal_config,
prefix=maybe_prefix(prefix, "visual"),
)

def _compute_deepstack_embeds(
self,
inputs_embeds: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings,
is_multimodal: torch.Tensor,
) -> tuple[torch.Tensor, MultiModalEmbeddings]:
visual_lens = [len(x) for x in multimodal_embeddings]
multimodal_embeddings_cat = torch.cat(multimodal_embeddings, dim=0)

(
multimodal_embeddings_main,
multimodal_embeddings_multiscale,
) = torch.split(
multimodal_embeddings_cat,
[self.visual_dim, self.multiscale_dim],
dim=-1,
)

multimodal_embeddings = torch.split(multimodal_embeddings_main, visual_lens, dim=0)
multimodal_embeddings_multiscale = torch.split(multimodal_embeddings_multiscale, visual_lens, dim=0)

deepstack_input_embeds = inputs_embeds.new_zeros(inputs_embeds.size(0),
self.deepstack_num_level * inputs_embeds.size(1))

deepstack_input_embeds = _merge_multimodal_embeddings(
inputs_embeds=deepstack_input_embeds,
multimodal_embeddings=multimodal_embeddings_multiscale,
is_multimodal=is_multimodal,
)
deepstack_input_embeds = deepstack_input_embeds.view(inputs_embeds.shape[0], self.deepstack_num_level,
self.visual_dim)
deepstack_input_embeds = deepstack_input_embeds.permute(1, 0, 2)

return deepstack_input_embeds, multimodal_embeddings

def embed_input_ids(
self,
input_ids: torch.Tensor,
multimodal_embeddings: MultiModalEmbeddings | None = None,
*,
is_multimodal: torch.Tensor | None = None,
handle_oov_mm_token: bool = False,
) -> torch.Tensor:
inputs_embeds = self._embed_text_input_ids(
input_ids,
self.language_model.embed_input_ids,
is_multimodal=is_multimodal,
handle_oov_mm_token=handle_oov_mm_token,
)

if multimodal_embeddings is None or len(multimodal_embeddings) == 0:
return inputs_embeds

is_multimodal = _require_is_multimodal(is_multimodal)

if self.use_deepstack:
(
deepstack_input_embeds,
multimodal_embeddings,
) = self._compute_deepstack_embeds(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)
else:
deepstack_input_embeds = None

inputs_embeds = _merge_multimodal_embeddings(
inputs_embeds=inputs_embeds,
multimodal_embeddings=multimodal_embeddings,
is_multimodal=is_multimodal,
)

if deepstack_input_embeds is not None:
self._set_deepstack_input_embeds(deepstack_input_embeds)

return inputs_embeds
17 changes: 10 additions & 7 deletions vllm_gaudi/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,21 @@ def _merge_multimodal_embeddings(
mm_embeds_flat = _flatten_embeddings(multimodal_embeddings)
input_dtype = inputs_embeds.dtype

if is_multimodal.dtype == torch.int64:
if inputs_embeds.ndim == 3 and mm_embeds_flat.ndim == 2:
original_shape = inputs_embeds.shape
inputs_embeds = inputs_embeds.view(-1, inputs_embeds.shape[-1])
inputs_embeds.index_copy_(0, is_multimodal, mm_embeds_flat)
return inputs_embeds.view(original_shape)
else:
return inputs_embeds.index_copy_(0, is_multimodal, mm_embeds_flat)
try:
# For debugging
# inputs_embeds[is_multimodal] = mm_embeds_flat.to(dtype=input_dtype)
# htcore.mark_step()

# NOTE: This can avoid D2H sync (#22105), but fails to
# raise an error if is_multimodal.sum() < len(mm_embeds_flat)
# inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1),
# mm_embeds_flat.to(dtype=input_dtype))

multimodal_positions = torch.where(is_multimodal)[0][:mm_embeds_flat.shape[0]]
inputs_embeds[0, multimodal_positions] = mm_embeds_flat.to(dtype=input_dtype)

inputs_embeds.masked_scatter_(is_multimodal.unsqueeze(-1), mm_embeds_flat.to(dtype=input_dtype))
except RuntimeError as e:
num_actual_tokens = len(mm_embeds_flat)
num_expected_tokens = is_multimodal.sum().item()
Expand Down
Loading