From 7d63e51fb366b3aefe5d74cfdf4d65192f5fed4d Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar Date: Thu, 1 Jan 2026 13:33:21 +0200 Subject: [PATCH 01/15] GPT OSS Code Signed-off-by: Himangshu Lahkar --- vllm_gaudi/attention/backends/hpu_attn.py | 19 ++++ vllm_gaudi/extension/ops.py | 128 +++++++++++++++++----- vllm_gaudi/extension/utils.py | 6 +- vllm_gaudi/ops/hpu_fused_moe.py | 27 ++++- vllm_gaudi/v1/worker/hpu_model_runner.py | 2 +- 5 files changed, 151 insertions(+), 31 deletions(-) diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index f619be467..190150f46 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -198,6 +198,7 @@ def __init__( qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, + sinks: Optional[torch.Tensor] = None, **kwargs, ) -> None: torch.nn.Module.__init__(self) @@ -252,6 +253,11 @@ def __init__( "encoder/decoder cross-attention " "are not implemented for " "TritonMLAImpl") + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}.") def forward( self, @@ -428,6 +434,7 @@ def __init__( attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, + sinks: Optional[torch.Tensor] = None, ) -> None: super(AttentionImpl, self).__init__() if kv_sharing_target_layer_name is not None: @@ -492,6 +499,11 @@ def __init__( raise NotImplementedError("Encoder self-attention " "is not implemented for " "HPUAttentionImpl") + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}.") def _maybe_init_alibi_biases( self, @@ -576,6 +588,12 @@ def forward( # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are # not cached. This happens during the initial memory profiling run. + if key.dtype != key_cache.dtype: + key = key.to(key_cache.dtype) + if value.dtype != value_cache.dtype: + value = value.to(value_cache.dtype) + if query.dtype != key.dtype: + query = query.to(key.dtype) key_cache = self.k_cache(key, key_cache, slot_mapping, k_scales) value_cache = self.v_cache(value, value_cache, slot_mapping, v_scales) @@ -691,6 +709,7 @@ def common_attention_args(self, 'key_cache': key_cache, 'value_cache': value_cache, 'block_size': block_size, + "sinks": self.sinks, 'k_scales': k_scales, 'v_scales': v_scales, } diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index a40388d3e..5817d212f 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -65,7 +65,7 @@ def matmul_shape(lhs, rhs): return result -def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, batch_size, matmul_av_op, batch2block_matmul_op, +def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, sink, batch_size, matmul_av_op, batch2block_matmul_op, block2batch_matmul_op): # When fp32_softmax is enabled attn is left in fp32 after Q@K # We can return to native dtype after we renormalize and calculate the adjustments @@ -80,11 +80,27 @@ def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, batch_siz if block_bias is not None: attn.add_(block_bias) block_max = attn.amax(dim=-1, keepdim=True) + if sink is not None: + block_max = torch.maximum(block_max, sink) attn = attn.sub(block_max) attn = attn.exp() if attn.dtype == torch.float32: attn = attn.to(value.dtype) - block_sums = attn.sum(dim=-1, keepdim=True) + attn_shape = attn.shape + block_sums = attn.view(-1, attn_shape[-1]).sum(dim=-1, keepdim=True) + attn_shape = list(attn_shape) + attn_shape[-1] = 1 + block_sums = block_sums.view(attn_shape) + if sink is not None: + attn_sink = sink.sub(block_max) + attn_sink = attn_sink.exp() + if attn_sink.dtype == torch.float32: + attn_sink = attn_sink.to(value.dtype) + #TODO: Removing this .sum and using attn_sink directly + #results in wrong output which does not make sense. + #Looks like a Synapse issue, need to investigate further. + block_sums_sink = attn_sink.sum(dim=-1, keepdim=True) + block_sums = block_sums + block_sums_sink attn = matmul_av_op(attn, value) if get_config().fused_block_softmax_adjustment: out_shape = list(attn.shape[:3]) + [1] * (attn.dim() - 3) @@ -177,7 +193,7 @@ def flat_pa_mla(query, key_cache, value_cache, block_list, block_mapping, block_ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias, block_groups, block_size, scale, matmul_qk_op, position_bias, matmul_av_op, batch2block_matmul_op, block2batch_matmul_op, keys_fetch_func, - values_fetch_func, k_scales, v_scales, **ignored_args): + values_fetch_func, sinks, k_scales, v_scales, **ignored_args): batch_size, _, hidden_size = query.shape _, kv_heads, head_size = key_cache.shape q_heads = hidden_size // head_size @@ -195,6 +211,13 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias value = values_fetch_func(value_cache.unflatten(0, (-1, block_size)), **get_kv_fetch_extra_args(blocks=block_list, scales=v_scales_uf)).transpose(1, 2) block_bias = block_bias.view(key.size(0), 1, 1, -1) + sink = None + if sinks is not None: + sinks = sinks.reshape(sinks.shape[0], 1) + sink = sinks.reshape(1, sinks.shape[0], 1, sinks.shape[1]) + sink = sink.expand(query.shape[0], -1, query.shape[-2], -1) + if kv_heads != q_heads: + sink = sink.unflatten(1, (kv_heads, -1)) if kv_heads != q_heads: query = query.unflatten(1, (kv_heads, -1)) key = key.unflatten(1, (kv_heads, 1)) @@ -232,6 +255,7 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias block_bias, block_groups, block_mapping, + sink, batch_size=batch_size, matmul_av_op=matmul_av_op, batch2block_matmul_op=batch2block_matmul_op, @@ -290,6 +314,7 @@ def _naive_prompt_attention(query: torch.Tensor, matmul_qk_op=torch.matmul, softmax_op=torch.softmax, matmul_av_op=torch.matmul, + sinks: Optional[torch.Tensor] = None, **ignored_args) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -321,10 +346,19 @@ def _naive_prompt_attention(query: torch.Tensor, if attn_weights.dtype != attn_bias.dtype: attn_bias = attn_bias.to(dtype=attn_weights.dtype) attn_weights.add_(attn_bias) + if sinks is not None: + sink = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + if query_heads != kv_heads: + sink = sink.unflatten(1, (kv_heads, -1)) + combined_logits = torch.cat([attn_weights, sink], dim=-1) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + attn_weights = combined_logits if get_config().fp32_softmax: attn_weights = torch.softmax(attn_weights, dim=-1) else: attn_weights = softmax_op(attn_weights, dim=-1) + if sinks is not None: + attn_weights = attn_weights[..., :-1] attn_weights = attn_weights.to(query.dtype) attn_weights = matmul_av_op(attn_weights, value) @@ -343,6 +377,7 @@ def _fsdpa_prompt_attention(query: torch.Tensor, attn_bias: Optional[torch.Tensor] = None, valid_seq_lengths: Optional[torch.Tensor] = None, window_size: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, **ignored_args) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -364,10 +399,16 @@ def _fsdpa_prompt_attention(query: torch.Tensor, query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths, padding_side ] - args += [window_size] if window_size else [] + args += [window_size] if window_size else [None] + # use sinks in fsdpa + if sinks is not None: + args += [sinks] attn_weights = fsdpa_op(*args) attn_weights = attn_weights.transpose(1, 2) + if sinks is not None: + # TODO - check if we can remove this + htcore.mark_step() return attn_weights @@ -487,6 +528,9 @@ def __init__(self): def set_weight(self, w): self.weight = w + def set_bias(self, b): + self.bias = b + def forward(self, state, expert_id, w): raise NotImplementedError() @@ -498,12 +542,14 @@ def __init__(self, num_total_experts: int, experts_min: int = 0, experts_max: int = 8, + bias = None, dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): super().__init__() self.experts_min = experts_min self.experts_max = experts_max self.global_num_experts = global_num_experts self.num_experts = num_total_experts + self.bias = bias if MAX_EXPERTS_PER_SLICE > 0: max_expert_per_slice = MAX_EXPERTS_PER_SLICE @@ -555,8 +601,9 @@ def __init__(self, num_total_experts: int, experts_min: int = 0, experts_max: int = 8, + bias=None, dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): - super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, dispatch_fn) + super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, bias, dispatch_fn) self.w13_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) self.w2_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) @@ -569,31 +616,62 @@ def forward(self, hidden_states, expert_routing_table, router_weights, permuted_ w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range] if self.moe_n_slice == 1: - return torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states, - expert_routing_table=expert_routing_table, - router_weights=router_weights, - w12=w1_list, - w3=w2_list, - permuted_weights=permuted_weights, - activation=activation, - experts_min=self.experts_min, - experts_max=self.experts_max, - **kwargs) + if self.bias is not None: + w1_bias_list = [self.w13_list[i].bias.squeeze() for i in experts_range] + w2_bias_list = [self.w2_list[i].bias.squeeze() for i in experts_range] + return torch.ops.hpu.mixture_of_experts.bias_fused_weights(hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list, + w3=w2_list, + w12_bias=w1_bias_list, + w3_bias=w2_bias_list, + permuted_weights=permuted_weights, + experts_min=self.experts_min, + experts_max=self.experts_max) + else: + return torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list, + w3=w2_list, + permuted_weights=permuted_weights, + activation=activation, + experts_min=self.experts_min, + experts_max=self.experts_max, + **kwargs) for i in range(self.moe_n_slice): w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] min_expert = self.experts_min + i * self.num_expert_per_group max_expert = min_expert + self.num_expert_per_group - 1 - slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states, - expert_routing_table=expert_routing_table, - router_weights=router_weights, - w12=w1_list_slice, - w3=w2_list_slice, - permuted_weights=permuted_weights, - activation=activation, - experts_min=min_expert, - experts_max=max_expert, - **kwargs) + if self.bias is not None: + w1_bias_list = [self.w13_list[i].bias.squeeze() for i in experts_range] + w2_bias_list = [self.w2_list[i].bias.squeeze() for i in experts_range] + w1_bias_list_slice = w1_bias_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] + w2_bias_list_slice = w2_bias_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] + slice_final_hidden_states = torch.ops.hpu.mixture_of_experts.bias_fused_weights( + hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list, + w3=w2_list, + w12_bias=w1_bias_list_slice, + w3_bias=w2_bias_list_slice, + permuted_weights=permuted_weights, + experts_min=self.experts_min, + experts_max=self.experts_max) + else: + slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list_slice, + w3=w2_list_slice, + permuted_weights=permuted_weights, + activation=activation, + experts_min=min_expert, + experts_max=max_expert, + **kwargs) if i == 0: final_hidden_states = slice_final_hidden_states else: diff --git a/vllm_gaudi/extension/utils.py b/vllm_gaudi/extension/utils.py index bcdd05b21..5c751ad35 100644 --- a/vllm_gaudi/extension/utils.py +++ b/vllm_gaudi/extension/utils.py @@ -157,14 +157,16 @@ def forward( valid_sequence_lengths, padding_side="left", window_size=None, + sinks=None, ): if window_size is not None: return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side, False, False, - window_size) + window_size, sinks) else: return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, - recompute_mode, valid_sequence_lengths, padding_side) + recompute_mode, valid_sequence_lengths, padding_side, False, False, + (-1, -1), sinks) def pad_list(input, target_len, val_generator): diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index a0b55b1a0..8e69f3045 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -3,6 +3,7 @@ import torch import vllm +from vllm.config import get_current_vllm_config from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, UnquantizedFusedMoEMethod) @@ -89,12 +90,15 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.use_dispatch_fn = get_config().use_dispatch_fn torch.hpu.synchronize() + vllm_config = get_current_vllm_config() + self.model_type = vllm_config.model_config.hf_config.model_type def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) # custom handling for HPU num_experts = layer.local_num_experts ep_shift = layer.ep_rank * num_experts + has_bias = hasattr(layer, 'w13_bias') and hasattr(layer, 'w2_bias') experts_min, experts_max = ep_shift, num_experts + ep_shift - 1 @@ -108,12 +112,16 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts, experts_min, experts_max, + has_bias, dispatch_fn, ) for expert_id in range(layer.local_num_experts): layer.moe_op.w13_list[expert_id].set_weight(layer.w13_weight.data[expert_id]) layer.moe_op.w2_list[expert_id].set_weight(layer.w2_weight.data[expert_id]) + if has_bias: + layer.moe_op.w13_list[expert_id].set_bias(layer.w13_bias.data[expert_id]) + layer.moe_op.w2_list[expert_id].set_bias(layer.w2_bias.data[expert_id]) def forward_oot( self, @@ -128,9 +136,13 @@ def forward_oot( topk_weights, topk_ids = layer.select_experts(hidden_states=x, router_logits=router_logits) else: import torch.nn.functional as F - topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) - topk_weights, topk_ids = torch.topk(topk_weights, layer.top_k, dim=-1) - topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + if self.model_type in ["gpt_oss"]: + topk_weights, topk_ids = torch.topk(router_logits, layer.top_k, dim=-1) + topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32) + else: + topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(topk_weights, layer.top_k, dim=-1) + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights.to(x.dtype) if not layer.use_grouped_topk: @@ -151,6 +163,15 @@ def forward_oot( topk_ids = topk_ids.view(-1, topk_ids.shape[-1]) topk_weights = topk_weights.view(-1, topk_weights.shape[-1]) + if self.model_type in ["gpt_oss"]: + return layer.moe_op( + x, + topk_ids.to(torch.int64), + topk_weights.to(x.dtype), + permuted_weights=True, + activation=layer.activation, + ).view(*input_shape) + output = layer.moe_op( x, topk_ids, diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 7d3d392d3..9b6ba24d7 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2090,7 +2090,7 @@ def _create_decode_input_data(self, ) if self.interleaved_sliding_window and self.sliding_window is not None and self.sliding_window > 0: - sliding_block_size = (self.sliding_window // self.block_size) + sliding_block_size = (self.sliding_window // self.block_size) + 1 window_block_tables = [block_table[-sliding_block_size:] for block_table in block_tables_list] window_block_list, window_block_groups, window_block_usage = \ self.get_habana_paged_attn_buffers( From 0b49b976bb4d93b67c4a3e42d7b5ca3c379802be Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar Date: Fri, 2 Jan 2026 04:36:00 +0200 Subject: [PATCH 02/15] Update MOE Signed-off-by: Himangshu Lahkar --- vllm_gaudi/extension/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 5817d212f..76be85afd 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -542,7 +542,7 @@ def __init__(self, num_total_experts: int, experts_min: int = 0, experts_max: int = 8, - bias = None, + bias = None, dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): super().__init__() self.experts_min = experts_min @@ -601,7 +601,7 @@ def __init__(self, num_total_experts: int, experts_min: int = 0, experts_max: int = 8, - bias=None, + bias=None, dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, bias, dispatch_fn) self.w13_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) From 9cdf3f39d6d4608afe34995ead3014cd27de73b4 Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar Date: Fri, 2 Jan 2026 04:43:22 +0200 Subject: [PATCH 03/15] Update Pipelined PA Signed-off-by: Himangshu Lahkar --- vllm_gaudi/extension/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 76be85afd..cecc130c5 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -65,8 +65,8 @@ def matmul_shape(lhs, rhs): return result -def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, sink, batch_size, matmul_av_op, batch2block_matmul_op, - block2batch_matmul_op): +def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, sink, batch_size, matmul_av_op, + batch2block_matmul_op, block2batch_matmul_op): # When fp32_softmax is enabled attn is left in fp32 after Q@K # We can return to native dtype after we renormalize and calculate the adjustments if block_bias is not None and attn.dtype != block_bias.dtype: From 7c8e4eb54927c5268c89d35850ac01fb7d9983af Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar Date: Fri, 2 Jan 2026 04:46:43 +0200 Subject: [PATCH 04/15] Format MOE Signed-off-by: Himangshu Lahkar --- vllm_gaudi/extension/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index cecc130c5..e9369ac29 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -542,7 +542,7 @@ def __init__(self, num_total_experts: int, experts_min: int = 0, experts_max: int = 8, - bias = None, + bias=None, dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): super().__init__() self.experts_min = experts_min From cb9bd94db741ddfb76afb53bd042ce0416d75dc6 Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar Date: Fri, 2 Jan 2026 06:59:54 +0200 Subject: [PATCH 05/15] Update FSDPA Signed-off-by: Himangshu Lahkar --- vllm_gaudi/extension/ops.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index e9369ac29..7c17f2183 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -399,7 +399,10 @@ def _fsdpa_prompt_attention(query: torch.Tensor, query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths, padding_side ] - args += [window_size] if window_size else [None] + if sinks is not None: + args += [window_size] if window_size else [None] + else: + args += [window_size] if window_size else [] # use sinks in fsdpa if sinks is not None: args += [sinks] From 4b2f0ffec113b33d06018412b1c4381d838db7bb Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar Date: Mon, 19 Jan 2026 10:02:38 +0200 Subject: [PATCH 06/15] Set Model type to None if config is None Signed-off-by: Himangshu Lahkar --- vllm_gaudi/ops/hpu_fused_moe.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index 803577cac..d237ee03c 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -92,7 +92,10 @@ def __init__(self, *args, **kwargs): self.use_dispatch_fn = get_config().use_dispatch_fn torch.hpu.synchronize() vllm_config = get_current_vllm_config() - self.model_type = vllm_config.model_config.hf_config.model_type + if vllm_config is not None: + self.model_type = vllm_config.model_config.hf_config.model_type + else: + self.model_type = None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) @@ -138,7 +141,7 @@ def forward_oot( topk_weights, topk_ids = layer.router.select_experts(hidden_states=x, router_logits=router_logits) else: import torch.nn.functional as F - if self.model_type in ["gpt_oss"]: + if self.model_type is not None and self.model_type in ["gpt_oss"]: topk_weights, topk_ids = torch.topk(router_logits, layer.top_k, dim=-1) topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32) else: From 69f417831666d3529087b9bc5cf5458c2b1f0457 Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar Date: Thu, 22 Jan 2026 11:57:16 +0200 Subject: [PATCH 07/15] qkv dtype cast, needed for GPT-OSS Signed-off-by: Himangshu Lahkar --- vllm_gaudi/attention/backends/hpu_attn.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index f4cd0fade..02d4899f1 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -586,6 +586,12 @@ def forward( if kv_cache is not None and isinstance(kv_cache, tuple): key_cache, value_cache, k_scales, v_scales = \ HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size) + if key.dtype != key_cache.dtype: + key = key.to(key_cache.dtype) + if value.dtype != value_cache.dtype: + value = value.to(value_cache.dtype) + if query.dtype != key.dtype: + query = query.to(key.dtype) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are From bc26332465d6a00bcd483392c21f685b326d5cc9 Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar Date: Fri, 23 Jan 2026 06:33:21 +0200 Subject: [PATCH 08/15] Put a check for hf_config Signed-off-by: Himangshu Lahkar --- vllm_gaudi/ops/hpu_fused_moe.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index f89fcadff..5f61907d7 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -92,10 +92,10 @@ def __init__(self, *args, **kwargs): self.use_dispatch_fn = get_config().use_dispatch_fn torch.hpu.synchronize() vllm_config = get_current_vllm_config() - if vllm_config is not None: + self.model_type = None + if vllm_config is not None and vllm_config.model_config is not None \ + and vllm_config.model_config.hf_config is not None: self.model_type = vllm_config.model_config.hf_config.model_type - else: - self.model_type = None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) From a63ceb2347a0e25864259992fcc47c9cf87ab1bd Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar Date: Fri, 23 Jan 2026 12:45:21 +0200 Subject: [PATCH 09/15] Make bias the last arg Signed-off-by: Himangshu Lahkar --- vllm_gaudi/extension/ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 881f6ba95..cb7c36666 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -556,8 +556,8 @@ def __init__(self, num_total_experts: int, experts_min: int = 0, experts_max: int = 8, - bias=None, - dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): + dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None, + bias=None): super().__init__() self.experts_min = experts_min self.experts_max = experts_max @@ -615,8 +615,8 @@ def __init__(self, num_total_experts: int, experts_min: int = 0, experts_max: int = 8, - bias=None, - dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): + dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None, + bias=None): super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, bias, dispatch_fn) self.w13_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) self.w2_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) From 4b6fb2894d6d2b7fb976edf0a764f159362043c7 Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar Date: Fri, 23 Jan 2026 14:58:03 +0200 Subject: [PATCH 10/15] bias ordering made proper Signed-off-by: Himangshu Lahkar --- vllm_gaudi/extension/ops.py | 2 +- vllm_gaudi/ops/hpu_fused_moe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index cb7c36666..73dbebf2a 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -617,7 +617,7 @@ def __init__(self, experts_max: int = 8, dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None, bias=None): - super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, bias, dispatch_fn) + super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, dispatch_fn, bias) self.w13_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) self.w2_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index 5f61907d7..f8ccc98a5 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -116,8 +116,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: num_experts, experts_min, experts_max, - has_bias, dispatch_fn, + has_bias ) for expert_id in range(layer.local_num_experts): From 0f46ae2e6df8e2228c17d70dd00fe7c6ff8de34c Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar Date: Fri, 23 Jan 2026 17:32:07 +0200 Subject: [PATCH 11/15] check for bias not None Signed-off-by: Himangshu Lahkar --- vllm_gaudi/extension/ops.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 73dbebf2a..eaba72db3 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -538,6 +538,7 @@ class MoeMatmul(torch.nn.Module): def __init__(self): super().__init__() + self.bias = None def set_weight(self, w): self.weight = w @@ -557,7 +558,7 @@ def __init__(self, experts_min: int = 0, experts_max: int = 8, dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None, - bias=None): + bias=False): super().__init__() self.experts_min = experts_min self.experts_max = experts_max @@ -616,7 +617,7 @@ def __init__(self, experts_min: int = 0, experts_max: int = 8, dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None, - bias=None): + bias=False): super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, dispatch_fn, bias) self.w13_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) self.w2_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) @@ -630,7 +631,8 @@ def forward(self, hidden_states, expert_routing_table, router_weights, permuted_ w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range] if self.moe_n_slice == 1: - if self.bias is not None: + if self.bias is True and self.w13_list[i].bias is not None \ + and self.w2_list[i].bias is not None: w1_bias_list = [self.w13_list[i].bias.squeeze() for i in experts_range] w2_bias_list = [self.w2_list[i].bias.squeeze() for i in experts_range] return torch.ops.hpu.mixture_of_experts.bias_fused_weights(hidden_states=hidden_states, @@ -659,7 +661,8 @@ def forward(self, hidden_states, expert_routing_table, router_weights, permuted_ w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] min_expert = self.experts_min + i * self.num_expert_per_group max_expert = min_expert + self.num_expert_per_group - 1 - if self.bias is not None: + if self.bias is not None and self.w13_list[i].bias is not None \ + and self.w2_list is not None: w1_bias_list = [self.w13_list[i].bias.squeeze() for i in experts_range] w2_bias_list = [self.w2_list[i].bias.squeeze() for i in experts_range] w1_bias_list_slice = w1_bias_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] From 8b4bd00837c0d0d7fb5928f58a1d76c3f8bdeb12 Mon Sep 17 00:00:00 2001 From: Artur Fierka Date: Mon, 26 Jan 2026 08:13:14 +0100 Subject: [PATCH 12/15] Fix Llama4 shape mismatch for 32k+ context window (#842) (#855) Llama4 for `max_model_len > 32k` enable temperature adjustment https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L719. Enabled adjustment causes tensor `q` shape modification from 2D to 3D: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama4.py#L307. This tensor is passing to `UnqnatizedFusedMoEMetod -> forward`: https://github.com/vllm-project/vllm-gaudi/blob/main/vllm_gaudi/ops/hpu_fused_moe.py#L163 causing invalid reshaping - we trying to return a 3D `output.view` based on 2D output tensor. Found that following PR introduced the bug: #680 and #684 Cherry-picked from `releases/v0.13.0` --------- Signed-off-by: Artur Fierka --- vllm_gaudi/ops/hpu_fused_moe.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index f8ccc98a5..6612571e4 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -185,7 +185,10 @@ def forward_oot( permuted_weights=True, activation=layer.activation, ) - return output.view(*(output.size(0), *input_shape[1:])) + if layer.dp_size > 1: + return output.view(*(output.size(0), *input_shape[1:])) + else: + return output.view(*input_shape) def reduce_output(self, states: torch.Tensor) -> torch.Tensor: From ba3e55af0f3b2829a24b91119ee9f13040b93564 Mon Sep 17 00:00:00 2001 From: git config -lDudi Lester <160421192+dudilester@users.noreply.github.com> Date: Mon, 26 Jan 2026 10:23:34 +0200 Subject: [PATCH 13/15] Fix HPU model runner profile_run to work with dynamic kv-cache scales (#852) Signed-off-by: Dudi Lester Co-authored-by: Kamil Kaczor --- vllm_gaudi/v1/worker/hpu_worker.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm_gaudi/v1/worker/hpu_worker.py b/vllm_gaudi/v1/worker/hpu_worker.py index ae5705e54..ac63d1b77 100644 --- a/vllm_gaudi/v1/worker/hpu_worker.py +++ b/vllm_gaudi/v1/worker/hpu_worker.py @@ -201,10 +201,11 @@ def determine_available_memory(self) -> int: hpu_k_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device='hpu') if create_dynamic_scales else None - if hpu_v_cache is None: - hpu_v_scales = None - elif create_dynamic_scales: - hpu_v_scales = torch.ones(kv_scales_shape, dtype=torch.bfloat16, device='hpu') + if create_dynamic_scales: + hpu_v_scales = (torch.ones(kv_scales_shape, dtype=torch.bfloat16, device='hpu'), + torch.ones([num_blocks, num_kv_heads, head_size], + dtype=torch.bfloat16, + device='hpu')) else: hpu_v_scales = None From 8a68c7d12a4ca2d03580ca1440f9c27f5f3627c6 Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz <160237065+adobrzyn@users.noreply.github.com> Date: Mon, 26 Jan 2026 15:36:38 +0100 Subject: [PATCH 14/15] Revert "skip HPU graphs for long prefills" (#850) Reverts vllm-project/vllm-gaudi#780 --------- Signed-off-by: Agata Dobrzyniewicz Co-authored-by: Chendi.Xue --- tests/full_tests/ci_gsm8k_tests.sh | 2 +- tests/full_tests/ci_perf_tests.sh | 2 +- vllm_gaudi/v1/worker/hpu_model_runner.py | 10 ++++------ 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/tests/full_tests/ci_gsm8k_tests.sh b/tests/full_tests/ci_gsm8k_tests.sh index 822592629..0cba7bec7 100644 --- a/tests/full_tests/ci_gsm8k_tests.sh +++ b/tests/full_tests/ci_gsm8k_tests.sh @@ -99,7 +99,7 @@ run_qwen3_compressed_tensor_dynamic_scaling_test() { # QWEN3 FP8 + MOE compressed tensor + dynamic scaling run_qwen3_moe_compressed_tensor_dynamic_scaling_test() { echo "➡️ Testing Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 + moe + compressed-tensor + dynamic scaling..." - HABANA_VISIBLE_DEVICES=all VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 python -u "${VLLM_GAUDI_PREFIX}/tests/full_tests/generate.py" --model Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --trust-remote-code + HABANA_VISIBLE_DEVICES=all VLLM_CONTIGUOUS_PA=False VLLM_SKIP_WARMUP=true PT_HPU_LAZY_MODE=1 python -u "${VLLM_GAUDI_PREFIX}/tests/full_tests/generate.py" --model Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 --trust-remote-code --max-model-len 131072 echo "✅ Test with Qwen/Qwen3-30B-A3B-Instruct-2507-FP8 + moe + compressed-tensor + dynamic scaling successful." } diff --git a/tests/full_tests/ci_perf_tests.sh b/tests/full_tests/ci_perf_tests.sh index fb94a1956..2066572ee 100644 --- a/tests/full_tests/ci_perf_tests.sh +++ b/tests/full_tests/ci_perf_tests.sh @@ -37,4 +37,4 @@ vllm bench throughput \ --dataset_path ShareGPT_V3_unfiltered_cleaned_split.json \ --dataset_name sharegpt \ --num-prompts 1000 \ - --max-model-len 32768 + --max-model-len 16384 diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index ceb5c06bb..dc10a2895 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -822,14 +822,14 @@ def __init__( self.use_hpu_graph = not self.model_config.enforce_eager self.max_batch_size = self.scheduler_config.max_num_seqs self.max_num_seqs = self.scheduler_config.max_num_seqs + self.max_cudagraph_capture_size = self.vllm_config.compilation_config.max_cudagraph_capture_size if prompt_profile_cfg: self.max_prefill_batch_size = prompt_profile_cfg[0] else: self.max_prefill_batch_size = with_default(get_config().VLLM_PROMPT_BS_BUCKET_MAX, 1) self.seen_configs: set = set() - self.max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens - self.max_graph_capture_tokens = self.vllm_config.compilation_config.max_cudagraph_capture_size if \ - self.vllm_config.compilation_config.max_cudagraph_capture_size is not None else self.max_num_batched_tokens + self.max_num_batched_tokens = \ + self.scheduler_config.max_num_batched_tokens self.use_prefix_caching = (self.vllm_config.cache_config.enable_prefix_caching) self.bucketing_manager = HPUBucketingManager() max_num_prefill_seqs = self.max_num_seqs if self.use_merged_prefill \ @@ -2687,9 +2687,7 @@ def _execute_model_generic(self, additional_kwargs = {} if htorch.utils.internal.is_lazy(): use_graphs = self._use_graphs() - # skip HPU graphs for long prefills - if seq_len > 1 and \ - batch_size * (seq_len + num_blocks * self.block_size) > self.max_graph_capture_tokens: + if self.max_cudagraph_capture_size is not None and batch_size * seq_len > self.max_cudagraph_capture_size: use_graphs = False additional_kwargs.update({"bypass_hpu_graphs": not use_graphs}) else: From f3a45602978ed8deb2e3809539e1b8add38892e7 Mon Sep 17 00:00:00 2001 From: Himangshu Lahkar Date: Tue, 27 Jan 2026 08:09:39 +0200 Subject: [PATCH 15/15] fix pre-commit error Signed-off-by: Himangshu Lahkar --- vllm_gaudi/ops/hpu_fused_moe.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index 6612571e4..3562f6a33 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -111,14 +111,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: else: dispatch_fn = None - layer.moe_op = VllmMixtureOfExpertsOp( - layer.global_num_experts, - num_experts, - experts_min, - experts_max, - dispatch_fn, - has_bias - ) + layer.moe_op = VllmMixtureOfExpertsOp(layer.global_num_experts, num_experts, experts_min, experts_max, + dispatch_fn, has_bias) for expert_id in range(layer.local_num_experts): layer.moe_op.w13_list[expert_id].set_weight(layer.w13_weight.data[expert_id])