Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from neural_compressor.torch.quantization import finalize_calibration
else:
finalize_calibration = None
import types

import habana_frameworks.torch as htorch
import habana_frameworks.torch.internal.bridge_config as bc
Expand Down Expand Up @@ -354,6 +355,34 @@ def is_mm_optimized(model):
'Gemma3ForConditionalGeneration' in str(type(model))


def patch_llama4_get_attn_scale(model):

config = getattr(model, "config", None)
is_llama4 = (getattr(config, "model_type", None) == "llama4") or ("llama4" in type(model).__name__.lower())
if not is_llama4:
return

for layer in model.language_model.model.layers:

if "Llama4Attention" not in type(layer.self_attn).__name__:
continue

attn = layer.self_attn
orig = attn._get_attn_scale

def _get_attn_scale_for_hpu(self, positions, _orig=orig):
positions = positions.flatten()
return _orig(positions)

attn._get_attn_scale = types.MethodType(_get_attn_scale_for_hpu, attn)


def apply_model_specific_patches(model):
"""The function applies model-specific monkey patches."""

patch_llama4_get_attn_scale(model)


class HpuModelAdapter(torch.nn.Module, KVConnectorModelRunnerMixin):

def __init__(self, model, vllm_config):
Expand Down Expand Up @@ -3806,6 +3835,7 @@ def load_model(self) -> None:
self.model = self.model.to("hpu")
htcore.mark_step()

apply_model_specific_patches(self.model)
hidden_layer_markstep_interval = int(os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1'))
model_config = getattr(self.model, "config", None)
modify_model_layers(self.model,
Expand Down
Loading