Skip to content

Commit 69a54bb

Browse files
committed
avoid add twice
1 parent b98a8a1 commit 69a54bb

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

python/sglang/srt/layers/attention/triton_ops/prefill_attention.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,17 @@ def _fwd_kernel(
112112
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
113113
qk += tl.dot(q, k)
114114
qk *= sm_scale
115-
qk += tl.where(
116-
(start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
117-
)
115+
118116
if IS_CAUSAL:
119117
qk += tl.where(
120-
offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")
118+
(start_n + offs_n[None, :] < cur_batch_seq_len)
119+
& (offs_m[:, None] >= (start_n + offs_n[None, :])),
120+
0,
121+
float("-inf"),
122+
)
123+
else:
124+
qk += tl.where(
125+
(start_n + offs_n[None, :]) < cur_batch_seq_len, 0, float("-inf")
121126
)
122127

123128
# -- compute m_ij, p, l_ij

0 commit comments

Comments
 (0)