Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion models/deepseek/v4/deepseek_v4_decode_hc_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ def golden_deepseek_v4_decode_hc_post(tensors):
for in_h in range(HC_MULT):
y_row = y_row + residual[:, :, in_h, :] * comb[:, :, in_h, out_h:out_h + 1]
y_fp32[:, :, out_h, :] = y_row
y = y_fp32.to(torch.bfloat16)
def _to_device_bf16(value):
rounded = (value.contiguous().view(torch.int32) + 0x8000) & -0x10000
return rounded.view(torch.float32).to(torch.bfloat16)
Comment on lines +97 to +99
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The _to_device_bf16 helper function is duplicated in models/deepseek/v4/deepseek_v4_decode_sparse_attn.py. Consider refactoring this into a shared utility module to adhere to DRY principles and ensure consistent rounding behavior across the project.


y = _to_device_bf16(y_fp32)

tensors["y"][:] = y

Expand Down
90 changes: 83 additions & 7 deletions models/deepseek/v4/deepseek_v4_decode_qkv_proj_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,63 @@ def rms_norm(x, gamma, eps=EPS):
inv = torch.rsqrt(x.square().mean(-1, keepdim=True) + eps)
return x * inv * gamma

def matmul_bf16_input_fp32(a, b):
def matmul_wqa_tiled(a, b):
a_fp32 = a.to(torch.bfloat16).float()
b_fp32 = b.to(torch.bfloat16).float()
return torch.matmul(a_fp32, b_fp32).float()
out = torch.empty(T, Q_LORA, dtype=torch.float32)
for qb in range(Q_BLOCKS):
q0 = qb * Q_LORA_CHUNK
acc = torch.matmul(a_fp32[:, :D_CHUNK], b_fp32[:D_CHUNK, q0:q0 + Q_LORA_CHUNK])
for db in range(1, D_BLOCKS):
d0 = db * D_CHUNK
acc = acc + torch.matmul(
a_fp32[:, d0:d0 + D_CHUNK],
b_fp32[d0:d0 + D_CHUNK, q0:q0 + Q_LORA_CHUNK],
)
out[:, q0:q0 + Q_LORA_CHUNK] = acc
return out

def matmul_wkv_tiled(a, b):
a_fp32 = a.to(torch.bfloat16).float()
b_fp32 = b.to(torch.bfloat16).float()
out = torch.empty(T, HEAD_DIM, dtype=torch.float32)
for kb in range(KV_BLOCKS):
k0 = kb * KV_CHUNK
acc = torch.matmul(a_fp32[:, :D_CHUNK], b_fp32[:D_CHUNK, k0:k0 + KV_CHUNK])
for db in range(1, D_BLOCKS):
d0 = db * D_CHUNK
acc = acc + torch.matmul(
a_fp32[:, d0:d0 + D_CHUNK],
b_fp32[d0:d0 + D_CHUNK, k0:k0 + KV_CHUNK],
)
out[:, k0:k0 + KV_CHUNK] = acc
return out

def rms_norm_q_tiled(x, gamma):
sq_sum = torch.zeros(T, 1, dtype=torch.float32)
for qb in range(Q_BLOCKS):
q0 = qb * Q_LORA_CHUNK
chunk = x[:, q0:q0 + Q_LORA_CHUNK]
sq_sum = sq_sum + chunk.square().sum(dim=-1, keepdim=True)
inv = torch.rsqrt(sq_sum * (1.0 / Q_LORA) + EPS)
out = torch.empty_like(x)
for qb in range(Q_BLOCKS):
q0 = qb * Q_LORA_CHUNK
out[:, q0:q0 + Q_LORA_CHUNK] = x[:, q0:q0 + Q_LORA_CHUNK] * inv * gamma[q0:q0 + Q_LORA_CHUNK]
return out

