diff --git a/vllm_gaudi/attention/backends/hpu_attn.py b/vllm_gaudi/attention/backends/hpu_attn.py index f0b54760a..4fc6f1651 100644 --- a/vllm_gaudi/attention/backends/hpu_attn.py +++ b/vllm_gaudi/attention/backends/hpu_attn.py @@ -204,6 +204,7 @@ def __init__( qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, + sinks: Optional[torch.Tensor] = None, **kwargs, ) -> None: torch.nn.Module.__init__(self) @@ -258,6 +259,11 @@ def __init__( "encoder/decoder cross-attention " "are not implemented for " "TritonMLAImpl") + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}.") def forward( self, @@ -434,6 +440,7 @@ def __init__( attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, use_irope: bool = False, + sinks: Optional[torch.Tensor] = None, ) -> None: super(AttentionImpl, self).__init__() self.kv_sharing_target_layer_name = kv_sharing_target_layer_name @@ -500,6 +507,11 @@ def __init__( raise NotImplementedError("Encoder self-attention " "is not implemented for " "HPUAttentionImpl") + self.sinks = sinks + if sinks is not None: + assert sinks.shape[0] == num_heads, ("Sinks must have the same number of heads as the number of " + f"heads in the layer. Sinks shape: {sinks.shape}, " + f"num_heads: {num_heads}.") self.is_chunked_attention = False @@ -582,6 +594,12 @@ def forward( if kv_cache is not None and isinstance(kv_cache, tuple): key_cache, value_cache, k_scales, v_scales = \ HPUPagedAttention.split_kv_cache(kv_cache, self.num_kv_heads, self.head_size) + if key.dtype != key_cache.dtype: + key = key.to(key_cache.dtype) + if value.dtype != value_cache.dtype: + value = value.to(value_cache.dtype) + if query.dtype != key.dtype: + query = query.to(key.dtype) if self.kv_sharing_target_layer_name is None: # Reshape the input keys and values and store them in the cache. # If kv_cache is not provided, the new key and value tensors are @@ -720,6 +738,7 @@ def common_attention_args(self, 'key_cache': key_cache, 'value_cache': value_cache, 'block_size': block_size, + "sinks": self.sinks, 'k_scales': k_scales, 'v_scales': v_scales, } diff --git a/vllm_gaudi/extension/ops.py b/vllm_gaudi/extension/ops.py index 96e38f3bc..eaba72db3 100644 --- a/vllm_gaudi/extension/ops.py +++ b/vllm_gaudi/extension/ops.py @@ -67,8 +67,8 @@ def matmul_shape(lhs, rhs): return result -def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, batch_size, matmul_av_op, batch2block_matmul_op, - block2batch_matmul_op): +def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, sink, batch_size, matmul_av_op, + batch2block_matmul_op, block2batch_matmul_op): # When fp32_softmax is enabled attn is left in fp32 after Q@K # We can return to native dtype after we renormalize and calculate the adjustments if block_bias is not None and attn.dtype != block_bias.dtype: @@ -82,11 +82,27 @@ def pipelined_pa(attn, value, block_bias, block_groups, block_mapping, batch_siz if block_bias is not None: attn.add_(block_bias) block_max = attn.amax(dim=-1, keepdim=True) + if sink is not None: + block_max = torch.maximum(block_max, sink) attn = attn.sub(block_max) attn = attn.exp() if attn.dtype == torch.float32: attn = attn.to(value.dtype) - block_sums = attn.sum(dim=-1, keepdim=True) + attn_shape = attn.shape + block_sums = attn.view(-1, attn_shape[-1]).sum(dim=-1, keepdim=True) + attn_shape = list(attn_shape) + attn_shape[-1] = 1 + block_sums = block_sums.view(attn_shape) + if sink is not None: + attn_sink = sink.sub(block_max) + attn_sink = attn_sink.exp() + if attn_sink.dtype == torch.float32: + attn_sink = attn_sink.to(value.dtype) + #TODO: Removing this .sum and using attn_sink directly + #results in wrong output which does not make sense. + #Looks like a Synapse issue, need to investigate further. + block_sums_sink = attn_sink.sum(dim=-1, keepdim=True) + block_sums = block_sums + block_sums_sink attn = matmul_av_op(attn, value) if get_config().fused_block_softmax_adjustment: out_shape = list(attn.shape[:3]) + [1] * (attn.dim() - 3) @@ -179,7 +195,7 @@ def flat_pa_mla(query, key_cache, value_cache, block_list, block_mapping, block_ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias, block_groups, block_size, scale, matmul_qk_op, position_bias, matmul_av_op, batch2block_matmul_op, block2batch_matmul_op, keys_fetch_func, - values_fetch_func, k_scales, v_scales, **ignored_args): + values_fetch_func, sinks, k_scales, v_scales, **ignored_args): batch_size, _, hidden_size = query.shape _, kv_heads, head_size = key_cache.shape q_heads = hidden_size // head_size @@ -197,6 +213,13 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias value = values_fetch_func(value_cache.unflatten(0, (-1, block_size)), **get_kv_fetch_extra_args(blocks=block_list, scales=v_scales_uf)).transpose(1, 2) block_bias = block_bias.view(key.size(0), 1, 1, -1) + sink = None + if sinks is not None: + sinks = sinks.reshape(sinks.shape[0], 1) + sink = sinks.reshape(1, sinks.shape[0], 1, sinks.shape[1]) + sink = sink.expand(query.shape[0], -1, query.shape[-2], -1) + if kv_heads != q_heads: + sink = sink.unflatten(1, (kv_heads, -1)) if kv_heads != q_heads: query = query.unflatten(1, (kv_heads, -1)) key = key.unflatten(1, (kv_heads, 1)) @@ -234,6 +257,7 @@ def flat_pa(query, key_cache, value_cache, block_list, block_mapping, block_bias block_bias, block_groups, block_mapping, + sink, batch_size=batch_size, matmul_av_op=matmul_av_op, batch2block_matmul_op=batch2block_matmul_op, @@ -292,6 +316,7 @@ def _naive_prompt_attention(query: torch.Tensor, matmul_qk_op=torch.matmul, softmax_op=torch.softmax, matmul_av_op=torch.matmul, + sinks: Optional[torch.Tensor] = None, **ignored_args) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -323,10 +348,19 @@ def _naive_prompt_attention(query: torch.Tensor, if attn_weights.dtype != attn_bias.dtype: attn_bias = attn_bias.to(dtype=attn_weights.dtype) attn_weights.add_(attn_bias) + if sinks is not None: + sink = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + if query_heads != kv_heads: + sink = sink.unflatten(1, (kv_heads, -1)) + combined_logits = torch.cat([attn_weights, sink], dim=-1) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + attn_weights = combined_logits if get_config().fp32_softmax: attn_weights = torch.softmax(attn_weights, dim=-1) else: attn_weights = softmax_op(attn_weights, dim=-1) + if sinks is not None: + attn_weights = attn_weights[..., :-1] attn_weights = attn_weights.to(query.dtype) attn_weights = matmul_av_op(attn_weights, value) @@ -345,6 +379,7 @@ def _fsdpa_prompt_attention(query: torch.Tensor, attn_bias: Optional[torch.Tensor] = None, valid_seq_lengths: Optional[torch.Tensor] = None, window_size: Optional[int] = None, + sinks: Optional[torch.Tensor] = None, **ignored_args) -> torch.Tensor: query = query.transpose(1, 2) key = key.transpose(1, 2) @@ -366,10 +401,19 @@ def _fsdpa_prompt_attention(query: torch.Tensor, query, key, value, attn_bias, 0.0, is_causal, scale, softmax_mode, recompute_mode, valid_seq_lengths, padding_side ] - args += [window_size] if window_size else [] + if sinks is not None: + args += [window_size] if window_size else [None] + else: + args += [window_size] if window_size else [] + # use sinks in fsdpa + if sinks is not None: + args += [sinks] attn_weights = fsdpa_op(*args) attn_weights = attn_weights.transpose(1, 2) + if sinks is not None: + # TODO - check if we can remove this + htcore.mark_step() return attn_weights @@ -494,10 +538,14 @@ class MoeMatmul(torch.nn.Module): def __init__(self): super().__init__() + self.bias = None def set_weight(self, w): self.weight = w + def set_bias(self, b): + self.bias = b + def forward(self, state, expert_id, w): raise NotImplementedError() @@ -509,12 +557,14 @@ def __init__(self, num_total_experts: int, experts_min: int = 0, experts_max: int = 8, - dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): + dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None, + bias=False): super().__init__() self.experts_min = experts_min self.experts_max = experts_max self.global_num_experts = global_num_experts self.num_experts = num_total_experts + self.bias = bias if MAX_EXPERTS_PER_SLICE > 0: max_expert_per_slice = MAX_EXPERTS_PER_SLICE @@ -566,8 +616,9 @@ def __init__(self, num_total_experts: int, experts_min: int = 0, experts_max: int = 8, - dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None): - super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, dispatch_fn) + dispatch_fn: Callable[[torch.Tensor], torch.Tensor] = None, + bias=False): + super().__init__(global_num_experts, num_total_experts, experts_min, experts_max, dispatch_fn, bias) self.w13_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) self.w2_list = torch.nn.ModuleList([MoeMatmul() for _ in range(num_total_experts)]) @@ -580,31 +631,64 @@ def forward(self, hidden_states, expert_routing_table, router_weights, permuted_ w2_list = [self.w2_list[i].weight.squeeze() for i in experts_range] if self.moe_n_slice == 1: - return torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states, - expert_routing_table=expert_routing_table, - router_weights=router_weights, - w12=w1_list, - w3=w2_list, - permuted_weights=permuted_weights, - activation=activation, - experts_min=self.experts_min, - experts_max=self.experts_max, - **kwargs) + if self.bias is True and self.w13_list[i].bias is not None \ + and self.w2_list[i].bias is not None: + w1_bias_list = [self.w13_list[i].bias.squeeze() for i in experts_range] + w2_bias_list = [self.w2_list[i].bias.squeeze() for i in experts_range] + return torch.ops.hpu.mixture_of_experts.bias_fused_weights(hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list, + w3=w2_list, + w12_bias=w1_bias_list, + w3_bias=w2_bias_list, + permuted_weights=permuted_weights, + experts_min=self.experts_min, + experts_max=self.experts_max) + else: + return torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list, + w3=w2_list, + permuted_weights=permuted_weights, + activation=activation, + experts_min=self.experts_min, + experts_max=self.experts_max, + **kwargs) for i in range(self.moe_n_slice): w1_list_slice = w1_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] w2_list_slice = w2_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] min_expert = self.experts_min + i * self.num_expert_per_group max_expert = min_expert + self.num_expert_per_group - 1 - slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states, - expert_routing_table=expert_routing_table, - router_weights=router_weights, - w12=w1_list_slice, - w3=w2_list_slice, - permuted_weights=permuted_weights, - activation=activation, - experts_min=min_expert, - experts_max=max_expert, - **kwargs) + if self.bias is not None and self.w13_list[i].bias is not None \ + and self.w2_list is not None: + w1_bias_list = [self.w13_list[i].bias.squeeze() for i in experts_range] + w2_bias_list = [self.w2_list[i].bias.squeeze() for i in experts_range] + w1_bias_list_slice = w1_bias_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] + w2_bias_list_slice = w2_bias_list[i * self.num_expert_per_group:(i + 1) * self.num_expert_per_group] + slice_final_hidden_states = torch.ops.hpu.mixture_of_experts.bias_fused_weights( + hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list, + w3=w2_list, + w12_bias=w1_bias_list_slice, + w3_bias=w2_bias_list_slice, + permuted_weights=permuted_weights, + experts_min=self.experts_min, + experts_max=self.experts_max) + else: + slice_final_hidden_states = torch.ops.hpu.mixture_of_experts(hidden_states=hidden_states, + expert_routing_table=expert_routing_table, + router_weights=router_weights, + w12=w1_list_slice, + w3=w2_list_slice, + permuted_weights=permuted_weights, + activation=activation, + experts_min=min_expert, + experts_max=max_expert, + **kwargs) if i == 0: final_hidden_states = slice_final_hidden_states else: diff --git a/vllm_gaudi/extension/utils.py b/vllm_gaudi/extension/utils.py index ede1c10dd..182cb5654 100644 --- a/vllm_gaudi/extension/utils.py +++ b/vllm_gaudi/extension/utils.py @@ -157,14 +157,16 @@ def forward( valid_sequence_lengths, padding_side="left", window_size=None, + sinks=None, ): if window_size is not None: return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, recompute_mode, valid_sequence_lengths, padding_side, False, False, - window_size) + window_size, sinks) else: return self._hpu_kernel_fsdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, softmax_mode, - recompute_mode, valid_sequence_lengths, padding_side) + recompute_mode, valid_sequence_lengths, padding_side, False, False, + (-1, -1), sinks) def pad_list(input, target_len, val_generator): diff --git a/vllm_gaudi/ops/hpu_fused_moe.py b/vllm_gaudi/ops/hpu_fused_moe.py index df860578a..b489b3747 100644 --- a/vllm_gaudi/ops/hpu_fused_moe.py +++ b/vllm_gaudi/ops/hpu_fused_moe.py @@ -3,6 +3,7 @@ import torch import vllm +from vllm.config import get_current_vllm_config from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant from vllm.model_executor.layers.fused_moe.fused_moe import GroupedTopk from vllm.model_executor.layers.fused_moe.fused_moe_router import FusedMoERouter @@ -90,12 +91,18 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.use_dispatch_fn = get_config().use_dispatch_fn torch.hpu.synchronize() + vllm_config = get_current_vllm_config() + self.model_type = None + if vllm_config is not None and vllm_config.model_config is not None \ + and vllm_config.model_config.hf_config is not None: + self.model_type = vllm_config.model_config.hf_config.model_type def process_weights_after_loading(self, layer: torch.nn.Module) -> None: super().process_weights_after_loading(layer) # custom handling for HPU num_experts = layer.local_num_experts ep_shift = layer.ep_rank * num_experts + has_bias = hasattr(layer, 'w13_bias') and hasattr(layer, 'w2_bias') experts_min, experts_max = ep_shift, num_experts + ep_shift - 1 @@ -104,17 +111,15 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: else: dispatch_fn = None - layer.moe_op = VllmMixtureOfExpertsOp( - layer.global_num_experts, - num_experts, - experts_min, - experts_max, - dispatch_fn, - ) + layer.moe_op = VllmMixtureOfExpertsOp(layer.global_num_experts, num_experts, experts_min, experts_max, + dispatch_fn, has_bias) for expert_id in range(layer.local_num_experts): layer.moe_op.w13_list[expert_id].set_weight(layer.w13_weight.data[expert_id]) layer.moe_op.w2_list[expert_id].set_weight(layer.w2_weight.data[expert_id]) + if has_bias: + layer.moe_op.w13_list[expert_id].set_bias(layer.w13_bias.data[expert_id]) + layer.moe_op.w2_list[expert_id].set_bias(layer.w2_bias.data[expert_id]) def forward_oot( self, @@ -130,9 +135,13 @@ def forward_oot( topk_weights, topk_ids = layer.router.select_experts(hidden_states=x, router_logits=router_logits) else: import torch.nn.functional as F - topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) - topk_weights, topk_ids = torch.topk(topk_weights, layer.top_k, dim=-1) - topk_weights /= topk_weights.sum(dim=-1, keepdim=True) + if self.model_type is not None and self.model_type in ["gpt_oss"]: + topk_weights, topk_ids = torch.topk(router_logits, layer.top_k, dim=-1) + topk_weights = F.softmax(topk_weights, dim=-1, dtype=torch.float32) + else: + topk_weights = F.softmax(router_logits, dim=1, dtype=torch.float32) + topk_weights, topk_ids = torch.topk(topk_weights, layer.top_k, dim=-1) + topk_weights /= topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights.to(x.dtype) if not layer.use_grouped_topk: @@ -154,6 +163,15 @@ def forward_oot( topk_ids = topk_ids.view(-1, topk_ids.shape[-1]) topk_weights = topk_weights.view(-1, topk_weights.shape[-1]) + if self.model_type in ["gpt_oss"]: + return layer.moe_op( + x, + topk_ids.to(torch.int64), + topk_weights.to(x.dtype), + permuted_weights=True, + activation=layer.activation, + ).view(*input_shape) + output = layer.moe_op( x, topk_ids, diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 7120c0dc3..ceb5c06bb 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -2178,7 +2178,7 @@ def _create_decode_input_data(self, ) if self.interleaved_sliding_window and self.sliding_window is not None and self.sliding_window > 0: - sliding_block_size = (self.sliding_window // self.block_size) + sliding_block_size = (self.sliding_window // self.block_size) + 1 window_block_tables = [block_table[-sliding_block_size:] for block_table in block_tables_list] window_block_list, window_block_groups, window_block_usage = \ self.get_habana_paged_attn_buffers(