diff --git a/python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py b/python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py index feeb7c31c69..008196eddce 100644 --- a/python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +++ b/python/sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py @@ -162,7 +162,8 @@ def fused_sigmoid_gating_delta_rule_update_kernel( tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h) -@input_guard +# Note: @input_guard is removed because in the optimized path +# (causal_conv1d_update_split_qkv), q/k/v are already contiguous. def fused_sigmoid_gating_delta_rule_update( A_log: torch.Tensor, a: torch.Tensor, diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index db0dd13b109..2ba798f1c1b 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -274,25 +274,36 @@ def forward_decode( query_start_loc = self.forward_metadata.query_start_loc cache_indices = self.forward_metadata.mamba_cache_indices - mixed_qkv = causal_conv1d_update( - mixed_qkv, - conv_states, - conv_weights, - bias, - activation, - conv_state_indices=cache_indices, - ) - - query, key, value = torch.split( - mixed_qkv, - [ - key_dim // attn_tp_size, - key_dim // attn_tp_size, - value_dim // attn_tp_size, - ], - dim=-1, - ) - # Reshape from [l, h*d] to [1, l, h, d] + if _is_hip: + query, key, value = causal_conv1d_update_split_qkv( + mixed_qkv, + conv_states, + conv_weights, + key_dim=key_dim // attn_tp_size, + value_dim=value_dim // attn_tp_size, + bias=bias, + activation=activation, + conv_state_indices=cache_indices, + ) + else: + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_states, + conv_weights, + bias, + activation, + conv_state_indices=cache_indices, + ) + query, key, value = torch.split( + mixed_qkv, + [ + key_dim // attn_tp_size, + key_dim // attn_tp_size, + value_dim // attn_tp_size, + ], + dim=-1, + ) + seq_len = query.shape[0] num_heads = query.shape[1] // head_k_dim query = query.view(1, seq_len, num_heads, head_k_dim) diff --git a/python/sglang/srt/layers/elementwise.py b/python/sglang/srt/layers/elementwise.py index 4cbe17b36e0..35422247bbd 100644 --- a/python/sglang/srt/layers/elementwise.py +++ b/python/sglang/srt/layers/elementwise.py @@ -636,22 +636,6 @@ def fused_sigmoid_mul(x: torch.Tensor, y: torch.Tensor, out: torch.Tensor = None return out -@triton.autotune( - configs=[ - triton.Config({"BLOCK_N": 128, "BLOCK_H": 128}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_N": 64, "BLOCK_H": 256}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_N": 64, "BLOCK_H": 512}, num_warps=8, num_stages=2), - triton.Config({"BLOCK_N": 32, "BLOCK_H": 512}, num_warps=8, num_stages=2), - triton.Config({"BLOCK_N": 32, "BLOCK_H": 1024}, num_warps=8, num_stages=2), - triton.Config({"BLOCK_N": 16, "BLOCK_H": 1024}, num_warps=8, num_stages=2), - triton.Config({"BLOCK_N": 16, "BLOCK_H": 2048}, num_warps=8, num_stages=2), - - triton.Config({"BLOCK_N": 8, "BLOCK_H": 2048}, num_warps=8, num_stages=2), - triton.Config({"BLOCK_N": 4, "BLOCK_H": 2048}, num_warps=8, num_stages=2), - triton.Config({"BLOCK_N": 8, "BLOCK_H": 1024}, num_warps=4, num_stages=2), - ], - key=["N", "H"], -) @triton.jit def _fused_sigmoid_mul_broadcast_kernel( X, Y, OUT, N, H, stride_y, @@ -702,10 +686,16 @@ def fused_sigmoid_mul_broadcast(x: torch.Tensor, y: torch.Tensor, out: torch.Ten if out is None: out = torch.empty_like(y) - def grid(META): - return (triton.cdiv(N, META["BLOCK_N"]), triton.cdiv(H, META["BLOCK_H"])) + # Fixed configuration (faster than autotune in benchmarks) + BLOCK_N = 32 + BLOCK_H = 1024 + grid = (triton.cdiv(N, BLOCK_N), triton.cdiv(H, BLOCK_H)) _fused_sigmoid_mul_broadcast_kernel[grid]( x, y, out, N, H, y.stride(0), + BLOCK_N=BLOCK_N, + BLOCK_H=BLOCK_H, + num_warps=8, + num_stages=2, ) return out diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index 976a08f50fb..c8d02ba4c0a 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -216,7 +216,7 @@ def _forward_shared_experts(self, hidden_states: torch.Tensor): if self.shared_expert_gate is not None: if _is_hip: gate_output = self.shared_expert_gate(hidden_states) - shared_output = fused_sigmoid_mul_broadcast(gate_output, shared_output) + fused_sigmoid_mul_broadcast(gate_output, shared_output, out=shared_output) else: shared_output = ( F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output diff --git a/python/sglang/srt/models/qwen3_next.py b/python/sglang/srt/models/qwen3_next.py index 7cf88bace72..28f2fe5d68d 100755 --- a/python/sglang/srt/models/qwen3_next.py +++ b/python/sglang/srt/models/qwen3_next.py @@ -309,6 +309,7 @@ def _forward_input_proj(self, hidden_states: torch.Tensor): else: projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states) projected_states_ba, _ = self.in_proj_ba(hidden_states) + return projected_states_qkvz, projected_states_ba def forward( @@ -663,7 +664,7 @@ def self_attention( if self.attn_output_gate: if _is_hip: - attn_output = fused_sigmoid_mul(gate, attn_output) + fused_sigmoid_mul(gate, attn_output, out=attn_output) else: attn_output = torch.sigmoid(gate) * attn_output