def rms_norm_kv_tiled(x, gamma):
sq_sum = torch.zeros(T, 1, dtype=torch.float32)
for kb in range(KV_BLOCKS):
k0 = kb * KV_CHUNK
chunk = x[:, k0:k0 + KV_CHUNK]
sq_sum = sq_sum + chunk.square().sum(dim=-1, keepdim=True)
inv = torch.rsqrt(sq_sum * (1.0 / HEAD_DIM) + EPS)
out = torch.empty_like(x)
for kb in range(KV_BLOCKS):
k0 = kb * KV_CHUNK
out[:, k0:k0 + KV_CHUNK] = x[:, k0:k0 + KV_CHUNK] * inv * gamma[k0:k0 + KV_CHUNK]
return out

def apply_rope(x_rope, cos, sin):
# x_rope: [T, ..., ROPE_DIM] using lo/hi half split.
Expand All @@ -356,7 +409,7 @@ def apply_rope(x_rope, cos, sin):
token_x = rms_norm(x.view(T, D), norm_w) # [T, D]

# Q path
qr_out = rms_norm(matmul_bf16_input_fp32(token_x, wq_a), gamma_cq) # [T, Q_LORA]
qr_out = rms_norm_q_tiled(matmul_wqa_tiled(token_x, wq_a), gamma_cq) # [T, Q_LORA]
# W8A8C16: wq_b W8 per-output-channel int8; qr_out A8 per-token int8.
# flash: also quantizes wq_a/wkv to fp8 (default Linear dtype).
qr_out_bf16 = qr_out.to(torch.bfloat16)
Expand All @@ -370,7 +423,7 @@ def apply_rope(x_rope, cos, sin):
q_out = torch.cat([q_nope, q_rope], dim=-1)

# KV path
kv_full = rms_norm(matmul_bf16_input_fp32(token_x, wkv), gamma_ckv) # [T, HEAD_DIM]
kv_full = rms_norm_kv_tiled(matmul_wkv_tiled(token_x, wkv), gamma_ckv) # [T, HEAD_DIM]
kv_nope = kv_full[..., :NOPE_DIM]
kv_rope_in = kv_full[..., NOPE_DIM:].unsqueeze(1) # add a pseudo head dim
kv_rope = apply_rope(kv_rope_in, rope_cos, rope_sin).squeeze(1)
Expand Down Expand Up @@ -441,6 +494,7 @@ def init_gamma_ckv():

if __name__ == "__main__":
import argparse
import torch
from golden import RunConfig, run_jit

