Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def fused_sigmoid_gating_delta_rule_update_kernel(
tl.store(p_h0, b_h.to(p_h0.dtype.element_ty), mask=mask_h)


@input_guard
# Note: @input_guard is removed because in the optimized path
# (causal_conv1d_update_split_qkv), q/k/v are already contiguous.
def fused_sigmoid_gating_delta_rule_update(
A_log: torch.Tensor,
a: torch.Tensor,
Expand Down
49 changes: 30 additions & 19 deletions python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,25 +274,36 @@ def forward_decode(
query_start_loc = self.forward_metadata.query_start_loc
cache_indices = self.forward_metadata.mamba_cache_indices

mixed_qkv = causal_conv1d_update(
mixed_qkv,
conv_states,
conv_weights,
bias,
activation,
conv_state_indices=cache_indices,
)

query, key, value = torch.split(
mixed_qkv,
[
key_dim // attn_tp_size,
key_dim // attn_tp_size,
value_dim // attn_tp_size,
],
dim=-1,
)
# Reshape from [l, h*d] to [1, l, h, d]
if _is_hip:
query, key, value = causal_conv1d_update_split_qkv(
mixed_qkv,
conv_states,
conv_weights,
key_dim=key_dim // attn_tp_size,
value_dim=value_dim // attn_tp_size,
bias=bias,
activation=activation,
conv_state_indices=cache_indices,
)
else:
mixed_qkv = causal_conv1d_update(
mixed_qkv,
conv_states,
conv_weights,
bias,
activation,
conv_state_indices=cache_indices,
)
query, key, value = torch.split(
mixed_qkv,
[
key_dim // attn_tp_size,
key_dim // attn_tp_size,
value_dim // attn_tp_size,
],
dim=-1,
)

seq_len = query.shape[0]
num_heads = query.shape[1] // head_k_dim
query = query.view(1, seq_len, num_heads, head_k_dim)
Expand Down
26 changes: 8 additions & 18 deletions python/sglang/srt/layers/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,22 +636,6 @@ def fused_sigmoid_mul(x: torch.Tensor, y: torch.Tensor, out: torch.Tensor = None
return out


@triton.autotune(
configs=[
triton.Config({"BLOCK_N": 128, "BLOCK_H": 128}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 64, "BLOCK_H": 256}, num_warps=4, num_stages=2),
triton.Config({"BLOCK_N": 64, "BLOCK_H": 512}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_N": 32, "BLOCK_H": 512}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_N": 32, "BLOCK_H": 1024}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_N": 16, "BLOCK_H": 1024}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_N": 16, "BLOCK_H": 2048}, num_warps=8, num_stages=2),

triton.Config({"BLOCK_N": 8, "BLOCK_H": 2048}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_N": 4, "BLOCK_H": 2048}, num_warps=8, num_stages=2),
triton.Config({"BLOCK_N": 8, "BLOCK_H": 1024}, num_warps=4, num_stages=2),
],
key=["N", "H"],
)
@triton.jit
def _fused_sigmoid_mul_broadcast_kernel(
X, Y, OUT, N, H, stride_y,
Expand Down Expand Up @@ -702,10 +686,16 @@ def fused_sigmoid_mul_broadcast(x: torch.Tensor, y: torch.Tensor, out: torch.Ten
if out is None:
out = torch.empty_like(y)

def grid(META):
return (triton.cdiv(N, META["BLOCK_N"]), triton.cdiv(H, META["BLOCK_H"]))
# Fixed configuration (faster than autotune in benchmarks)
BLOCK_N = 32
BLOCK_H = 1024
grid = (triton.cdiv(N, BLOCK_N), triton.cdiv(H, BLOCK_H))

_fused_sigmoid_mul_broadcast_kernel[grid](
x, y, out, N, H, y.stride(0),
BLOCK_N=BLOCK_N,
BLOCK_H=BLOCK_H,
num_warps=8,
num_stages=2,
)
return out
2 changes: 1 addition & 1 deletion python/sglang/srt/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def _forward_shared_experts(self, hidden_states: torch.Tensor):
if self.shared_expert_gate is not None:
if _is_hip:
gate_output = self.shared_expert_gate(hidden_states)
shared_output = fused_sigmoid_mul_broadcast(gate_output, shared_output)
fused_sigmoid_mul_broadcast(gate_output, shared_output, out=shared_output)
else:
shared_output = (
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
Expand Down
3 changes: 2 additions & 1 deletion python/sglang/srt/models/qwen3_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def _forward_input_proj(self, hidden_states: torch.Tensor):
else:
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
projected_states_ba, _ = self.in_proj_ba(hidden_states)

return projected_states_qkvz, projected_states_ba

def forward(
Expand Down Expand Up @@ -663,7 +664,7 @@ def self_attention(

if self.attn_output_gate:
if _is_hip:
attn_output = fused_sigmoid_mul(gate, attn_output)
fused_sigmoid_mul(gate, attn_output, out=attn_output)
else:
attn_output = torch.sigmoid(gate) * attn_output

Expand Down