Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7d63e51
GPT OSS Code
hlahkar Jan 1, 2026
0b49b97
Update MOE
hlahkar Jan 2, 2026
9cdf3f3
Update Pipelined PA
hlahkar Jan 2, 2026
7c8e4eb
Format MOE
hlahkar Jan 2, 2026
cb9bd94
Update FSDPA
hlahkar Jan 2, 2026
14408fe
Merge branch 'main' into gpt_oss_latest
hlahkar Jan 14, 2026
c609cd5
Merge branch 'main' into gpt_oss_latest
hlahkar Jan 16, 2026
e3c8b52
Merge branch 'main' into gpt_oss_latest
hlahkar Jan 19, 2026
4b2f0ff
Set Model type to None if config is None
hlahkar Jan 19, 2026
f494059
Merge branch 'main' into gpt_oss_latest
hlahkar Jan 19, 2026
9a37437
Merge branch 'main' into gpt_oss_latest
wpyszka Jan 19, 2026
87edb67
Merge branch 'main' into gpt_oss_latest
hlahkar Jan 22, 2026
69f4178
qkv dtype cast, needed for GPT-OSS
hlahkar Jan 22, 2026
f2c3d2d
Merge branch 'main' into gpt_oss_latest
hlahkar Jan 23, 2026
bc26332
Put a check for hf_config
hlahkar Jan 23, 2026
a63ceb2
Make bias the last arg
hlahkar Jan 23, 2026
e3373e1
Merge branch 'main' into gpt_oss_latest
hlahkar Jan 23, 2026
22a55d1
Merge branch 'main' into gpt_oss_latest
wpyszka Jan 23, 2026
4b6fb28
bias ordering made proper
hlahkar Jan 23, 2026
0f46ae2
check for bias not None
hlahkar Jan 23, 2026
8b4bd00
Fix Llama4 shape mismatch for 32k+ context window (#842) (#855)
afierka-intel Jan 26, 2026
ba3e55a
Fix HPU model runner profile_run to work with dynamic kv-cache scales…
dudilester Jan 26, 2026
8a68c7d
Revert "skip HPU graphs for long prefills" (#850)
adobrzyn Jan 26, 2026
f3a4560
fix pre-commit error
hlahkar Jan 27, 2026
5634b9b
Merge branch 'main' into gpt_oss_latest
hlahkar Jan 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions vllm_gaudi/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __init__(
qk_head_dim: int,
v_head_dim: int,
kv_b_proj: ColumnParallelLinear,
sinks: Optional[torch.Tensor] = None,
**kwargs,
) -> None:
torch.nn.Module.__init__(self)
Expand Down Expand Up @@ -252,6 +253,11 @@ def __init__(
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")

def forward(
self,
Expand Down Expand Up @@ -428,6 +434,7 @@ def __init__(
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
use_irope: bool = False,
sinks: Optional[torch.Tensor] = None,
) -> None:
super(AttentionImpl, self).__init__()
if kv_sharing_target_layer_name is not None:
Expand Down Expand Up @@ -492,6 +499,11 @@ def __init__(
raise NotImplementedError("Encoder self-attention "
"is not implemented for "
"HPUAttentionImpl")
self.sinks = sinks
if sinks is not None:
assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of "
f"heads in the layer. Sinks shape: {sinks.shape}, "
f"num_heads: {num_heads}.")

def _maybe_init_alibi_biases(
self,
Expand Down Expand Up @@ -576,6 +588,12 @@ def forward(
# Reshape the input keys and values and store them in the cache.
# If kv_cache is not provided, the new key and value tensors are
# not cached. This happens during the initial memory profiling run.
if key.dtype != key_cache.dtype:
key = key.to(key_cache.dtype)
if value.dtype != value_cache.dtype:
value = value.to(value_cache.dtype)
if query.dtype != key.dtype:
query = query.to(key.dtype)
key_cache = self.k_cache(key, key_cache, slot_mapping, k_scales)
value_cache = self.v_cache(value, value_cache, slot_mapping, v_scales)

Expand Down Expand Up @@ -691,6 +709,7 @@ def common_attention_args(self,
'key_cache': key_cache,
'value_cache': value_cache,
'block_size': block_size,
"sinks": self.sinks,
'k_scales': k_scales,
'v_scales': v_scales,
}
Expand Down
128 changes: 103 additions & 25 deletions vllm_gaudi/extension/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def matmul_shape(lhs, rhs):
return result


def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, batch_size, matmul_av_op, batch2block_matmul_op,
def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, sink, batch_size, matmul_av_op, batch2block_matmul_op,
block2batch_matmul_op):
# When fp32_softmax is enabled attn is left in fp32 after Q@K
# We can return to native dtype after we renormalize and calculate the adjustments
Expand All @@ -80,11 +80,27 @@ def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, batch_siz
if block_bias is not None:
attn.add_(block_bias)
block_max = attn.amax(dim=-1, keepdim=True)
if sink is not None:
block_max = torch.maximum(block_max, sink)
attn = attn.sub(block_max)
attn = attn.exp()
if attn.dtype == torch.float32:
attn = attn.to(value.dtype)
block_sums = attn.sum(dim=-1, keepdim=True)
attn_shape = attn.shape
block_sums = attn.view(-1, attn_shape[-1]).sum(dim=-1, keepdim=True)
attn_shape = list(attn_shape)
attn_shape[-1] = 1
block_sums = block_sums.view(attn_shape)
if sink is not None:
attn_sink = sink.sub(block_max)
attn_sink = attn_sink.exp()
if attn_sink.dtype == torch.float32:
attn_sink = attn_sink.to(value.dtype)
#TODO: Removing this .sum and using attn_sink directly
#results in wrong output which does not make sense.
#Looks like a Synapse issue, need to investigate further.
block_sums_sink = attn_sink.sum(dim=-1, keepdim=True)
block_sums = block_sums + block_sums_sink
attn = matmul_av_op(attn, value)
if get_config().fused_block_softmax_adjustment:
out_shape = list(attn.shape[:3]) + [1] * (attn.dim() - 3)
Expand Down Expand Up @@ -177,7 +193,7 @@ def flat_pa_mla(query, key_cache, value_cache, block_list, block_mapping, block_

def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias, block_groups, block_size, scale,
matmul_qk_op, position_bias, matmul_av_op, batch2block_matmul_op, block2batch_matmul_op, keys_fetch_func,
values_fetch_func, k_scales, v_scales, **ignored_args):
values_fetch_func, sinks, k_scales, v_scales, **ignored_args):
batch_size, _, hidden_size = query.shape
_, kv_heads, head_size = key_cache.shape
q_heads = hidden_size // head_size
Expand All @@ -195,6 +211,13 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias
value = values_fetch_func(value_cache.unflatten(0, (-1, block_size)),
**get_kv_fetch_extra_args(blocks=block_list, scales=v_scales_uf)).transpose(1, 2)
block_bias = block_bias.view(key.size(0), 1, 1, -1)
sink = None
if sinks is not None:
sinks = sinks.reshape(sinks.shape[0], 1)
sink = sinks.reshape(1, sinks.shape[0], 1, sinks.shape[1])
sink = sink.expand(query.shape[0], -1, query.shape[-2], -1)
if kv_heads != q_heads:
sink = sink.unflatten(1, (kv_heads, -1))
if kv_heads != q_heads:
query = query.unflatten(1, (kv_heads, -1))
key = key.unflatten(1, (kv_heads, 1))
Expand Down Expand Up @@ -232,6 +255,7 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias
block_bias,
block_groups,
block_mapping,
sink,
batch_size=batch_size,
matmul_av_op=matmul_av_op,
batch2block_matmul_op=batch2block_matmul_op,
Expand Down Expand Up @@ -290,6 +314,7 @@ def _naive_prompt_attention(query: torch.Tensor,
matmul_qk_op=torch.matmul,
softmax_op=torch.softmax,
matmul_av_op=torch.matmul,
sinks: Optional[torch.Tensor] = None,
**ignored_args) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
Expand Down Expand Up @@ -321,10 +346,19 @@ def _naive_prompt_attention(query: torch.Tensor,
if attn_weights.dtype != attn_bias.dtype:
attn_bias = attn_bias.to(dtype=attn_weights.dtype)
attn_weights.add_(attn_bias)
if sinks is not None:
sink = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1)
if query_heads != kv_heads:
sink = sink.unflatten(1, (kv_heads, -1))
combined_logits = torch.cat([attn_weights, sink], dim=-1)
combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values
attn_weights = combined_logits
if get_config().fp32_softmax:
attn_weights = torch.softmax(attn_weights, dim=-1)
else:
attn_weights = softmax_op(attn_weights, dim=-1)
if sinks is not None:
attn_weights = attn_weights[..., :-1]
attn_weights = attn_weights.to(query.dtype)
attn_weights = matmul_av_op(attn_weights, value)

Expand All @@ -343,6 +377,7 @@ def _fsdpa_prompt_attention(query: torch.Tensor,
attn_bias: Optional[torch.Tensor] = None,
valid_seq_lengths: Optional[torch.Tensor] = None,
window_size: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
**ignored_args) -> torch.Tensor:
query = query.transpose(1, 2)
key = key.transpose(1, 2)
Expand All @@ -364,10 +399,16 @@ def _fsdpa_prompt_attention(query: torch.Tensor,
query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths,
padding_side
]
args += [window_size] if window_size else []
args += [window_size] if window_size else [None]
# use sinks in fsdpa
if sinks is not None:
args += [sinks]
attn_weights = fsdpa_op(*args)

attn_weights = attn_weights.transpose(1, 2)
if sinks is not None:
# TODO - check if we can remove this
htcore.mark_step()
return attn_weights


Expand Down Expand Up @@ -487,6 +528,9 @@ def __init__(self):
def set_weight(self, w):
self.weight = w

def set_bias(self, b):
self.bias = b

def forward(self, state, expert_id, w):
raise NotImplementedError()

Expand All @@ -498,12 +542,14 @@ def __init__(self,
num_total_experts: int,
experts_min: int = 0,
experts_max: int = 8,
bias = None,
dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None):
super().__init__()
self.experts_min = experts_min
self.experts_max = experts_max
self.global_num_experts = global_num_experts
self.num_experts = num_total_experts
self.bias = bias

if MAX_EXPERTS_PER_SLICE > 0:
max_expert_per_slice = MAX_EXPERTS_PER_SLICE
Expand Down Expand Up @@ -555,8 +601,9 @@ def __init__(self,
num_total_experts: int,
experts_min: int = 0,
experts_max: int = 8,
bias=None,
dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None):
super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, dispatch_fn)
super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, bias, dispatch_fn)
self.w13_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)])
self.w2_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)])

