diff --git a/models/deepseek/v4/decode_sparse_attn.py b/models/deepseek/v4/decode_sparse_attn.py index 7c9c2ea..7a69805 100644 --- a/models/deepseek/v4/decode_sparse_attn.py +++ b/models/deepseek/v4/decode_sparse_attn.py @@ -659,11 +659,32 @@ def golden_sparse_attn(tensors): kv_b = torch.stack(gathered, dim=0) q_t = q[t] - scores = (q_t @ kv_b.T) * SOFTMAX_SCALE - score_max = scores.max(dim=-1, keepdim=True).values - exp_scores = torch.exp(scores - score_max) - oi_num = exp_scores @ kv_b - li = exp_scores.sum(dim=-1, keepdim=True) + + block_mi = [] + block_li = [] + block_oi = [] + for tile_start in range(0, kv_b.shape[0], SPARSE_ATTN_TILE): + kv_tile = kv_b[tile_start:tile_start + SPARSE_ATTN_TILE] + scores = (q_t @ kv_tile.T) * SOFTMAX_SCALE + mi = scores.max(dim=-1, keepdim=True).values + exp_scores = torch.exp(scores - mi).to(torch.bfloat16).float() + li = exp_scores.sum(dim=-1, keepdim=True) + oi = exp_scores @ kv_tile.to(torch.bfloat16).float() + block_mi.append(mi) + block_li.append(li) + block_oi.append(oi) + + score_max = block_mi[0] + li = block_li[0] + oi_num = block_oi[0] + for mi_cur, li_cur, oi_cur in zip(block_mi[1:], block_li[1:], block_oi[1:]): + score_max_new = torch.maximum(score_max, mi_cur) + alpha = torch.exp(score_max - score_max_new) + beta = torch.exp(mi_cur - score_max_new) + li = alpha * li + beta * li_cur + oi_num = alpha * oi_num + beta * oi_cur + score_max = score_max_new + denom = li + torch.exp(attn_sink.unsqueeze(-1) - score_max) o[t] = oi_num / denom diff --git a/models/deepseek/v4/qkv_proj_rope.py b/models/deepseek/v4/qkv_proj_rope.py index f2cc176..d38d33c 100644 --- a/models/deepseek/v4/qkv_proj_rope.py +++ b/models/deepseek/v4/qkv_proj_rope.py @@ -285,7 +285,7 @@ def qkv_proj_rope( d0 = h0 + db * HEAD_CHUNK q_head_chunk = q_proj_fp32[:, d0 : d0 + HEAD_CHUNK] q_head_sq_sum = pl.add(q_head_sq_sum, pl.reshape(pl.row_sum(pl.mul(q_head_chunk, q_head_chunk)), [1, T])) - q_head_inv_rms = pl.rsqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS)) + q_head_inv_rms = pl.recip(pl.sqrt(pl.add(pl.mul(q_head_sq_sum, 1.0 / HEAD_DIM), EPS))) q_head_inv_rms_t = pl.reshape(q_head_inv_rms, [T, 1]) q_head_inv_rms_all[h : h + 1, :] = q_head_inv_rms