def int8_lsb_compare(actual, expected, actual_outputs, expected_outputs, inputs, rtol, atol):
Expand All @@ -451,22 +505,44 @@ def int8_lsb_compare(actual, expected, actual_outputs, expected_outputs, inputs,
return True, ""
return False, "max INT8 diff > 1"

def bf16_outlier_compare(actual, expected, actual_outputs, expected_outputs, inputs, rtol, atol):
import torch

close = torch.isclose(actual, expected, rtol=rtol, atol=atol)
mismatch = int((~close).sum().item())
max_mismatch = int(actual.numel() * 0.005)
if mismatch <= max_mismatch:
return True, f"mismatch={mismatch}/{actual.numel()} <= {max_mismatch}"

diff = (actual.float() - expected.float()).abs()
max_idx = int(diff.flatten().argmax().item())
return False, (
f" BF16 outlier budget exceeded: mismatch={mismatch}/{actual.numel()} "
f"limit={max_mismatch} rtol={rtol} atol={atol}\n"
f" max_abs={diff.max().item():.8g} idx={max_idx} "
f"actual={actual.flatten()[max_idx].item()} expected={expected.flatten()[max_idx].item()}"
)
Comment on lines +508 to +524
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The bf16_outlier_compare function is duplicated in models/deepseek/v4/deepseek_v4_decode_swa.py. To improve maintainability, this custom comparator should be moved to a shared validation utility (e.g., in golden/validation.py).


parser = argparse.ArgumentParser()
parser.add_argument("-p", "--platform", type=str, default="a2a3",
choices=["a2a3", "a2a3sim", "a5", "a5sim"])
parser.add_argument("-d", "--device", type=int, default=0)
parser.add_argument("--seed", type=int, default=20260508)
parser.add_argument("--runtime-profiling", action="store_true", default=False)
args = parser.parse_args()

torch.manual_seed(args.seed)

result = run_jit(
fn=deepseek_v4_decode_qkv_proj_rope_test,
specs=build_tensor_specs(),
golden_fn=golden_deepseek_v4_decode_qkv_proj_rope,
config=RunConfig(
# W8A8C16 q_proj adds INT8 quant/dequant round-off before per-head RMSNorm.
rtol=5e-3,
atol=5e-3,
compare_fn={"qr": int8_lsb_compare},
# Allow a small BF16 tail: at most 0.5% elements may exceed the tolerance.
rtol=2e-3,
atol=2e-3,
compare_fn={"q": bf16_outlier_compare, "kv": bf16_outlier_compare, "qr": int8_lsb_compare},
compile=dict(dump_passes=True),
runtime=dict(
platform=args.platform,
Expand Down
36 changes: 28 additions & 8 deletions models/deepseek/v4/deepseek_v4_decode_sparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,10 @@ def golden_deepseek_v4_decode_sparse_attn_with_proj(tensors):
wo_a = tensors["wo_a"].float()
wo_b = tensors["wo_b"].float()

def _to_device_bf16(value):
rounded = (value.contiguous().view(torch.int32) + 0x8000) & -0x10000
return rounded.view(torch.float32).to(torch.bfloat16)

o = torch.zeros(T, H, HEAD_DIM)

for b in range(B):
Expand Down Expand Up @@ -375,23 +379,39 @@ def golden_deepseek_v4_decode_sparse_attn_with_proj(tensors):
denom = li + torch.exp(attn_sink.unsqueeze(-1) - score_max)
o[b] = oi_num / denom

rope_lo = o[..., NOPE_DIM : NOPE_DIM + HALF_ROPE]
rope_hi = o[..., NOPE_DIM + HALF_ROPE :]
attn_stage = _to_device_bf16(o)
rope_lo = attn_stage[..., NOPE_DIM : NOPE_DIM + HALF_ROPE].float()
rope_hi = attn_stage[..., NOPE_DIM + HALF_ROPE :].float()
cos_lo = cos[:, :HALF_ROPE].unsqueeze(1)
cos_hi = cos[:, HALF_ROPE:].unsqueeze(1)
sin_lo = sin[:, :HALF_ROPE].unsqueeze(1)
sin_hi = sin[:, HALF_ROPE:].unsqueeze(1)
inv_lo = rope_lo * cos_lo + rope_hi * sin_lo
inv_hi = rope_hi * cos_hi - rope_lo * sin_hi
o = torch.cat([o[..., :NOPE_DIM], inv_lo, inv_hi], dim=-1).to(torch.bfloat16)
o = torch.cat(
[attn_stage[..., :NOPE_DIM], _to_device_bf16(inv_lo), _to_device_bf16(inv_hi)],
dim=-1,
)

seq_per_batch = T // B
o_model = o.float().view(B, seq_per_batch, O_GROUPS, O_GROUP_IN)
o_r = torch.einsum("bsgd,grd->bsgr", o_model, wo_a)
o_r = o_r.to(torch.bfloat16).float()
out = o_r.flatten(2).view(T, O_GROUPS * O_LORA) @ wo_b.T

tensors["attn_out"][:] = out.to(torch.bfloat16)
o_r = torch.zeros(B, seq_per_batch, O_GROUPS, O_LORA, dtype=torch.float32)
for g in range(O_GROUPS):
for n0 in range(0, O_LORA, A_N_CHUNK):
acc = o_model[:, :, g, 0:A_K_CHUNK] @ wo_a[g, n0:n0 + A_N_CHUNK, 0:A_K_CHUNK].T
for k0 in range(A_K_CHUNK, O_GROUP_IN, A_K_CHUNK):
acc += o_model[:, :, g, k0:k0 + A_K_CHUNK] @ wo_a[g, n0:n0 + A_N_CHUNK, k0:k0 + A_K_CHUNK].T
o_r[:, :, g, n0:n0 + A_N_CHUNK] = acc
o_r_flat = _to_device_bf16(o_r).float().flatten(2).view(T, O_GROUPS * O_LORA)

out = torch.zeros(T, D, dtype=torch.float32)
for n0 in range(0, D, B_N_CHUNK):
acc = o_r_flat[:, 0:B_K_CHUNK] @ wo_b[n0:n0 + B_N_CHUNK, 0:B_K_CHUNK].T
for k0 in range(B_K_CHUNK, O_GROUPS * O_LORA, B_K_CHUNK):
acc += o_r_flat[:, k0:k0 + B_K_CHUNK] @ wo_b[n0:n0 + B_N_CHUNK, k0:k0 + B_K_CHUNK].T
out[:, n0:n0 + B_N_CHUNK] = acc

tensors["attn_out"][:] = _to_device_bf16(out)


def build_tensor_specs(compress_ratio: int = DEFAULT_COMPRESS_RATIO):
Expand Down
34 changes: 29 additions & 5 deletions models/deepseek/v4/deepseek_v4_decode_swa.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,11 +344,11 @@ def init_hc_attn_base():
def init_attn_norm_w():
return torch.ones(D)
def init_wq_a():
return torch.randn(D, Q_LORA) / D ** 0.5
return (torch.randn(D, Q_LORA) - 0.5) / D ** 0.5
def init_wq_b():
return torch.randn(Q_LORA, H * HEAD_DIM) / Q_LORA ** 0.5
return (torch.randn(Q_LORA, H * HEAD_DIM) - 0.5) / Q_LORA ** 0.5
def init_wkv():
return torch.randn(D, HEAD_DIM) / D ** 0.5
return (torch.randn(D, HEAD_DIM) - 0.5) / D ** 0.5
def init_gamma_cq():
return torch.ones(Q_LORA)
def init_gamma_ckv():
Expand Down Expand Up @@ -407,23 +407,47 @@ def init_wo_b():

if __name__ == "__main__":
import argparse
import torch
from golden import RunConfig, run_jit

def bf16_outlier_compare(actual, expected, actual_outputs, expected_outputs, inputs, rtol, atol):
import torch

close = torch.isclose(actual, expected, rtol=rtol, atol=atol)
mismatch = int((~close).sum().item())
max_mismatch = int(actual.numel() * 0.005)
if mismatch <= max_mismatch:
return True, f"mismatch={mismatch}/{actual.numel()} <= {max_mismatch}"

diff = (actual.float() - expected.float()).abs()
max_idx = int(diff.flatten().argmax().item())
return False, (
f" BF16 outlier budget exceeded: mismatch={mismatch}/{actual.numel()} "
f"limit={max_mismatch} rtol={rtol} atol={atol}\n"
f" max_abs={diff.max().item():.8g} idx={max_idx} "
f"actual={actual.flatten()[max_idx].item()} expected={expected.flatten()[max_idx].item()}"
)

parser = argparse.ArgumentParser()
parser.add_argument("-p", "--platform", type=str, default="a2a3",
choices=["a2a3", "a2a3sim", "a5", "a5sim"])
parser.add_argument("-d", "--device", type=int, default=0)
parser.add_argument("--seed", type=int, default=20260508)
parser.add_argument("--runtime-profiling", action="store_true", default=False)
args = parser.parse_args()

torch.manual_seed(args.seed)

result = run_jit(
fn=deepseek_v4_decode_swa,
specs=build_tensor_specs(),
golden_fn=golden_deepseek_v4_decode_swa,
config=RunConfig(
# qkv_proj_rope now uses W8A8C16 q_proj; SWA carries that BF16 drift through attention/o_proj.
rtol=6e-3,
atol=6e-3,
# Allow a small BF16 tail: at most 0.5% elements may exceed the tolerance.
rtol=3e-3,
atol=3e-3,
compare_fn={"x_out": bf16_outlier_compare},
compile=dict(dump_passes=True),
runtime=dict(
platform=args.platform,
Expand Down
Loading