Expand All @@ -569,31 +616,62 @@ def forward(self, hidden_states, expert_routing_table, router_weights, permuted_
w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range]

if self.moe_n_slice == 1:
return torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=permuted_weights,
activation=activation,
experts_min=self.experts_min,
experts_max=self.experts_max,
**kwargs)
if self.bias is not None:
w1_bias_list = [self.w13_list[i].bias.squeeze() for i in experts_range]
w2_bias_list = [self.w2_list[i].bias.squeeze() for i in experts_range]
return torch.ops.hpu.mixture_of_experts.bias_fused_weights(hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
w12_bias=w1_bias_list,
w3_bias=w2_bias_list,
permuted_weights=permuted_weights,
experts_min=self.experts_min,
experts_max=self.experts_max)
else:
return torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
permuted_weights=permuted_weights,
activation=activation,
experts_min=self.experts_min,
experts_max=self.experts_max,
**kwargs)
for i in range(self.moe_n_slice):
w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
min_expert = self.experts_min + i * self.num_expert_per_group
max_expert = min_expert + self.num_expert_per_group - 1
slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list_slice,
w3=w2_list_slice,
permuted_weights=permuted_weights,
activation=activation,
experts_min=min_expert,
experts_max=max_expert,
**kwargs)
if self.bias is not None:
w1_bias_list = [self.w13_list[i].bias.squeeze() for i in experts_range]
w2_bias_list = [self.w2_list[i].bias.squeeze() for i in experts_range]
w1_bias_list_slice = w1_bias_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
w2_bias_list_slice = w2_bias_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group]
slice_final_hidden_states = torch.ops.hpu.mixture_of_experts.bias_fused_weights(
hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list,
w3=w2_list,
w12_bias=w1_bias_list_slice,
w3_bias=w2_bias_list_slice,
permuted_weights=permuted_weights,
experts_min=self.experts_min,
experts_max=self.experts_max)
else:
slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states,
expert_routing_table=expert_routing_table,
router_weights=router_weights,
w12=w1_list_slice,
w3=w2_list_slice,
permuted_weights=permuted_weights,
activation=activation,
experts_min=min_expert,
experts_max=max_expert,
**kwargs)
if i == 0:
final_hidden_states = slice_final_hidden_states
else:
Expand Down
6 changes: 4 additions & 2 deletions vllm_gaudi/extension/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,14 +157,16 @@ def forward(
valid_sequence_lengths,
padding_side="left",
window_size=None,
sinks=None,
):
if window_size is not None:
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
recompute_mode, valid_sequence_lengths, padding_side, False, False,
window_size)
window_size, sinks)
else:
return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode,
recompute_mode, valid_sequence_lengths, padding_side)
recompute_mode, valid_sequence_lengths, padding_side, False, False,
(-1, -1), sinks)


def pad_list(input, target_len, val_generator):
Expand Down
Loading
Loading