Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 146 additions & 0 deletions examples/deepseek_v32/act_quant.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import os

import torch
import torch_npu
import tilelang
Expand Down Expand Up @@ -78,3 +80,147 @@ def act_quant_kernel_(
T.copy(y_shared, Y[pid_m * blk_m, pid_n * group_size])

return act_quant_kernel_


# ---------------------------------------------------------------------------
# Pure-PyTorch reference implementation
# ---------------------------------------------------------------------------

def _fast_round_scale_ref(amax: torch.Tensor, fp8_max_inv: float) -> torch.Tensor:
"""Replicate the bit-manipulation fast_round_scale used in the kernel.

IEEE 754 float32: bit[31]=sign, bits[30:23]=exponent (biased by 127),
bits[22:0]=mantissa. Extracts floor(log2(val)), adds 1 when the mantissa
is non-zero (i.e. val is not an exact power of 2) to get ceil(log2(val)),
then reconstructs the nearest power-of-two float via (exp + 127) << 23.
"""
val = (amax * fp8_max_inv).to(torch.float32)
bits = val.view(torch.int32)
exp = ((bits >> 23) & 0xFF) - 127
man = bits & ((1 << 23) - 1)
log2_ceil = exp + (man != 0).to(torch.int32)
result_bits = (log2_ceil + 127) << 23
return result_bits.view(torch.float32)


def act_quant_torch_ref(
x: torch.Tensor,
group_size: int = 128,
round_scale: bool = False,
):
"""Reference activation quantization in plain PyTorch (float32 arithmetic)."""
fp8_min = -448.0
fp8_max = 448.0
fp8_max_inv = 1.0 / fp8_max

m, n = x.shape
num_groups = (n + group_size - 1) // group_size

x_f32 = x.float()
y_ref = torch.empty((m, n), dtype=torch.float32, device=x.device)
s_ref = torch.empty((m, num_groups), dtype=torch.float32, device=x.device)

for j in range(num_groups):
x_group = x_f32[:, j * group_size : (j + 1) * group_size]
amax = x_group.abs().amax(dim=1).clamp(min=1e-4) # 1e-4: avoid div-by-zero, matches kernel

if round_scale:
s_val = _fast_round_scale_ref(amax, fp8_max_inv)
else:
s_val = amax * fp8_max_inv

s_ref[:, j] = s_val
y_ref[:, j * group_size : (j + 1) * group_size] = torch.clamp(
x_group / s_val.unsqueeze(1), fp8_min, fp8_max
)

return y_ref, s_ref


# ---------------------------------------------------------------------------
# Kernel wrapper
# ---------------------------------------------------------------------------

def act_quant(
x: torch.Tensor,
group_size: int = 128,
round_scale: bool = False,
):
m, n = x.shape
num_groups = (n + group_size - 1) // group_size

kernel = act_quant_kernel(
N=n,
in_dtype=BF16,
out_dtype=FP8,
scale_dtype=FP32,
round_scale=round_scale,
)

y = torch.empty((m, n), dtype=torch.float8_e4m3fn, device=x.device)
s = torch.empty((m, num_groups), dtype=torch.float32, device=x.device)
ret = kernel(x, y, s)
if ret is not None:
return ret
return y, s


# ---------------------------------------------------------------------------
# Precision comparison
# ---------------------------------------------------------------------------

def run_test_case(m: int, n: int, round_scale: bool = False):
group_size = 128
# n must be a multiple of group_size so every tile is fully covered by the kernel.
assert n % group_size == 0, "n must be divisible by group_size"

npu_device = torch.device("npu")
x = torch.randn((m, n), dtype=torch.bfloat16, device=npu_device)

y_kernel, s_kernel = act_quant(x, group_size=group_size, round_scale=round_scale)
y_ref, s_ref = act_quant_torch_ref(x, group_size=group_size, round_scale=round_scale)

# --- compare scales (should be numerically close) ---
torch.testing.assert_close(
s_kernel.float().cpu(),
s_ref.float().cpu(),
rtol=1e-3,
atol=1e-4,
msg=f"Scale mismatch for m={m} n={n} round_scale={round_scale}",
)

# --- compare dequantized output: y * s ≈ x ---
# FP8 E4M3 has 3 mantissa bits, giving a quantization step of 2^(-3) = 12.5%
# of the representable value at each exponent. The 20% rtol / 0.1 atol
# accounts for both the per-element quantization error and the scale
# approximation across the quantize→dequantize round-trip.
s_expand = s_kernel.float().repeat_interleave(group_size, dim=1) # (m, n)
x_dequant = y_kernel.float() * s_expand
torch.testing.assert_close(
x_dequant.cpu(),
x.float().cpu(),
rtol=0.2,
atol=1e-1,
msg=f"Dequant mismatch for m={m} n={n} round_scale={round_scale}",
)

print(
f" m={m:4d} n={n:4d} round_scale={round_scale} "
f"scale_max_err={( s_kernel.float() - s_ref.float() ).abs().max().item():.2e} "
f"dequant_max_err={( x_dequant - x.float() ).abs().max().item():.2e} \033[92mPASS\033[0m"
)


def run_test():
run_test_case(m=32, n=128, round_scale=False)
run_test_case(m=64, n=256, round_scale=False)
run_test_case(m=96, n=512, round_scale=True)
run_test_case(m=128, n=128, round_scale=True)
print("\033[92mact_quant precision test passed.\033[0m")


if __name__ == "__main__":
os.environ["TILELANG_ASCEND_MODE"] = "Developer"
torch.npu.set_device(0)
tilelang.cache.clear_cache()
run_test()
Loading