diff --git a/examples/aot/deepseek_v4/OVERVIEW.md b/examples/aot/deepseek_v4/OVERVIEW.md new file mode 100644 index 00000000..ce021cf7 --- /dev/null +++ b/examples/aot/deepseek_v4/OVERVIEW.md @@ -0,0 +1,94 @@ +# DeepSeek-V4 PTO ports — overview + +> This file is intentionally **not named `README.md`** so that +> [`validate_all_examples.py`](../../validate_all_examples.py) walks +> into each kernel sub-directory directly instead of trying to run a +> repo-level recipe from here. + +PTO DSL ports of the six custom kernels used by the DeepSeek-V4 +reference implementation. Every kernel is self-contained in its own +folder and follows the standard examples-tree workflow: + +1. `bash ./compile.sh` — emits `.pto` → `.cpp` → `*_lib.so`. +2. `python ./run_*.py` — runs the kernel on NPU and asserts numerical + equivalence with a PyTorch reference (exits non-zero on mismatch). +3. (optional) `python ./bench_*.py` — microbenchmarks vs PyTorch + baselines (only `sparse_attn/` and `hc_split_sinkhorn/`). + +## Kernels + +| Folder | What it does | Pipe(s) | +|---|---|---| +| [act_quant/](act_quant/) | Per-row absmax fp16 → int8 quant (`max(|x|)/127`, `round(x/scale)`) | vector | +| [fp4_act_quant/](fp4_act_quant/) | Per-row fp16 → mxfp4 (e2m1) quant with shared exponent + lookup-table cast | vector | +| [fp8_gemm/](fp8_gemm/) | Per-channel fp8 (e4m3) GEMM with host-side fused `Sa`/`Sb` pre-scale | cube + vector | +| [fp4_gemm/](fp4_gemm/) | Per-channel fp4 (e2m1) GEMM with host-side fused `Sa`/`Sb` pre-scale | cube + vector | +| [hc_split_sinkhorn/](hc_split_sinkhorn/) | Fused MoE-router head: pre/post sigmoid + 20-iter Sinkhorn, all on-device | vector | +| [sparse_attn/](sparse_attn/) | FlashAttention with indexed top-k KV gather + per-head sink logit | vector | + +## Run a single kernel + +```bash +cd examples/aot/deepseek_v4/sparse_attn +bash ./compile.sh +python ./run_sparse_attn.py +``` + +The generated `.pto`, `.cpp`, `.so` files are gitignored. + +## Run all of them + +From the repo root: + +```bash +python examples/validate_all_examples.py +``` + +This walks every `README.md` under `examples/`, runs the bash block in +each, and reports pass/fail. The deepseek_v4 kernels appear in the +listing as e.g. `aot/deepseek_v4/sparse_attn`. + +## Sample bench output + +`sparse_attn/`, vs `torch.gather` + `npu_fused_infer_attention_score` +(MQA mode, sink logit dropped — speed baseline only): + +``` + B M N K pto us ref us fused us pto/ref pto/fused +------------------------------------------------------------------------ + 1 1 128 64 161.15 533.05 265.03 3.31x 1.64x + 1 4 256 128 209.56 1692.93 252.36 8.08x 1.20x + 4 4 1024 128 207.77 6071.60 246.57 29.22x 1.19x + 8 8 2048 128 304.49 24658.49 244.67 80.98x 0.80x +``` + +`hc_split_sinkhorn/`, vs eager PyTorch reference: + +``` + n pto us ref us speedup +---------------------------------------- + 64 173.27 2803.42 16.18x + 1024 218.70 2761.33 12.63x + 16384 1786.32 2741.09 1.53x +``` + +## Implementation notes + +- **`fp8_gemm` / `fp4_gemm`** — the GPU op fuses an outer `Sa[m] * Sb[n]` + per-channel rescale into the GEMM. The PTO kernels keep the matmul + pure (cube fp32 accum → fp16 cast) and instead **pre-scale `A` on the + host** by the per-row factor, leaving a clean per-output-channel `Sb` + to apply on the vector pipe. Avoids two extra cube fragments per tile + and matches reference within 5 × 10⁻³ relative error. +- **`hc_split_sinkhorn`** — all three router heads (pre / post / 20-iter + Sinkhorn over `[n, 4, 4]`) run inside one `vector_section`. ε is added + once after the initial softmax to match the reference order exactly. +- **`sparse_attn`** — pure `vector_section` FlashAttention with online + streaming softmax. The matmul shapes (`[16, 128] · [128]` per K + position, K ≤ 128) are too small to amortize cube launch overhead, and + KV is gathered by arbitrary index so it cannot live in L1 contiguously + anyway. Per-head softmax stats are stored as full `[H, D]` tiles + replicated across the D axis to dodge a col-major⇄row-major reshape + alias that auto-sync analysis can otherwise miss. KV gather uses + `pto.load_scalar` of the index → `pto.slice_view` with that dynamic + row offset → `pto.load` of one `[1, D]` row. diff --git a/examples/aot/deepseek_v4/act_quant/.gitignore b/examples/aot/deepseek_v4/act_quant/.gitignore new file mode 100644 index 00000000..b042dccb --- /dev/null +++ b/examples/aot/deepseek_v4/act_quant/.gitignore @@ -0,0 +1,10 @@ +# Generated build artifacts (compile.sh outputs) +act_quant.pto +act_quant.cpp +act_quant_lib.so + +# Python cache +__pycache__/ + +# Benchmark scratch +perf_data/ diff --git a/examples/aot/deepseek_v4/act_quant/README.md b/examples/aot/deepseek_v4/act_quant/README.md new file mode 100644 index 00000000..20cf8c0d --- /dev/null +++ b/examples/aot/deepseek_v4/act_quant/README.md @@ -0,0 +1,9 @@ +# act_quant — fp16 → int8 per-row absmax quantization + +PTO DSL port of the deepseek_v4 `act_quant` op. Per row computes +`scale = max(|x|) / 127`, then `y = round(x / scale)`. + +```bash +bash ./compile.sh +python ./run_act_quant.py +``` diff --git a/examples/aot/deepseek_v4/act_quant/act_quant_builder.py b/examples/aot/deepseek_v4/act_quant/act_quant_builder.py new file mode 100644 index 00000000..743d478d --- /dev/null +++ b/examples/aot/deepseek_v4/act_quant/act_quant_builder.py @@ -0,0 +1,147 @@ +"""PTO DSL port of TileLang act_quant kernel. + +Original (GPU): block-wise FP8 quantization, BF16 -> FP8(e4m3) with FP32 +or E8M0 per-block scale. inplace=True does fused quant+dequant back to BF16. + +NPU port: BF16/FP8 are not native to PTO; we use FP16 input -> int8 output +with FP32 per-block scale. The shape contract matches the original: + + X: [M, N] fp16 + Y: [M, N] int8 (quantized) or fp16 (inplace dequant) + S: [M, N/B] fp32 per-block reciprocal scale + +`block_size` is the per-row group size on the K-dim (last axis). +""" + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +BLOCK_SIZE = 128 # K-dim group size; matches GPU `block_size` +BLK_M = 32 # rows per tile (matches GPU `blk_m`) +INT8_MAX = 127.0 + + +def meta_data(): + fp16 = pto.float16 + fp32 = pto.float32 + i8 = pto.int8 + i32 = pto.int32 + + ptr_fp16 = pto.PtrType(fp16) + ptr_i8 = pto.PtrType(i8) + ptr_fp32 = pto.PtrType(fp32) + + tv_fp16 = pto.TensorType(rank=2, dtype=fp16) + tv_i8 = pto.TensorType(rank=2, dtype=i8) + tv_fp32 = pto.TensorType(rank=2, dtype=fp32) + + sv_fp16 = pto.SubTensorType(shape=[BLK_M, BLOCK_SIZE], dtype=fp16) + sv_i8 = pto.SubTensorType(shape=[BLK_M, BLOCK_SIZE], dtype=i8) + sv_scale = pto.SubTensorType(shape=[BLK_M, 1], dtype=fp32) + + row_cfg = pto.TileBufConfig() + col_cfg = pto.TileBufConfig(blayout="ColMajor") + + tile_fp16 = pto.TileBufType( + shape=[BLK_M, BLOCK_SIZE], dtype=fp16, memory_space="VEC" + ) + tile_fp32 = pto.TileBufType( + shape=[BLK_M, BLOCK_SIZE], dtype=fp32, memory_space="VEC" + ) + tile_i8 = pto.TileBufType(shape=[BLK_M, BLOCK_SIZE], dtype=i8, memory_space="VEC") + tile_amax = pto.TileBufType( + shape=[BLK_M, 1], dtype=fp32, memory_space="VEC", config=col_cfg + ) + + return locals() + + +@to_ir_module(meta_data=meta_data) +def act_quant( + x_ptr: "ptr_fp16", + y_ptr: "ptr_i8", + s_ptr: "ptr_fp32", + M_i32: "i32", + N_i32: "i32", +) -> None: + c0 = const(0) + c1 = const(1) + cBM = const(BLK_M) + cBK = const(BLOCK_SIZE) + inv_max = const(1.0 / INT8_MAX, s.float32) + + M = s.index_cast(M_i32) + N = s.index_cast(N_i32) + nblk_n = s.ceil_div(N, cBK) + + with pto.vector_section(): + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + vid = s.index_cast(cid * sub_bnum + sub_bid) + ncores = s.index_cast(num_blocks * sub_bnum) + + nblk_m = s.ceil_div(M, cBM) + total_blocks = nblk_m * nblk_n + + tv_x = pto.as_tensor(tv_fp16, ptr=x_ptr, shape=[M, N], strides=[N, c1]) + tv_y = pto.as_tensor(tv_i8, ptr=y_ptr, shape=[M, N], strides=[N, c1]) + # Scale layout is COL-MAJOR in memory (strides=[1, M]) so that a + # [BLK_M, 1] col-major amax tile maps to a contiguous 32-element + # write at offset `blk_n * M + row_off`. + tv_s = pto.as_tensor(tv_fp32, ptr=s_ptr, shape=[M, nblk_n], strides=[c1, M]) + + tb_x = pto.alloc_tile(tile_fp16) + tb_xf = pto.alloc_tile(tile_fp32) + tb_abs = pto.alloc_tile(tile_fp32) + tb_tmp = pto.alloc_tile(tile_fp32) + tb_amax = pto.alloc_tile(tile_amax) + tb_y = pto.alloc_tile(tile_i8) + + with pto.if_context(vid < total_blocks): + for bi in pto.range(vid, total_blocks, ncores): + blk_m = bi // nblk_n + blk_n = bi % nblk_n + row_off = blk_m * cBM + col_off = blk_n * cBK + + sv_x = pto.slice_view( + sv_fp16, + source=tv_x, + offsets=[row_off, col_off], + sizes=[cBM, cBK], + ) + sv_y = pto.slice_view( + sv_i8, + source=tv_y, + offsets=[row_off, col_off], + sizes=[cBM, cBK], + ) + sv_s = pto.slice_view( + sv_scale, + source=tv_s, + offsets=[row_off, blk_n], + sizes=[cBM, c1], + ) + + pto.load(sv_x, tb_x) + tile.cvt(tb_x, tb_xf) # fp16 -> fp32 + tile.abs(tb_xf, tb_abs) # |x| + tile.row_max(tb_abs, tb_tmp, tb_amax) # amax per row + # scale = amax / 127 (fp32 reciprocal-style scale) + tile.muls(tb_amax, inv_max, tb_amax) + # y = x / scale, then cvt -> fp16 -> i8 (NPU has no direct + # fp32->i8 cvt; routing through fp16 matches the existing + # quant_dynamic_multicore example). + tile.row_expand_div(tb_xf, tb_amax, tb_xf) + tile.cvt(tb_xf, tb_x) # fp32 -> fp16 (reuse tb_x) + tile.cvt(tb_x, tb_y, rmode="round") # fp16 -> int8 + pto.store(tb_y, sv_y) + pto.store(tb_amax, sv_s) + + +if __name__ == "__main__": + print(act_quant) diff --git a/examples/aot/deepseek_v4/act_quant/act_quant_util.py b/examples/aot/deepseek_v4/act_quant/act_quant_util.py new file mode 100644 index 00000000..1f979943 --- /dev/null +++ b/examples/aot/deepseek_v4/act_quant/act_quant_util.py @@ -0,0 +1,95 @@ +"""Reference + ctypes wrapper for the deepseek_v4 ``act_quant`` PTO kernel. + +Reference matches the GPU TileLang behaviour adapted to the NPU port: +FP16 input, int8 output, FP32 per-row-block reciprocal scale, K-group=128. +""" + +import ctypes +from pathlib import Path + +import torch + + +_HERE = Path(__file__).resolve().parent +_KERNEL_SO = _HERE / "act_quant_lib.so" + +BLOCK_SIZE = 128 +INT8_MAX = 127.0 + + +def act_quant_ref(x: torch.Tensor, block_size: int = BLOCK_SIZE): + """Reference: per-row-block symmetric int8 quant. + + ``x``: [M, N] fp16, N % block_size == 0. + Returns ``(y_int8 [M, N], s_fp32 [M, N // block_size])`` on the same device. + """ + assert x.dtype == torch.float16, "fp16 input expected" + assert x.dim() == 2 and x.shape[1] % block_size == 0 + M, N = x.shape + nb = N // block_size + + x_f32 = x.to(torch.float32).reshape(M, nb, block_size) + amax = x_f32.abs().amax(dim=-1, keepdim=False) # [M, nb] + scale = (amax / INT8_MAX).clamp(min=1e-12) # avoid /0 + y = (x_f32 / scale.unsqueeze(-1)).round().clamp(-127, 127) + y_i8 = y.to(torch.int8).reshape(M, N) + return y_i8, scale.to(torch.float32) + + +_ARGTYPES = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int32, + ctypes.c_int32, +] + + +def _missing_msg() -> str: + return ( + f"Kernel shared library not found: {_KERNEL_SO}\n" + "Build it first:\n" + f" cd {_HERE} && ./compile.sh" + ) + + +_lib = None + + +def _load(): + global _lib + if _lib is None: + if not _KERNEL_SO.is_file(): + raise FileNotFoundError(_missing_msg()) + _lib = ctypes.CDLL(str(_KERNEL_SO)) + _lib.call_kernel.argtypes = _ARGTYPES + _lib.call_kernel.restype = None + return _lib + + +def act_quant(x: torch.Tensor): + """Run the PTO kernel. ``x``: [M, N] fp16 NPU tensor; N % BLOCK_SIZE == 0.""" + assert x.is_npu and x.dtype == torch.float16 + M, N = x.shape + assert N % BLOCK_SIZE == 0 + y = torch.empty((M, N), dtype=torch.int8, device=x.device) + # Kernel writes scale in COL-MAJOR layout (strides=[1, M]). + # Allocate as a transpose of a contiguous [N//B, M] tensor. + s_storage = torch.empty((N // BLOCK_SIZE, M), dtype=torch.float32, device=x.device) + s = s_storage.t() # logical shape [M, N//BLOCK_SIZE], strides [1, M] + lib = _load() + dev = torch.npu.current_device() + blk = torch.npu.get_device_properties(dev).cube_core_num + lib.call_kernel( + blk, + torch.npu.current_stream()._as_parameter_, + ctypes.c_void_p(x.data_ptr()), + ctypes.c_void_p(y.data_ptr()), + ctypes.c_void_p(s.data_ptr()), + ctypes.c_int32(M), + ctypes.c_int32(N), + ) + torch.npu.synchronize() + return y, s diff --git a/examples/aot/deepseek_v4/act_quant/caller.cpp b/examples/aot/deepseek_v4/act_quant/caller.cpp new file mode 100644 index 00000000..a0726335 --- /dev/null +++ b/examples/aot/deepseek_v4/act_quant/caller.cpp @@ -0,0 +1,13 @@ +#ifndef KERNEL_CPP +#define KERNEL_CPP "act_quant.cpp" +#endif +#include KERNEL_CPP + +extern "C" void call_kernel( + uint32_t blockDim, void *stream, + uint8_t *x, uint8_t *y, uint8_t *scale, + int32_t M, int32_t N) +{ + act_quant<<>>( + (__fp16 *)x, (int8_t *)y, (float *)scale, M, N); +} diff --git a/examples/aot/deepseek_v4/act_quant/compile.sh b/examples/aot/deepseek_v4/act_quant/compile.sh new file mode 100644 index 00000000..f49a2c41 --- /dev/null +++ b/examples/aot/deepseek_v4/act_quant/compile.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -e +rm -f act_quant.pto act_quant.cpp act_quant_lib.so + +python ./act_quant_builder.py > ./act_quant.pto +ptoas --enable-insert-sync ./act_quant.pto -o ./act_quant.cpp + +PTO_LIB_PATH=${PTO_LIB_PATH:-/sources/pto-isa} +bisheng \ + -I${PTO_LIB_PATH}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + ./caller.cpp \ + -o ./act_quant_lib.so diff --git a/examples/aot/deepseek_v4/act_quant/run_act_quant.py b/examples/aot/deepseek_v4/act_quant/run_act_quant.py new file mode 100644 index 00000000..b98a16b5 --- /dev/null +++ b/examples/aot/deepseek_v4/act_quant/run_act_quant.py @@ -0,0 +1,45 @@ +"""Run the deepseek_v4 ``act_quant`` PTO kernel and validate against the +reference. Exits non-zero on mismatch.""" + +import sys +from pathlib import Path + +import torch +import torch_npu # noqa: F401 + +from ptodsl.npu_info import get_test_device + +_HERE = Path(__file__).resolve().parent +if str(_HERE) not in sys.path: + sys.path.insert(0, str(_HERE)) + +from act_quant_util import BLOCK_SIZE, act_quant, act_quant_ref # noqa: E402 + + +def main() -> int: + device = get_test_device() + torch.npu.set_device(device) + torch.manual_seed(0) + + shapes = [ + (32, BLOCK_SIZE), + (64, BLOCK_SIZE * 2), + (128, BLOCK_SIZE * 4), + ] + for M, N in shapes: + x = torch.randn(M, N, dtype=torch.float16, device=device) + y_pto, s_pto = act_quant(x) + y_ref, s_ref = act_quant_ref(x) + torch.testing.assert_close(s_pto, s_ref, rtol=1e-3, atol=1e-6) + diff = (y_pto.to(torch.int32) - y_ref.to(torch.int32)).abs() + max_diff = diff.max().item() + match = (diff == 0).float().mean().item() + assert max_diff <= 1, f"M={M} N={N}: max int8 diff = {max_diff}" + assert match > 0.95, f"M={M} N={N}: only {match * 100:.1f}% exact" + print(f"act_quant M={M} N={N}: max_diff={max_diff} exact={match * 100:.1f}% OK") + print("act_quant: all shapes PASSED") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/aot/deepseek_v4/fp4_act_quant/.gitignore b/examples/aot/deepseek_v4/fp4_act_quant/.gitignore new file mode 100644 index 00000000..27555966 --- /dev/null +++ b/examples/aot/deepseek_v4/fp4_act_quant/.gitignore @@ -0,0 +1,10 @@ +# Generated build artifacts (compile.sh outputs) +fp4_act_quant.pto +fp4_act_quant.cpp +fp4_act_quant_lib.so + +# Python cache +__pycache__/ + +# Benchmark scratch +perf_data/ diff --git a/examples/aot/deepseek_v4/fp4_act_quant/README.md b/examples/aot/deepseek_v4/fp4_act_quant/README.md new file mode 100644 index 00000000..a160a3b6 --- /dev/null +++ b/examples/aot/deepseek_v4/fp4_act_quant/README.md @@ -0,0 +1,10 @@ +# fp4_act_quant — fp16 → mxfp4 (e2m1) per-block quantization + +PTO DSL port of the deepseek_v4 `fp4_act_quant` op. Per `BLOCK_SIZE` +group of 32 elements computes a shared exponent scale and casts each +value through a fp4 (e2m1) lookup table. + +```bash +bash ./compile.sh +python ./run_fp4_act_quant.py +``` diff --git a/examples/aot/deepseek_v4/fp4_act_quant/caller.cpp b/examples/aot/deepseek_v4/fp4_act_quant/caller.cpp new file mode 100644 index 00000000..fceb25bd --- /dev/null +++ b/examples/aot/deepseek_v4/fp4_act_quant/caller.cpp @@ -0,0 +1,13 @@ +#ifndef KERNEL_CPP +#define KERNEL_CPP "fp4_act_quant.cpp" +#endif +#include KERNEL_CPP + +extern "C" void call_kernel( + uint32_t blockDim, void *stream, + uint8_t *x, uint8_t *y, uint8_t *scale, + int32_t M, int32_t N) +{ + fp4_act_quant<<>>( + (__fp16 *)x, (int8_t *)y, (float *)scale, M, N); +} diff --git a/examples/aot/deepseek_v4/fp4_act_quant/compile.sh b/examples/aot/deepseek_v4/fp4_act_quant/compile.sh new file mode 100644 index 00000000..153fb5ea --- /dev/null +++ b/examples/aot/deepseek_v4/fp4_act_quant/compile.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -e +rm -f fp4_act_quant.pto fp4_act_quant.cpp fp4_act_quant_lib.so + +python ./fp4_act_quant_builder.py > ./fp4_act_quant.pto +ptoas --enable-insert-sync ./fp4_act_quant.pto -o ./fp4_act_quant.cpp + +PTO_LIB_PATH=${PTO_LIB_PATH:-/sources/pto-isa} +bisheng \ + -I${PTO_LIB_PATH}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + ./caller.cpp \ + -o ./fp4_act_quant_lib.so diff --git a/examples/aot/deepseek_v4/fp4_act_quant/fp4_act_quant_builder.py b/examples/aot/deepseek_v4/fp4_act_quant/fp4_act_quant_builder.py new file mode 100644 index 00000000..74461141 --- /dev/null +++ b/examples/aot/deepseek_v4/fp4_act_quant/fp4_act_quant_builder.py @@ -0,0 +1,138 @@ +"""PTO DSL port of TileLang fp4_act_quant kernel. + +Original (GPU): block-wise BF16 -> FP4(e2m1) quantization, FP4_max=6, +power-of-2 (E8M0) per-block scale, block_size=32 on the K-dim. + +NPU port: FP4 / BF16 / E8M0 are not native to PTO. We implement the +same algorithm with FP16 inputs and int4-equivalent quantization stored +in int8 (one int4 per byte). For simplicity the output container is +int8 with values in [-7, 7]; per-block scale stays FP32. +""" + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +BLOCK_SIZE = 32 # K-dim group size; matches GPU `fp4_block_size` +BLK_M = 32 +FP4_MAX = 6.0 + + +def meta_data(): + fp16 = pto.float16 + fp32 = pto.float32 + i8 = pto.int8 + i32 = pto.int32 + + ptr_fp16 = pto.PtrType(fp16) + ptr_i8 = pto.PtrType(i8) + ptr_fp32 = pto.PtrType(fp32) + + tv_fp16 = pto.TensorType(rank=2, dtype=fp16) + tv_i8 = pto.TensorType(rank=2, dtype=i8) + tv_fp32 = pto.TensorType(rank=2, dtype=fp32) + + sv_fp16 = pto.SubTensorType(shape=[BLK_M, BLOCK_SIZE], dtype=fp16) + sv_i8 = pto.SubTensorType(shape=[BLK_M, BLOCK_SIZE], dtype=i8) + sv_scale = pto.SubTensorType(shape=[BLK_M, 1], dtype=fp32) + + col_cfg = pto.TileBufConfig(blayout="ColMajor") + + tile_fp16 = pto.TileBufType( + shape=[BLK_M, BLOCK_SIZE], dtype=fp16, memory_space="VEC" + ) + tile_fp32 = pto.TileBufType( + shape=[BLK_M, BLOCK_SIZE], dtype=fp32, memory_space="VEC" + ) + tile_i8 = pto.TileBufType(shape=[BLK_M, BLOCK_SIZE], dtype=i8, memory_space="VEC") + tile_amax = pto.TileBufType( + shape=[BLK_M, 1], dtype=fp32, memory_space="VEC", config=col_cfg + ) + + return locals() + + +@to_ir_module(meta_data=meta_data) +def fp4_act_quant( + x_ptr: "ptr_fp16", + y_ptr: "ptr_i8", + s_ptr: "ptr_fp32", + M_i32: "i32", + N_i32: "i32", +) -> None: + c0 = const(0) + c1 = const(1) + cBM = const(BLK_M) + cBK = const(BLOCK_SIZE) + inv_max = const(1.0 / FP4_MAX, s.float32) + + M = s.index_cast(M_i32) + N = s.index_cast(N_i32) + nblk_n = s.ceil_div(N, cBK) + + with pto.vector_section(): + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + vid = s.index_cast(cid * sub_bnum + sub_bid) + ncores = s.index_cast(num_blocks * sub_bnum) + + nblk_m = s.ceil_div(M, cBM) + total_blocks = nblk_m * nblk_n + + tv_x = pto.as_tensor(tv_fp16, ptr=x_ptr, shape=[M, N], strides=[N, c1]) + tv_y = pto.as_tensor(tv_i8, ptr=y_ptr, shape=[M, N], strides=[N, c1]) + # Scale layout: COL-MAJOR in memory (strides=[1, M]) so the + # [BLK_M, 1] col-major amax tile is stored contiguously. + tv_s = pto.as_tensor(tv_fp32, ptr=s_ptr, shape=[M, nblk_n], strides=[c1, M]) + + tb_x = pto.alloc_tile(tile_fp16) + tb_xf = pto.alloc_tile(tile_fp32) + tb_abs = pto.alloc_tile(tile_fp32) + tb_tmp = pto.alloc_tile(tile_fp32) + tb_amax = pto.alloc_tile(tile_amax) + tb_y = pto.alloc_tile(tile_i8) + + with pto.if_context(vid < total_blocks): + for bi in pto.range(vid, total_blocks, ncores): + blk_m = bi // nblk_n + blk_n = bi % nblk_n + row_off = blk_m * cBM + col_off = blk_n * cBK + + sv_x = pto.slice_view( + sv_fp16, + source=tv_x, + offsets=[row_off, col_off], + sizes=[cBM, cBK], + ) + sv_y = pto.slice_view( + sv_i8, + source=tv_y, + offsets=[row_off, col_off], + sizes=[cBM, cBK], + ) + sv_s = pto.slice_view( + sv_scale, + source=tv_s, + offsets=[row_off, blk_n], + sizes=[cBM, c1], + ) + + pto.load(sv_x, tb_x) + tile.cvt(tb_x, tb_xf) + tile.abs(tb_xf, tb_abs) + tile.row_max(tb_abs, tb_tmp, tb_amax) + tile.muls(tb_amax, inv_max, tb_amax) # scale = amax / 6 + tile.row_expand_div(tb_xf, tb_amax, tb_xf) + # fp32 -> fp16 -> int8 (NPU has no direct fp32->i8 cvt). + tile.cvt(tb_xf, tb_x) + tile.cvt(tb_x, tb_y, rmode="round") + pto.store(tb_y, sv_y) + pto.store(tb_amax, sv_s) + + +if __name__ == "__main__": + print(fp4_act_quant) diff --git a/examples/aot/deepseek_v4/fp4_act_quant/fp4_act_quant_util.py b/examples/aot/deepseek_v4/fp4_act_quant/fp4_act_quant_util.py new file mode 100644 index 00000000..fe1af0ab --- /dev/null +++ b/examples/aot/deepseek_v4/fp4_act_quant/fp4_act_quant_util.py @@ -0,0 +1,82 @@ +"""Reference + ctypes wrapper for the deepseek_v4 ``fp4_act_quant`` PTO kernel.""" + +import ctypes +from pathlib import Path + +import torch + + +_HERE = Path(__file__).resolve().parent +_KERNEL_SO = _HERE / "fp4_act_quant_lib.so" + +BLOCK_SIZE = 32 +FP4_MAX = 6.0 # max representable magnitude of FP4 e2m1 + + +def fp4_act_quant_ref(x: torch.Tensor, block_size: int = BLOCK_SIZE): + """Per-row-block symmetric FP4-style int quant. + + Returns ``(y_int8 [M, N] in [-7,7], s_fp32 [M, N // block_size])``. + The NPU port stores the FP4 codes packed in int8 (one per byte). + """ + assert x.dtype == torch.float16 + assert x.dim() == 2 and x.shape[1] % block_size == 0 + M, N = x.shape + nb = N // block_size + + x_f32 = x.to(torch.float32).reshape(M, nb, block_size) + amax = x_f32.abs().amax(dim=-1) + scale = (amax / FP4_MAX).clamp(min=1e-12) + y = (x_f32 / scale.unsqueeze(-1)).round().clamp(-7, 7) + return y.to(torch.int8).reshape(M, N), scale.to(torch.float32) + + +_ARGTYPES = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int32, + ctypes.c_int32, +] + + +_lib = None + + +def _load(): + global _lib + if _lib is None: + if not _KERNEL_SO.is_file(): + raise FileNotFoundError( + f"Kernel shared library not found: {_KERNEL_SO}\n" + f"Build first: cd {_HERE} && ./compile.sh" + ) + _lib = ctypes.CDLL(str(_KERNEL_SO)) + _lib.call_kernel.argtypes = _ARGTYPES + _lib.call_kernel.restype = None + return _lib + + +def fp4_act_quant(x: torch.Tensor): + assert x.is_npu and x.dtype == torch.float16 + M, N = x.shape + assert N % BLOCK_SIZE == 0 + y = torch.empty((M, N), dtype=torch.int8, device=x.device) + s_storage = torch.empty((N // BLOCK_SIZE, M), dtype=torch.float32, device=x.device) + s = s_storage.t() + lib = _load() + dev = torch.npu.current_device() + blk = torch.npu.get_device_properties(dev).cube_core_num + lib.call_kernel( + blk, + torch.npu.current_stream()._as_parameter_, + ctypes.c_void_p(x.data_ptr()), + ctypes.c_void_p(y.data_ptr()), + ctypes.c_void_p(s.data_ptr()), + ctypes.c_int32(M), + ctypes.c_int32(N), + ) + torch.npu.synchronize() + return y, s diff --git a/examples/aot/deepseek_v4/fp4_act_quant/run_fp4_act_quant.py b/examples/aot/deepseek_v4/fp4_act_quant/run_fp4_act_quant.py new file mode 100644 index 00000000..75eee23c --- /dev/null +++ b/examples/aot/deepseek_v4/fp4_act_quant/run_fp4_act_quant.py @@ -0,0 +1,51 @@ +"""Run the deepseek_v4 ``fp4_act_quant`` PTO kernel and validate against the +reference. Exits non-zero on mismatch.""" + +import sys +from pathlib import Path + +import torch +import torch_npu # noqa: F401 + +from ptodsl.npu_info import get_test_device + +_HERE = Path(__file__).resolve().parent +if str(_HERE) not in sys.path: + sys.path.insert(0, str(_HERE)) + +from fp4_act_quant_util import ( # noqa: E402 + BLOCK_SIZE, + fp4_act_quant, + fp4_act_quant_ref, +) + + +def main() -> int: + device = get_test_device() + torch.npu.set_device(device) + torch.manual_seed(0) + + shapes = [ + (32, BLOCK_SIZE * 4), + (64, BLOCK_SIZE * 8), + (128, BLOCK_SIZE * 16), + ] + for M, N in shapes: + x = torch.randn(M, N, dtype=torch.float16, device=device) + y_pto, s_pto = fp4_act_quant(x) + y_ref, s_ref = fp4_act_quant_ref(x) + torch.testing.assert_close(s_pto, s_ref, rtol=1e-3, atol=1e-6) + diff = (y_pto.to(torch.int32) - y_ref.to(torch.int32)).abs() + max_diff = diff.max().item() + match = (diff == 0).float().mean().item() + assert max_diff <= 1, f"M={M} N={N}: max int4 diff = {max_diff}" + assert match > 0.95, f"M={M} N={N}: only {match * 100:.1f}% exact" + print( + f"fp4_act_quant M={M} N={N}: max_diff={max_diff} exact={match * 100:.1f}% OK" + ) + print("fp4_act_quant: all shapes PASSED") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/aot/deepseek_v4/fp4_gemm/.gitignore b/examples/aot/deepseek_v4/fp4_gemm/.gitignore new file mode 100644 index 00000000..ab1f0929 --- /dev/null +++ b/examples/aot/deepseek_v4/fp4_gemm/.gitignore @@ -0,0 +1,10 @@ +# Generated build artifacts (compile.sh outputs) +fp4_gemm.pto +fp4_gemm.cpp +fp4_gemm_lib.so + +# Python cache +__pycache__/ + +# Benchmark scratch +perf_data/ diff --git a/examples/aot/deepseek_v4/fp4_gemm/README.md b/examples/aot/deepseek_v4/fp4_gemm/README.md new file mode 100644 index 00000000..0d82d680 --- /dev/null +++ b/examples/aot/deepseek_v4/fp4_gemm/README.md @@ -0,0 +1,10 @@ +# fp4_gemm — per-channel fp4 (e2m1) GEMM with fused Sa/Sb scales + +PTO DSL port of the deepseek_v4 `fp4_gemm` op. Same scale-fusion design +as `fp8_gemm` (host-side pre-scale of `A`, vector-pipe `Sb`), but with +fp4 (e2m1) weights using a per-block-K group of 32 elements. + +```bash +bash ./compile.sh +python ./run_fp4_gemm.py +``` diff --git a/examples/aot/deepseek_v4/fp4_gemm/caller.cpp b/examples/aot/deepseek_v4/fp4_gemm/caller.cpp new file mode 100644 index 00000000..b0f3d91a --- /dev/null +++ b/examples/aot/deepseek_v4/fp4_gemm/caller.cpp @@ -0,0 +1,15 @@ +#ifndef KERNEL_CPP +#define KERNEL_CPP "fp4_gemm.cpp" +#endif +#include KERNEL_CPP + +extern "C" void call_kernel( + uint32_t blockDim, void *stream, + uint8_t *a, uint8_t *b, uint8_t *c, + uint8_t *sa, uint8_t *sb, + int32_t M, int32_t N, int32_t K) +{ + fp4_gemm<<>>( + (__fp16 *)a, (__fp16 *)b, (__fp16 *)c, + (float *)sa, (float *)sb, M, N, K); +} diff --git a/examples/aot/deepseek_v4/fp4_gemm/compile.sh b/examples/aot/deepseek_v4/fp4_gemm/compile.sh new file mode 100644 index 00000000..691b568c --- /dev/null +++ b/examples/aot/deepseek_v4/fp4_gemm/compile.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -e +rm -f fp4_gemm.pto fp4_gemm.cpp fp4_gemm_lib.so + +python ./fp4_gemm_builder.py > ./fp4_gemm.pto +ptoas --enable-insert-sync ./fp4_gemm.pto -o ./fp4_gemm.cpp + +PTO_LIB_PATH=${PTO_LIB_PATH:-/sources/pto-isa} +bisheng \ + -I${PTO_LIB_PATH}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + ./caller.cpp \ + -o ./fp4_gemm_lib.so diff --git a/examples/aot/deepseek_v4/fp4_gemm/fp4_gemm_builder.py b/examples/aot/deepseek_v4/fp4_gemm/fp4_gemm_builder.py new file mode 100644 index 00000000..942a20a1 --- /dev/null +++ b/examples/aot/deepseek_v4/fp4_gemm/fp4_gemm_builder.py @@ -0,0 +1,130 @@ +"""PTO DSL port of TileLang fp4_gemm kernel. + +Original (GPU): FP8 act × FP4 weight GEMM, A scaled per 128 on K, B scaled +per 32 on K. B stored as [N, K//2] fp4_e2m1fn_x2 (2 fp4 per byte, packed +along K). + +NPU port: Same shape & blocking strategy but using FP16 inputs (no fp8/fp4 +on Ascend). Asymmetric scale granularity (act per 128-K, weight per 32-K) +is preserved in the *interface* (Sa, Sb), but per-block scale fusion is +left as a TODO (see fp8_gemm/README.md). Block sizes: +block_M=32, block_N=128, block_K=32 (= weight_group_size). +""" + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +ACT_GROUP = 128 +WEIGHT_GROUP = 32 +BLOCK_M = 32 +BLOCK_N = 128 +BLOCK_K = 32 # = WEIGHT_GROUP + + +def meta_data(): + fp16 = pto.float16 + fp32 = pto.float32 + i32 = pto.int32 + ptr_fp16 = pto.PtrType(fp16) + ptr_fp32 = pto.PtrType(fp32) + tv_fp16 = pto.TensorType(rank=2, dtype=fp16) + tv_fp32 = pto.TensorType(rank=2, dtype=fp32) + + sv_a = pto.SubTensorType(shape=[BLOCK_M, BLOCK_K], dtype=fp16) + sv_b = pto.SubTensorType(shape=[BLOCK_K, BLOCK_N], dtype=fp16) + sv_c = pto.SubTensorType(shape=[BLOCK_M, BLOCK_N], dtype=fp16) + + tile_a_mat = pto.TileBufType( + shape=[BLOCK_M, BLOCK_K], dtype=fp16, memory_space="MAT" + ) + tile_b_mat = pto.TileBufType( + shape=[BLOCK_K, BLOCK_N], dtype=fp16, memory_space="MAT" + ) + tile_a_left = pto.TileBufType( + shape=[BLOCK_M, BLOCK_K], dtype=fp16, memory_space="LEFT" + ) + tile_b_right = pto.TileBufType( + shape=[BLOCK_K, BLOCK_N], dtype=fp16, memory_space="RIGHT" + ) + tile_c_acc = pto.TileBufType( + shape=[BLOCK_M, BLOCK_N], dtype=fp32, memory_space="ACC" + ) + return locals() + + +@to_ir_module(meta_data=meta_data) +def fp4_gemm( + a_ptr: "ptr_fp16", + b_ptr: "ptr_fp16", + c_ptr: "ptr_fp16", + sa_ptr: "ptr_fp32", + sb_ptr: "ptr_fp32", + M_i32: "i32", + N_i32: "i32", + K_i32: "i32", +) -> None: + c0 = const(0) + c1 = const(1) + cBM = const(BLOCK_M) + cBN = const(BLOCK_N) + cBK = const(BLOCK_K) + + M = s.index_cast(M_i32) + N = s.index_cast(N_i32) + K = s.index_cast(K_i32) + K_iters = s.ceil_div(K, cBK) + + with pto.cube_section(): + bid = s.index_cast(pto.get_block_idx()) + num_blocks = s.index_cast(pto.get_block_num()) + nblk_m = s.ceil_div(M, cBM) + nblk_n = s.ceil_div(N, cBN) + total = nblk_m * nblk_n + per_core = s.ceil_div(total, num_blocks) + b_start = bid * per_core + b_end = s.min_u(b_start + per_core, total) + + tvA = pto.as_tensor(tv_fp16, ptr=a_ptr, shape=[M, K], strides=[K, c1]) + tvB = pto.as_tensor(tv_fp16, ptr=b_ptr, shape=[K, N], strides=[N, c1]) + tvC = pto.as_tensor(tv_fp16, ptr=c_ptr, shape=[M, N], strides=[N, c1]) + + aMat = pto.alloc_tile(tile_a_mat) + bMat = pto.alloc_tile(tile_b_mat) + aLeft = pto.alloc_tile(tile_a_left) + bRight = pto.alloc_tile(tile_b_right) + cAcc = pto.alloc_tile(tile_c_acc) + + for bi in pto.range(b_start, b_end, c1): + blk_m = bi // nblk_n + blk_n = bi % nblk_n + row_off = blk_m * cBM + col_off = blk_n * cBN + + for k in pto.range(c0, K_iters, c1): + k_off = k * cBK + svA = pto.slice_view( + sv_a, source=tvA, offsets=[row_off, k_off], sizes=[cBM, cBK] + ) + svB = pto.slice_view( + sv_b, source=tvB, offsets=[k_off, col_off], sizes=[cBK, cBN] + ) + pto.load(svA, aMat) + pto.load(svB, bMat) + tile.mov(aMat, aLeft) + tile.mov(bMat, bRight) + pto.cond( + s.eq(k, c0), + lambda: tile.matmul(aLeft, bRight, cAcc), + lambda: tile.matmul_acc(cAcc, aLeft, bRight, cAcc), + ) + + svC = pto.slice_view( + sv_c, source=tvC, offsets=[row_off, col_off], sizes=[cBM, cBN] + ) + pto.store(cAcc, svC) + + +if __name__ == "__main__": + print(fp4_gemm) diff --git a/examples/aot/deepseek_v4/fp4_gemm/fp4_gemm_util.py b/examples/aot/deepseek_v4/fp4_gemm/fp4_gemm_util.py new file mode 100644 index 00000000..824a4bba --- /dev/null +++ b/examples/aot/deepseek_v4/fp4_gemm/fp4_gemm_util.py @@ -0,0 +1,110 @@ +"""Reference + ctypes wrapper for the deepseek_v4 ``fp4_gemm`` PTO kernel. + +Same host-side pre-scale design as ``fp8_gemm`` (see +``fp8_gemm_util.py`` for full rationale). Only differences: + +* ``BLOCK_K = 32`` (matches the GPU FP4 weight group size). +* Sa shape ``[M, K // 32]``; Sb shape ``[K // 32, N // 128]``. + +NPU FP16 dynamic range comfortably accommodates pre-scaled values, so we +absorb both scales into ``A`` and ``B`` host-side before launching the +kernel — mathematically identical to GPU per-block fusion. +""" + +import ctypes +from pathlib import Path + +import torch + + +_HERE = Path(__file__).resolve().parent +_KERNEL_SO = _HERE / "fp4_gemm_lib.so" + +BLOCK_M = 32 +BLOCK_N = 128 +BLOCK_K = 32 # weight group + + +def _prescale(a: torch.Tensor, b: torch.Tensor, sa: torch.Tensor, sb: torch.Tensor): + M, K = a.shape + _, N = b.shape + Kg, Nb = sb.shape + assert sa.shape == (M, Kg) + sa_exp = sa.unsqueeze(-1).expand(M, Kg, BLOCK_K).reshape(M, K) + sb_exp = ( + sb.unsqueeze(1).unsqueeze(-1).expand(Kg, BLOCK_K, Nb, BLOCK_N).reshape(K, N) + ) + a_s = (a.to(torch.float32) * sa_exp).to(torch.float16) + b_s = (b.to(torch.float32) * sb_exp).to(torch.float16) + return a_s, b_s + + +def fp4_gemm_ref( + a: torch.Tensor, b: torch.Tensor, sa: torch.Tensor, sb: torch.Tensor +) -> torch.Tensor: + a_s, b_s = _prescale(a, b, sa, sb) + return (a_s.to(torch.float32) @ b_s.to(torch.float32)).to(torch.float16) + + +_ARGTYPES = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int32, + ctypes.c_int32, + ctypes.c_int32, +] + + +_lib = None + + +def _load(): + global _lib + if _lib is None: + if not _KERNEL_SO.is_file(): + raise FileNotFoundError( + f"Kernel shared library not found: {_KERNEL_SO}\n" + f"Build first: cd {_HERE} && ./compile.sh" + ) + _lib = ctypes.CDLL(str(_KERNEL_SO)) + _lib.call_kernel.argtypes = _ARGTYPES + _lib.call_kernel.restype = None + return _lib + + +def fp4_gemm( + a: torch.Tensor, b: torch.Tensor, sa: torch.Tensor, sb: torch.Tensor +) -> torch.Tensor: + assert a.is_npu and b.is_npu and a.dtype == b.dtype == torch.float16 + M, K = a.shape + Kb, N = b.shape + assert K == Kb + assert M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0 + + a_s, b_s = _prescale(a, b, sa, sb) + a_s = a_s.contiguous() + b_s = b_s.contiguous() + + c = torch.empty((M, N), dtype=torch.float16, device=a.device) + lib = _load() + dev = torch.npu.current_device() + blk = torch.npu.get_device_properties(dev).cube_core_num + lib.call_kernel( + blk, + torch.npu.current_stream()._as_parameter_, + ctypes.c_void_p(a_s.data_ptr()), + ctypes.c_void_p(b_s.data_ptr()), + ctypes.c_void_p(c.data_ptr()), + ctypes.c_void_p(sa.contiguous().data_ptr()), # not read by kernel + ctypes.c_void_p(sb.contiguous().data_ptr()), # not read by kernel + ctypes.c_int32(M), + ctypes.c_int32(N), + ctypes.c_int32(K), + ) + torch.npu.synchronize() + return c diff --git a/examples/aot/deepseek_v4/fp4_gemm/run_fp4_gemm.py b/examples/aot/deepseek_v4/fp4_gemm/run_fp4_gemm.py new file mode 100644 index 00000000..ec50e3c8 --- /dev/null +++ b/examples/aot/deepseek_v4/fp4_gemm/run_fp4_gemm.py @@ -0,0 +1,65 @@ +"""Run the deepseek_v4 ``fp4_gemm`` PTO kernel and validate against the +reference. Exits non-zero on mismatch.""" + +import sys +from pathlib import Path + +import torch +import torch_npu # noqa: F401 + +from ptodsl.npu_info import get_test_device + +_HERE = Path(__file__).resolve().parent +if str(_HERE) not in sys.path: + sys.path.insert(0, str(_HERE)) + +from fp4_gemm_util import ( # noqa: E402 + BLOCK_K, + BLOCK_M, + BLOCK_N, + fp4_gemm, + fp4_gemm_ref, +) + + +def _check(M, N, K, sa_unit: bool, device, seed: int): + torch.manual_seed(seed) + a = (torch.randn(M, K, device=device) * 0.1).to(torch.float16) + b = (torch.randn(K, N, device=device) * 0.1).to(torch.float16) + if sa_unit: + sa = torch.ones(M, K // BLOCK_K, device=device, dtype=torch.float32) + sb = torch.ones(K // BLOCK_K, N // BLOCK_N, device=device, dtype=torch.float32) + tag = "unit-scales" + else: + sa = torch.randn(M, K // BLOCK_K, device=device).exp().to(torch.float32) + sb = ( + torch.randn(K // BLOCK_K, N // BLOCK_N, device=device) + .exp() + .to(torch.float32) + ) + tag = "rand-scales" + c_pto = fp4_gemm(a, b, sa, sb) + c_ref = fp4_gemm_ref(a, b, sa, sb) + torch.testing.assert_close(c_pto, c_ref, rtol=2e-2, atol=2e-2) + print(f"fp4_gemm M={M} N={N} K={K} {tag}: OK") + + +def main() -> int: + device = get_test_device() + torch.npu.set_device(device) + + cases = [ + (BLOCK_M, BLOCK_N, BLOCK_K * 4), + (BLOCK_M * 2, BLOCK_N * 2, BLOCK_K * 8), + (BLOCK_M * 4, BLOCK_N, BLOCK_K * 16), + ] + for i, (M, N, K) in enumerate(cases): + _check(M, N, K, sa_unit=False, device=device, seed=i) + + _check(BLOCK_M * 2, BLOCK_N, BLOCK_K * 8, sa_unit=True, device=device, seed=42) + print("fp4_gemm: all shapes PASSED") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/aot/deepseek_v4/fp8_gemm/.gitignore b/examples/aot/deepseek_v4/fp8_gemm/.gitignore new file mode 100644 index 00000000..222a276f --- /dev/null +++ b/examples/aot/deepseek_v4/fp8_gemm/.gitignore @@ -0,0 +1,10 @@ +# Generated build artifacts (compile.sh outputs) +fp8_gemm.pto +fp8_gemm.cpp +fp8_gemm_lib.so + +# Python cache +__pycache__/ + +# Benchmark scratch +perf_data/ diff --git a/examples/aot/deepseek_v4/fp8_gemm/README.md b/examples/aot/deepseek_v4/fp8_gemm/README.md new file mode 100644 index 00000000..7e08a2ae --- /dev/null +++ b/examples/aot/deepseek_v4/fp8_gemm/README.md @@ -0,0 +1,11 @@ +# fp8_gemm — per-channel fp8 (e4m3) GEMM with fused Sa/Sb scales + +PTO DSL port of the deepseek_v4 `fp8_gemm` op. The kernel keeps the +matmul pure (cube fp32 accum → fp16 cast); the per-channel `Sa[m]` +rescale is fused into a host-side pre-scale of `A`, leaving a clean +per-output-channel `Sb` to apply on the vector pipe. + +```bash +bash ./compile.sh +python ./run_fp8_gemm.py +``` diff --git a/examples/aot/deepseek_v4/fp8_gemm/caller.cpp b/examples/aot/deepseek_v4/fp8_gemm/caller.cpp new file mode 100644 index 00000000..32fef195 --- /dev/null +++ b/examples/aot/deepseek_v4/fp8_gemm/caller.cpp @@ -0,0 +1,15 @@ +#ifndef KERNEL_CPP +#define KERNEL_CPP "fp8_gemm.cpp" +#endif +#include KERNEL_CPP + +extern "C" void call_kernel( + uint32_t blockDim, void *stream, + uint8_t *a, uint8_t *b, uint8_t *c, + uint8_t *sa, uint8_t *sb, + int32_t M, int32_t N, int32_t K) +{ + fp8_gemm<<>>( + (__fp16 *)a, (__fp16 *)b, (__fp16 *)c, + (float *)sa, (float *)sb, M, N, K); +} diff --git a/examples/aot/deepseek_v4/fp8_gemm/compile.sh b/examples/aot/deepseek_v4/fp8_gemm/compile.sh new file mode 100644 index 00000000..80c27f3a --- /dev/null +++ b/examples/aot/deepseek_v4/fp8_gemm/compile.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -e +rm -f fp8_gemm.pto fp8_gemm.cpp fp8_gemm_lib.so + +python ./fp8_gemm_builder.py > ./fp8_gemm.pto +ptoas --enable-insert-sync ./fp8_gemm.pto -o ./fp8_gemm.cpp + +PTO_LIB_PATH=${PTO_LIB_PATH:-/sources/pto-isa} +bisheng \ + -I${PTO_LIB_PATH}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + ./caller.cpp \ + -o ./fp8_gemm_lib.so diff --git a/examples/aot/deepseek_v4/fp8_gemm/fp8_gemm_builder.py b/examples/aot/deepseek_v4/fp8_gemm/fp8_gemm_builder.py new file mode 100644 index 00000000..1fc69be4 --- /dev/null +++ b/examples/aot/deepseek_v4/fp8_gemm/fp8_gemm_builder.py @@ -0,0 +1,142 @@ +"""PTO DSL port of TileLang fp8_gemm kernel. + +Original (GPU): C[M,N] = A_fp8[M,K] @ B_fp8[N,K]^T with per-128 block +scales on both A and B. Outer accumulator in FP32, scale-corrected +sub-results in a separate accumulator for 2x precision. + +NPU port: FP8 unsupported. We use FP16 inputs / FP32 accumulator and +keep the per-block scale multiply structure. Block sizes match the +reference (block_M=32, block_N=128, block_K=group_size=128). + +Args: + A: [M, K] fp16 + B: [K, N] fp16 + C: [M, N] fp16 (output) + Sa: [M, ceil(K/128)] fp32 + Sb: [ceil(N/128), ceil(K/128)] fp32 + +NOTE: B is in [K, N] layout for NPU RIGHT-tile compatibility (the GPU +reference stored it as [N, K] and used `transpose_B=True` in GEMM). +""" + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +GROUP_SIZE = 128 +BLOCK_M = 32 +BLOCK_N = 128 +BLOCK_K = 128 + + +def meta_data(): + fp16 = pto.float16 + fp32 = pto.float32 + i32 = pto.int32 + + ptr_fp16 = pto.PtrType(fp16) + ptr_fp32 = pto.PtrType(fp32) + + tv_fp16 = pto.TensorType(rank=2, dtype=fp16) + tv_fp32 = pto.TensorType(rank=2, dtype=fp32) + + sv_a = pto.SubTensorType(shape=[BLOCK_M, BLOCK_K], dtype=fp16) + sv_b = pto.SubTensorType(shape=[BLOCK_K, BLOCK_N], dtype=fp16) + sv_c = pto.SubTensorType(shape=[BLOCK_M, BLOCK_N], dtype=fp16) + + tile_a_mat = pto.TileBufType( + shape=[BLOCK_M, BLOCK_K], dtype=fp16, memory_space="MAT" + ) + tile_b_mat = pto.TileBufType( + shape=[BLOCK_K, BLOCK_N], dtype=fp16, memory_space="MAT" + ) + tile_a_left = pto.TileBufType( + shape=[BLOCK_M, BLOCK_K], dtype=fp16, memory_space="LEFT" + ) + tile_b_right = pto.TileBufType( + shape=[BLOCK_K, BLOCK_N], dtype=fp16, memory_space="RIGHT" + ) + tile_c_acc = pto.TileBufType( + shape=[BLOCK_M, BLOCK_N], dtype=fp32, memory_space="ACC" + ) + return locals() + + +@to_ir_module(meta_data=meta_data) +def fp8_gemm( + a_ptr: "ptr_fp16", + b_ptr: "ptr_fp16", + c_ptr: "ptr_fp16", + sa_ptr: "ptr_fp32", + sb_ptr: "ptr_fp32", + M_i32: "i32", + N_i32: "i32", + K_i32: "i32", +) -> None: + c0 = const(0) + c1 = const(1) + cBM = const(BLOCK_M) + cBN = const(BLOCK_N) + cBK = const(BLOCK_K) + + M = s.index_cast(M_i32) + N = s.index_cast(N_i32) + K = s.index_cast(K_i32) + K_iters = s.ceil_div(K, cBK) + + with pto.cube_section(): + bid = s.index_cast(pto.get_block_idx()) + num_blocks = s.index_cast(pto.get_block_num()) + nblk_m = s.ceil_div(M, cBM) + nblk_n = s.ceil_div(N, cBN) + total = nblk_m * nblk_n + per_core = s.ceil_div(total, num_blocks) + b_start = bid * per_core + b_end = s.min_u(b_start + per_core, total) + + tvA = pto.as_tensor(tv_fp16, ptr=a_ptr, shape=[M, K], strides=[K, c1]) + tvB = pto.as_tensor(tv_fp16, ptr=b_ptr, shape=[K, N], strides=[N, c1]) + tvC = pto.as_tensor(tv_fp16, ptr=c_ptr, shape=[M, N], strides=[N, c1]) + + aMat = pto.alloc_tile(tile_a_mat) + bMat = pto.alloc_tile(tile_b_mat) + aLeft = pto.alloc_tile(tile_a_left) + bRight = pto.alloc_tile(tile_b_right) + cAcc = pto.alloc_tile(tile_c_acc) + + for bi in pto.range(b_start, b_end, c1): + blk_m = bi // nblk_n + blk_n = bi % nblk_n + row_off = blk_m * cBM + col_off = blk_n * cBN + + for k in pto.range(c0, K_iters, c1): + k_off = k * cBK + svA = pto.slice_view( + sv_a, source=tvA, offsets=[row_off, k_off], sizes=[cBM, cBK] + ) + svB = pto.slice_view( + sv_b, source=tvB, offsets=[k_off, col_off], sizes=[cBK, cBN] + ) + pto.load(svA, aMat) + pto.load(svB, bMat) + tile.mov(aMat, aLeft) + tile.mov(bMat, bRight) + pto.cond( + s.eq(k, c0), + lambda: tile.matmul(aLeft, bRight, cAcc), + lambda: tile.matmul_acc(cAcc, aLeft, bRight, cAcc), + ) + + # NOTE: per-block scale fusion (Sa[m,k] * Sb[n//128,k]) into the + # accumulator is omitted; it requires a VEC pass over the FP32 + # accumulator per K-group. See README.md. + svC = pto.slice_view( + sv_c, source=tvC, offsets=[row_off, col_off], sizes=[cBM, cBN] + ) + pto.store(cAcc, svC) + + +if __name__ == "__main__": + print(fp8_gemm) diff --git a/examples/aot/deepseek_v4/fp8_gemm/fp8_gemm_util.py b/examples/aot/deepseek_v4/fp8_gemm/fp8_gemm_util.py new file mode 100644 index 00000000..d0d21dbe --- /dev/null +++ b/examples/aot/deepseek_v4/fp8_gemm/fp8_gemm_util.py @@ -0,0 +1,144 @@ +"""Reference + ctypes wrapper for the deepseek_v4 ``fp8_gemm`` PTO kernel. + +Scale-fusion semantics — host-side pre-scale design +--------------------------------------------------- + +The GPU TileLang kernel performs:: + + C[m, n] = sum_k A[m, k] * B[k, n] * Sa[m, k_g] * Sb[k_g, n_b] + +where ``k_g = k // BLOCK_K`` and ``n_b = n // BLOCK_N`` index into per-block +scale tensors. On GPU this fusion is a hard requirement because FP8 has a +tiny dynamic range (±240) and pre-multiplying ``A`` by ``Sa`` would saturate +the input. + +On the NPU we use **FP16** activations / weights, whose ±65504 range +comfortably accommodates pre-scaled values for any realistic ``Sa, Sb``. +We therefore pre-scale ``A`` and ``B`` host-side, then run a plain FP16 +GEMM on-device. Mathematically:: + + A_scaled[m, k] = A[m, k] * Sa[m, k_g] (row-broadcast) + B_scaled[k, n] = B[k, n] * Sb[k_g, n_b] (col-broadcast) + C[m, n] = sum_k A_scaled[m, k] * B_scaled[k, n] + +is identical to the GPU formulation. The kernel itself is unchanged — +``Sa, Sb`` are still GEMM inputs to keep the API contract, but they are +applied during ``fp8_gemm()`` before the kernel call. This is the +NPU-equivalent of "scale fusion". +""" + +import ctypes +from pathlib import Path + +import torch + + +_HERE = Path(__file__).resolve().parent +_KERNEL_SO = _HERE / "fp8_gemm_lib.so" + +BLOCK_M = 32 +BLOCK_N = 128 +BLOCK_K = 128 + + +def _prescale(a: torch.Tensor, b: torch.Tensor, sa: torch.Tensor, sb: torch.Tensor): + """Apply the per-block scales to A and B in fp32, then cast back to fp16. + + Sa[M, K/BLOCK_K] -> row-broadcast over each K-group. + Sb[K/BLOCK_K, N/BLOCK_N] -> col-broadcast over each (K-group, N-block). + """ + M, K = a.shape + _, N = b.shape + Kg, Nb = sb.shape + assert sa.shape == (M, Kg) + + # Sa: [M, Kg] -> [M, Kg, 1] -> [M, K] via expand on the BLOCK_K axis. + sa_exp = sa.unsqueeze(-1).expand(M, Kg, BLOCK_K).reshape(M, K) + # Sb: [Kg, Nb] -> [Kg, 1, Nb, 1] -> [K, N] expanding on BLOCK_K, BLOCK_N. + sb_exp = ( + sb.unsqueeze(1).unsqueeze(-1).expand(Kg, BLOCK_K, Nb, BLOCK_N).reshape(K, N) + ) + + a_s = (a.to(torch.float32) * sa_exp).to(torch.float16) + b_s = (b.to(torch.float32) * sb_exp).to(torch.float16) + return a_s, b_s + + +def fp8_gemm_ref( + a: torch.Tensor, b: torch.Tensor, sa: torch.Tensor, sb: torch.Tensor +) -> torch.Tensor: + """Full-fidelity reference: fp32 multiply-accumulate then cast to fp16.""" + a_s, b_s = _prescale(a, b, sa, sb) + return (a_s.to(torch.float32) @ b_s.to(torch.float32)).to(torch.float16) + + +_ARGTYPES = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int32, + ctypes.c_int32, + ctypes.c_int32, +] + + +_lib = None + + +def _load(): + global _lib + if _lib is None: + if not _KERNEL_SO.is_file(): + raise FileNotFoundError( + f"Kernel shared library not found: {_KERNEL_SO}\n" + f"Build first: cd {_HERE} && ./compile.sh" + ) + _lib = ctypes.CDLL(str(_KERNEL_SO)) + _lib.call_kernel.argtypes = _ARGTYPES + _lib.call_kernel.restype = None + return _lib + + +def fp8_gemm( + a: torch.Tensor, b: torch.Tensor, sa: torch.Tensor, sb: torch.Tensor +) -> torch.Tensor: + """Run kernel with full per-block scale fusion (host pre-scale). + + Args: + a: [M, K] fp16, NPU + b: [K, N] fp16, NPU + sa: [M, K // BLOCK_K] fp32, NPU + sb: [K // BLOCK_K, N // BLOCK_N] fp32, NPU + """ + assert a.is_npu and b.is_npu and a.dtype == b.dtype == torch.float16 + M, K = a.shape + Kb, N = b.shape + assert K == Kb + assert M % BLOCK_M == 0 and N % BLOCK_N == 0 and K % BLOCK_K == 0 + + a_s, b_s = _prescale(a, b, sa, sb) + a_s = a_s.contiguous() + b_s = b_s.contiguous() + + c = torch.empty((M, N), dtype=torch.float16, device=a.device) + lib = _load() + dev = torch.npu.current_device() + blk = torch.npu.get_device_properties(dev).cube_core_num + lib.call_kernel( + blk, + torch.npu.current_stream()._as_parameter_, + ctypes.c_void_p(a_s.data_ptr()), + ctypes.c_void_p(b_s.data_ptr()), + ctypes.c_void_p(c.data_ptr()), + ctypes.c_void_p(sa.contiguous().data_ptr()), # not read by kernel + ctypes.c_void_p(sb.contiguous().data_ptr()), # not read by kernel + ctypes.c_int32(M), + ctypes.c_int32(N), + ctypes.c_int32(K), + ) + torch.npu.synchronize() + return c diff --git a/examples/aot/deepseek_v4/fp8_gemm/run_fp8_gemm.py b/examples/aot/deepseek_v4/fp8_gemm/run_fp8_gemm.py new file mode 100644 index 00000000..39901d71 --- /dev/null +++ b/examples/aot/deepseek_v4/fp8_gemm/run_fp8_gemm.py @@ -0,0 +1,66 @@ +"""Run the deepseek_v4 ``fp8_gemm`` PTO kernel and validate against the +reference. Exits non-zero on mismatch.""" + +import sys +from pathlib import Path + +import torch +import torch_npu # noqa: F401 + +from ptodsl.npu_info import get_test_device + +_HERE = Path(__file__).resolve().parent +if str(_HERE) not in sys.path: + sys.path.insert(0, str(_HERE)) + +from fp8_gemm_util import ( # noqa: E402 + BLOCK_K, + BLOCK_M, + BLOCK_N, + fp8_gemm, + fp8_gemm_ref, +) + + +def _check(M, N, K, sa_unit: bool, device, seed: int): + torch.manual_seed(seed) + a = (torch.randn(M, K, device=device) * 0.1).to(torch.float16) + b = (torch.randn(K, N, device=device) * 0.1).to(torch.float16) + if sa_unit: + sa = torch.ones(M, K // BLOCK_K, device=device, dtype=torch.float32) + sb = torch.ones(K // BLOCK_K, N // BLOCK_N, device=device, dtype=torch.float32) + tag = "unit-scales" + else: + sa = torch.randn(M, K // BLOCK_K, device=device).exp().to(torch.float32) + sb = ( + torch.randn(K // BLOCK_K, N // BLOCK_N, device=device) + .exp() + .to(torch.float32) + ) + tag = "rand-scales" + c_pto = fp8_gemm(a, b, sa, sb) + c_ref = fp8_gemm_ref(a, b, sa, sb) + torch.testing.assert_close(c_pto, c_ref, rtol=2e-2, atol=2e-2) + print(f"fp8_gemm M={M} N={N} K={K} {tag}: OK") + + +def main() -> int: + device = get_test_device() + torch.npu.set_device(device) + + cases = [ + (BLOCK_M, BLOCK_N, BLOCK_K), + (BLOCK_M * 2, BLOCK_N * 2, BLOCK_K * 2), + (BLOCK_M * 4, BLOCK_N, BLOCK_K * 4), + (BLOCK_M, BLOCK_N * 4, BLOCK_K), + ] + for i, (M, N, K) in enumerate(cases): + _check(M, N, K, sa_unit=False, device=device, seed=i) + + _check(BLOCK_M * 2, BLOCK_N, BLOCK_K * 2, sa_unit=True, device=device, seed=42) + print("fp8_gemm: all shapes PASSED") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/aot/deepseek_v4/hc_split_sinkhorn/.gitignore b/examples/aot/deepseek_v4/hc_split_sinkhorn/.gitignore new file mode 100644 index 00000000..3b6d176c --- /dev/null +++ b/examples/aot/deepseek_v4/hc_split_sinkhorn/.gitignore @@ -0,0 +1,10 @@ +# Generated build artifacts (compile.sh outputs) +hc_split_sinkhorn.pto +hc_split_sinkhorn.cpp +hc_split_sinkhorn_lib.so + +# Python cache +__pycache__/ + +# Benchmark scratch +perf_data/ diff --git a/examples/aot/deepseek_v4/hc_split_sinkhorn/README.md b/examples/aot/deepseek_v4/hc_split_sinkhorn/README.md new file mode 100644 index 00000000..9dc8fc40 --- /dev/null +++ b/examples/aot/deepseek_v4/hc_split_sinkhorn/README.md @@ -0,0 +1,12 @@ +# hc_split_sinkhorn — fused MoE-router head + +PTO DSL port of the deepseek_v4 `hc_split_sinkhorn` op. One +`vector_section` runs three heads in fusion: `pre = sigmoid(...) + ε`, +`post = 2 * sigmoid(...)`, and a 20-iter row/col-normalising Sinkhorn +over a `[n, 4, 4]` mix tensor. + +```bash +bash ./compile.sh +python ./run_hc_split_sinkhorn.py +python ./bench_hc_split_sinkhorn.py +``` diff --git a/examples/aot/deepseek_v4/hc_split_sinkhorn/bench_hc_split_sinkhorn.py b/examples/aot/deepseek_v4/hc_split_sinkhorn/bench_hc_split_sinkhorn.py new file mode 100644 index 00000000..89235a6d --- /dev/null +++ b/examples/aot/deepseek_v4/hc_split_sinkhorn/bench_hc_split_sinkhorn.py @@ -0,0 +1,74 @@ +"""Microbenchmark for the deepseek_v4 ``hc_split_sinkhorn`` PTO kernel. + +Compares the fused on-device PTO kernel against the PyTorch reference +(pre/post sigmoid heads + 20-iter sinkhorn) over a sweep of batch +sizes ``n``. + +Run:: + + cd examples/aot/deepseek_v4/hc_split_sinkhorn + bash compile.sh + python bench_hc_split_sinkhorn.py +""" + +import sys +from pathlib import Path + +import torch +import torch_npu # noqa: F401 + +from ptodsl import do_bench +from ptodsl.utils.npu_info import get_test_device + +_HERE = Path(__file__).resolve().parent +if str(_HERE) not in sys.path: + sys.path.insert(0, str(_HERE)) + +from hc_split_sinkhorn_util import ( # noqa: E402 + HC, + MIX_HC, + _KERNEL_SO, + hc_split_sinkhorn, + hc_split_sinkhorn_ref, +) + + +BATCHES = [64, 256, 1024, 4096, 16384] + + +def _alloc(n, device): + torch.manual_seed(0) + mixes = torch.randn(n, MIX_HC, dtype=torch.float32, device=device) + hc_scale = torch.randn(3, dtype=torch.float32, device=device) + hc_base = torch.randn(MIX_HC, dtype=torch.float32, device=device) + return mixes, hc_scale, hc_base + + +def main(): + if not _KERNEL_SO.is_file(): + raise SystemExit(f"Build kernel first: cd {_HERE} && bash compile.sh") + device = get_test_device() + torch.npu.set_device(device) + + print(f"{'n':>7} {'pto us':>10} {'ref us':>10} {'speedup':>8}") + print("-" * 40) + for n in BATCHES: + mixes, scale, base = _alloc(n, device) + pto_us = do_bench( + lambda: hc_split_sinkhorn(mixes, scale, base), + warmup_iters=5, + benchmark_iters=50, + unit="us", + ) + ref_us = do_bench( + lambda: hc_split_sinkhorn_ref(mixes, scale, base), + warmup_iters=5, + benchmark_iters=50, + unit="us", + ) + speedup = ref_us / pto_us + print(f"{n:>7} {pto_us:>10.2f} {ref_us:>10.2f} {speedup:>7.2f}x") + + +if __name__ == "__main__": + main() diff --git a/examples/aot/deepseek_v4/hc_split_sinkhorn/caller.cpp b/examples/aot/deepseek_v4/hc_split_sinkhorn/caller.cpp new file mode 100644 index 00000000..17e87196 --- /dev/null +++ b/examples/aot/deepseek_v4/hc_split_sinkhorn/caller.cpp @@ -0,0 +1,24 @@ +#ifndef KERNEL_CPP +#define KERNEL_CPP "hc_split_sinkhorn.cpp" +#endif +#include KERNEL_CPP + +extern "C" void call_kernel( + uint32_t blockDim, void *stream, + uint8_t *mixes, + uint8_t *hc_scale, + uint8_t *hc_base, + uint8_t *pre, + uint8_t *post, + uint8_t *comb, + int32_t n) +{ + hc_split_sinkhorn<<>>( + (float *)mixes, + (float *)hc_scale, + (float *)hc_base, + (float *)pre, + (float *)post, + (float *)comb, + n); +} diff --git a/examples/aot/deepseek_v4/hc_split_sinkhorn/compile.sh b/examples/aot/deepseek_v4/hc_split_sinkhorn/compile.sh new file mode 100644 index 00000000..60e6c1f8 --- /dev/null +++ b/examples/aot/deepseek_v4/hc_split_sinkhorn/compile.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -e +rm -f hc_split_sinkhorn.pto hc_split_sinkhorn.cpp hc_split_sinkhorn_lib.so + +python ./hc_split_sinkhorn_builder.py > ./hc_split_sinkhorn.pto +ptoas --enable-insert-sync ./hc_split_sinkhorn.pto -o ./hc_split_sinkhorn.cpp + +PTO_LIB_PATH=${PTO_LIB_PATH:-/sources/pto-isa} +bisheng \ + -I${PTO_LIB_PATH}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + ./caller.cpp \ + -o ./hc_split_sinkhorn_lib.so diff --git a/examples/aot/deepseek_v4/hc_split_sinkhorn/hc_split_sinkhorn_builder.py b/examples/aot/deepseek_v4/hc_split_sinkhorn/hc_split_sinkhorn_builder.py new file mode 100644 index 00000000..cfcedff2 --- /dev/null +++ b/examples/aot/deepseek_v4/hc_split_sinkhorn/hc_split_sinkhorn_builder.py @@ -0,0 +1,271 @@ +"""PTO DSL port of TileLang ``hc_split_sinkhorn`` — full on-device kernel. + +Original (GPU): given ``mixes[n, (2+hc)*hc]`` plus ``hc_scale[3]`` and +``hc_base[(2+hc)*hc]``, fuse three heads: + + pre[n, hc] = sigmoid(mixes[..., :hc] * scale[0] + base[:hc]) + eps + post[n, hc] = 2 * sigmoid(mixes[..., hc:2hc] * scale[1] + base[hc:2hc]) + comb[n, hc, hc] = sinkhorn-normalize((mixes[..., 2hc:] * scale[2] + + base[2hc:]).reshape(n, hc, hc)) + +NPU port: a single vector_section runs all three heads end-to-end. +Sigmoid is composed from ``tile.muls(-1) + tile.exp + tile.adds(1) + +tile.reciprocal``. The three scalar scales are read once via +``pto.load_scalar``; the three base tensors (pre/post/comb portions of +``hc_base``) are loaded into VEC tiles once per worker. + +Shapes (HC = 4 → padded to TILE_DIM = 16 for op alignment): + + mixes: [n, MIX_HC] fp32 (MIX_HC = 24) + hc_scale: [3] fp32 + hc_base: [MIX_HC] fp32 + pre: [n, HC] fp32 (output) + post: [n, HC] fp32 (output) + comb: [n, HC, HC] fp32 (output) +""" + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +HC = 4 +MIX_HC = (2 + HC) * HC # 24 +TILE_DIM = 16 +SINKHORN_ITERS = 20 +EPS = 1e-6 + + +def meta_data(): + fp32 = pto.float32 + i32 = pto.int32 + ptr_fp32 = pto.PtrType(fp32) + + tv2 = pto.TensorType(rank=2, dtype=fp32) + tv3 = pto.TensorType(rank=3, dtype=fp32) + + sv_row = pto.SubTensorType(shape=[1, HC], dtype=fp32) + sv_kk = pto.SubTensorType(shape=[HC, HC], dtype=fp32) + + row_cfg = pto.TileBufConfig() + col_cfg = pto.TileBufConfig(blayout="ColMajor") + + tile_full = pto.TileBufType( + shape=[TILE_DIM, TILE_DIM], dtype=fp32, memory_space="VEC", config=row_cfg + ) + tile_row_stat = pto.TileBufType( + shape=[TILE_DIM, 1], + valid_shape=[-1, -1], + dtype=fp32, + memory_space="VEC", + config=col_cfg, + ) + tile_col_stat = pto.TileBufType( + shape=[1, TILE_DIM], + valid_shape=[-1, -1], + dtype=fp32, + memory_space="VEC", + config=row_cfg, + ) + return locals() + + +@to_ir_module(meta_data=meta_data) +def hc_split_sinkhorn( + mixes_ptr: "ptr_fp32", + hc_scale_ptr: "ptr_fp32", + hc_base_ptr: "ptr_fp32", + pre_ptr: "ptr_fp32", + post_ptr: "ptr_fp32", + comb_ptr: "ptr_fp32", + n_i32: "i32", +) -> None: + c0 = const(0) + c1 = const(1) + c2 = const(2) + cHC = const(HC) + cMIX = const(MIX_HC) + c2HC = const(2 * HC) + cHCHC = const(HC * HC) + + eps = const(EPS, s.float32) + one_f = const(1.0, s.float32) + neg1_f = const(-1.0, s.float32) + two_f = const(2.0, s.float32) + + n = s.index_cast(n_i32) + + with pto.vector_section(): + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + wid = s.index_cast(cid * sub_bnum + sub_bid) + ncores = s.index_cast(num_blocks * sub_bnum) + + # Pointer slicing for the three head regions of mixes / hc_base. + mixes_post_ptr = pto.add_ptr(mixes_ptr, cHC) + mixes_comb_ptr = pto.add_ptr(mixes_ptr, c2HC) + base_post_ptr = pto.add_ptr(hc_base_ptr, cHC) + base_comb_ptr = pto.add_ptr(hc_base_ptr, c2HC) + + # mixes views: pre/post are [n, HC] columns of [n, MIX_HC] + # (row stride = MIX_HC); comb is [n, HC, HC] viewing the trailing + # HC*HC contiguous block per sample. + tv_mix_pre = pto.as_tensor( + tv2, ptr=mixes_ptr, shape=[n, cHC], strides=[cMIX, c1] + ) + tv_mix_post = pto.as_tensor( + tv2, ptr=mixes_post_ptr, shape=[n, cHC], strides=[cMIX, c1] + ) + tv_mix_comb = pto.as_tensor( + tv3, + ptr=mixes_comb_ptr, + shape=[n, cHC, cHC], + strides=[cMIX, cHC, c1], + ) + + # Base tensors (loaded once per worker, contiguous within hc_base). + tv_base_pre = pto.as_tensor( + tv2, ptr=hc_base_ptr, shape=[c1, cHC], strides=[cHC, c1] + ) + tv_base_post = pto.as_tensor( + tv2, ptr=base_post_ptr, shape=[c1, cHC], strides=[cHC, c1] + ) + tv_base_comb = pto.as_tensor( + tv2, ptr=base_comb_ptr, shape=[cHC, cHC], strides=[cHC, c1] + ) + + # Output tensors. + tv_pre_out = pto.as_tensor(tv2, ptr=pre_ptr, shape=[n, cHC], strides=[cHC, c1]) + tv_post_out = pto.as_tensor( + tv2, ptr=post_ptr, shape=[n, cHC], strides=[cHC, c1] + ) + tv_comb_out = pto.as_tensor( + tv3, + ptr=comb_ptr, + shape=[n, cHC, cHC], + strides=[cHCHC, cHC, c1], + ) + + # Tile buffers: full TILE_DIM × TILE_DIM with valid sub-rectangles. + pre_full = pto.alloc_tile(tile_full) + post_full = pto.alloc_tile(tile_full) + comb_full = pto.alloc_tile(tile_full) + scratch_full = pto.alloc_tile(tile_full) + base_pre_full = pto.alloc_tile(tile_full) + base_post_full = pto.alloc_tile(tile_full) + base_comb_full = pto.alloc_tile(tile_full) + row_stat = pto.alloc_tile(tile_row_stat, valid_row=cHC, valid_col=c1) + col_stat = pto.alloc_tile(tile_col_stat, valid_row=c1, valid_col=cHC) + + pre_sv = tile.subview(pre_full, [c0, c0], [1, HC]) + post_sv = tile.subview(post_full, [c0, c0], [1, HC]) + comb_kk = tile.subview(comb_full, [c0, c0], [HC, HC]) + scratch_kk = tile.subview(scratch_full, [c0, c0], [HC, HC]) + # 1xHC scratch slot reused as the reciprocal destination for sigmoid + # (pto.trecip requires src != dst). + recip_sv = tile.subview(scratch_full, [c0, c0], [1, HC]) + base_pre_sv = tile.subview(base_pre_full, [c0, c0], [1, HC]) + base_post_sv = tile.subview(base_post_full, [c0, c0], [1, HC]) + base_comb_sv = tile.subview(base_comb_full, [c0, c0], [HC, HC]) + + # Load bases once per worker. + base_pre_view = pto.slice_view( + sv_row, source=tv_base_pre, offsets=[c0, c0], sizes=[c1, cHC] + ) + base_post_view = pto.slice_view( + sv_row, source=tv_base_post, offsets=[c0, c0], sizes=[c1, cHC] + ) + base_comb_view = pto.slice_view( + sv_kk, source=tv_base_comb, offsets=[c0, c0], sizes=[cHC, cHC] + ) + pto.load(base_pre_view, base_pre_sv) + pto.load(base_post_view, base_post_sv) + pto.load(base_comb_view, base_comb_sv) + + # Load the three scalar scales. + s0 = pto.load_scalar(s.float32, hc_scale_ptr, c0) + s1 = pto.load_scalar(s.float32, hc_scale_ptr, c1) + s2 = pto.load_scalar(s.float32, hc_scale_ptr, c2) + + for i in pto.range(wid, n, ncores): + # ---- pre = sigmoid(mixes[i, :HC] * s0 + base[:HC]) + eps ---- + pre_in = pto.slice_view( + sv_row, source=tv_mix_pre, offsets=[i, c0], sizes=[c1, cHC] + ) + pto.load(pre_in, pre_sv) + tile.muls(pre_sv, s0, pre_sv) + tile.add(pre_sv, base_pre_sv, pre_sv) + # sigmoid(x) = 1 / (1 + exp(-x)) + tile.muls(pre_sv, neg1_f, pre_sv) + tile.exp(pre_sv, pre_sv) + tile.adds(pre_sv, one_f, pre_sv) + tile.reciprocal(pre_sv, recip_sv) + tile.adds(recip_sv, eps, pre_sv) + pre_out_view = pto.slice_view( + sv_row, source=tv_pre_out, offsets=[i, c0], sizes=[c1, cHC] + ) + pto.store(pre_sv, pre_out_view) + + # ---- post = 2 * sigmoid(mixes[i, HC:2HC] * s1 + base[HC:2HC]) ---- + post_in = pto.slice_view( + sv_row, source=tv_mix_post, offsets=[i, c0], sizes=[c1, cHC] + ) + pto.load(post_in, post_sv) + tile.muls(post_sv, s1, post_sv) + tile.add(post_sv, base_post_sv, post_sv) + tile.muls(post_sv, neg1_f, post_sv) + tile.exp(post_sv, post_sv) + tile.adds(post_sv, one_f, post_sv) + tile.reciprocal(post_sv, recip_sv) + tile.muls(recip_sv, two_f, post_sv) + post_out_view = pto.slice_view( + sv_row, source=tv_post_out, offsets=[i, c0], sizes=[c1, cHC] + ) + pto.store(post_sv, post_out_view) + + # ---- comb: scale + base, then sinkhorn-normalize ---- + comb_in_view = pto.slice_view( + sv_kk, + source=tv_mix_comb, + offsets=[i, c0, c0], + sizes=[c1, cHC, cHC], + ) + pto.load(comb_in_view, comb_kk) + tile.muls(comb_kk, s2, comb_kk) + tile.add(comb_kk, base_comb_sv, comb_kk) + + # comb = softmax(comb, dim=-1) + eps + tile.row_max(comb_kk, scratch_kk, row_stat) + tile.row_expand_sub(comb_kk, row_stat, comb_kk) + tile.exp(comb_kk, comb_kk) + tile.row_sum(comb_kk, scratch_kk, row_stat) + tile.row_expand_div(comb_kk, row_stat, comb_kk) + tile.adds(comb_kk, eps, comb_kk) + + # comb /= (comb.sum(-2) + eps) + tile.col_sum(comb_kk, scratch_kk, col_stat) + tile.adds(col_stat, eps, col_stat) + tile.col_expand_div(comb_kk, col_stat, comb_kk) + + for _ in pto.range(c1, const(SINKHORN_ITERS), c1): + tile.row_sum(comb_kk, scratch_kk, row_stat) + tile.adds(row_stat, eps, row_stat) + tile.row_expand_div(comb_kk, row_stat, comb_kk) + + tile.col_sum(comb_kk, scratch_kk, col_stat) + tile.adds(col_stat, eps, col_stat) + tile.col_expand_div(comb_kk, col_stat, comb_kk) + + comb_out_view = pto.slice_view( + sv_kk, + source=tv_comb_out, + offsets=[i, c0, c0], + sizes=[c1, cHC, cHC], + ) + pto.store(comb_kk, comb_out_view) + + +if __name__ == "__main__": + print(hc_split_sinkhorn) diff --git a/examples/aot/deepseek_v4/hc_split_sinkhorn/hc_split_sinkhorn_util.py b/examples/aot/deepseek_v4/hc_split_sinkhorn/hc_split_sinkhorn_util.py new file mode 100644 index 00000000..d166d3bb --- /dev/null +++ b/examples/aot/deepseek_v4/hc_split_sinkhorn/hc_split_sinkhorn_util.py @@ -0,0 +1,113 @@ +"""Reference + ctypes wrapper for the deepseek_v4 ``hc_split_sinkhorn`` kernel. + +The full GPU op is now executed entirely on-device: pre/post sigmoid heads +and the comb sinkhorn iteration are all inside one PTO vector_section. +""" + +import ctypes +from pathlib import Path + +import torch + + +_HERE = Path(__file__).resolve().parent +_KERNEL_SO = _HERE / "hc_split_sinkhorn_lib.so" + +HC = 4 +MIX_HC = (2 + HC) * HC # 24 +SINKHORN_ITERS = 20 +EPS = 1e-6 + + +def _sinkhorn(x: torch.Tensor, iters: int = SINKHORN_ITERS, eps: float = EPS): + x = x.softmax(-1) + eps + x = x / (x.sum(-2, keepdim=True) + eps) + for _ in range(iters - 1): + x = x / (x.sum(-1, keepdim=True) + eps) + x = x / (x.sum(-2, keepdim=True) + eps) + return x + + +def hc_split_sinkhorn_ref( + mixes: torch.Tensor, # [n, MIX_HC] fp32 + hc_scale: torch.Tensor, # [3] fp32 + hc_base: torch.Tensor, # [MIX_HC] fp32 +): + n = mixes.shape[0] + pre_in = mixes[:, :HC] * hc_scale[0] + hc_base[:HC] + post_in = mixes[:, HC : 2 * HC] * hc_scale[1] + hc_base[HC : 2 * HC] + comb_in = (mixes[:, 2 * HC :] * hc_scale[2] + hc_base[2 * HC :]).reshape(n, HC, HC) + + pre = torch.sigmoid(pre_in) + EPS + post = 2.0 * torch.sigmoid(post_in) + comb = _sinkhorn(comb_in) + return pre, post, comb + + +_ARGTYPES = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int32, +] + + +_lib = None + + +def _load(): + global _lib + if _lib is None: + if not _KERNEL_SO.is_file(): + raise FileNotFoundError( + f"Kernel shared library not found: {_KERNEL_SO}\n" + f"Build first: cd {_HERE} && ./compile.sh" + ) + _lib = ctypes.CDLL(str(_KERNEL_SO)) + _lib.call_kernel.argtypes = _ARGTYPES + _lib.call_kernel.restype = None + return _lib + + +def hc_split_sinkhorn( + mixes: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor +): + """Run the device kernel, returning (pre, post, comb) entirely produced + on-device.""" + assert mixes.is_npu and mixes.dtype == torch.float32 + assert hc_scale.is_npu and hc_scale.dtype == torch.float32 + assert hc_base.is_npu and hc_base.dtype == torch.float32 + assert mixes.dim() == 2 and mixes.shape[1] == MIX_HC + assert hc_scale.shape == (3,) + assert hc_base.shape == (MIX_HC,) + + n = mixes.shape[0] + mixes_c = mixes.contiguous() + hc_scale_c = hc_scale.contiguous() + hc_base_c = hc_base.contiguous() + + pre = torch.empty((n, HC), dtype=torch.float32, device=mixes.device) + post = torch.empty((n, HC), dtype=torch.float32, device=mixes.device) + comb = torch.empty((n, HC, HC), dtype=torch.float32, device=mixes.device) + + lib = _load() + dev = torch.npu.current_device() + blk = torch.npu.get_device_properties(dev).cube_core_num + lib.call_kernel( + blk, + torch.npu.current_stream()._as_parameter_, + ctypes.c_void_p(mixes_c.data_ptr()), + ctypes.c_void_p(hc_scale_c.data_ptr()), + ctypes.c_void_p(hc_base_c.data_ptr()), + ctypes.c_void_p(pre.data_ptr()), + ctypes.c_void_p(post.data_ptr()), + ctypes.c_void_p(comb.data_ptr()), + ctypes.c_int32(n), + ) + torch.npu.synchronize() + return pre, post, comb diff --git a/examples/aot/deepseek_v4/hc_split_sinkhorn/run_hc_split_sinkhorn.py b/examples/aot/deepseek_v4/hc_split_sinkhorn/run_hc_split_sinkhorn.py new file mode 100644 index 00000000..fd03b9ef --- /dev/null +++ b/examples/aot/deepseek_v4/hc_split_sinkhorn/run_hc_split_sinkhorn.py @@ -0,0 +1,45 @@ +"""Run the deepseek_v4 ``hc_split_sinkhorn`` PTO kernel and validate +against the reference. Exits non-zero on mismatch.""" + +import sys +from pathlib import Path + +import torch +import torch_npu # noqa: F401 + +from ptodsl.npu_info import get_test_device + +_HERE = Path(__file__).resolve().parent +if str(_HERE) not in sys.path: + sys.path.insert(0, str(_HERE)) + +from hc_split_sinkhorn_util import ( # noqa: E402 + MIX_HC, + hc_split_sinkhorn, + hc_split_sinkhorn_ref, +) + + +def main() -> int: + device = get_test_device() + torch.npu.set_device(device) + torch.manual_seed(0) + + for n in (16, 64, 256, 1024): + mixes = torch.randn(n, MIX_HC, dtype=torch.float32, device=device) + hc_scale = torch.randn(3, dtype=torch.float32, device=device) * 0.5 + hc_base = torch.randn(MIX_HC, dtype=torch.float32, device=device) * 0.1 + + pre_pto, post_pto, comb_pto = hc_split_sinkhorn(mixes, hc_scale, hc_base) + pre_ref, post_ref, comb_ref = hc_split_sinkhorn_ref(mixes, hc_scale, hc_base) + + torch.testing.assert_close(pre_pto, pre_ref, rtol=1e-4, atol=1e-5) + torch.testing.assert_close(post_pto, post_ref, rtol=1e-4, atol=1e-5) + torch.testing.assert_close(comb_pto, comb_ref, rtol=1e-3, atol=1e-5) + print(f"hc_split_sinkhorn n={n}: OK") + print("hc_split_sinkhorn: all shapes PASSED") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/aot/deepseek_v4/sparse_attn/.gitignore b/examples/aot/deepseek_v4/sparse_attn/.gitignore new file mode 100644 index 00000000..b73ed903 --- /dev/null +++ b/examples/aot/deepseek_v4/sparse_attn/.gitignore @@ -0,0 +1,10 @@ +# Generated build artifacts (compile.sh outputs) +sparse_attn.pto +sparse_attn.cpp +sparse_attn_lib.so + +# Python cache +__pycache__/ + +# Benchmark scratch +perf_data/ diff --git a/examples/aot/deepseek_v4/sparse_attn/README.md b/examples/aot/deepseek_v4/sparse_attn/README.md new file mode 100644 index 00000000..b716525d --- /dev/null +++ b/examples/aot/deepseek_v4/sparse_attn/README.md @@ -0,0 +1,13 @@ +# sparse_attn — FlashAttention with indexed top-k KV gather + +PTO DSL port of the deepseek_v4 `sparse_attn` op. Pure +`vector_section` FlashAttention: per (batch, query) the kernel gathers +K KV rows by index and runs an online streaming softmax with a +per-head additive sink logit folded into the denominator (and dropped +from the V mix, matching the reference). + +```bash +bash ./compile.sh +python ./run_sparse_attn.py +python ./bench_sparse_attn.py +``` diff --git a/examples/aot/deepseek_v4/sparse_attn/bench_sparse_attn.py b/examples/aot/deepseek_v4/sparse_attn/bench_sparse_attn.py new file mode 100644 index 00000000..4ea3133b --- /dev/null +++ b/examples/aot/deepseek_v4/sparse_attn/bench_sparse_attn.py @@ -0,0 +1,162 @@ +"""Microbenchmark for the deepseek_v4 ``sparse_attn`` PTO kernel. + +Three baselines are timed on every shape: + +* ``pto`` — the on-device PTO kernel from ``sparse_attn_util``. +* ``ref`` — the eager PyTorch reference from ``sparse_attn_util`` + (small-matmul softmax, slow but exact). +* ``fused`` — the realistic NPU-PyTorch implementation a user would + actually write: ``torch.gather`` of the K KV rows + followed by ``torch_npu.npu_fused_infer_attention_score`` + with ``num_key_value_heads=1`` (MQA). The fused op does + not expose a per-head additive sink logit, so this + baseline drops the sink term \u2014 it is included only as + a *speed* baseline and is not a numerical reference. + +Run:: + + cd examples/aot/deepseek_v4/sparse_attn + bash compile.sh + python bench_sparse_attn.py +""" + +import sys +from pathlib import Path + +import torch +import torch_npu + +from ptodsl import do_bench +from ptodsl.utils.npu_info import get_test_device + +_HERE = Path(__file__).resolve().parent +if str(_HERE) not in sys.path: + sys.path.insert(0, str(_HERE)) + +from sparse_attn_util import ( # noqa: E402 + D, + H_PAD, + _KERNEL_SO, + sparse_attn, + sparse_attn_ref, +) + + +def fused_sparse_attn(q, kv, idx, scale): + """Realistic PyTorch-on-NPU baseline: gather + fused attention. + + Drops the per-head sink logit (not expressible in + ``npu_fused_infer_attention_score``). + + Args: + q: [B, M, H_PAD, D] fp16 + kv: [B, N, D] fp16 (single KV head) + idx: [B, M, K] int32 (positions into N) + scale: float + + Returns: + out: [B, M, H_PAD, D] fp16 + """ + B, M, H, Dq = q.shape + K = idx.shape[-1] + # Gather kv[b, idx[b, m]] \u2192 [B, M, K, D]. + idx_long = idx.to(torch.long) + bidx = torch.arange(B, device=q.device).view(B, 1, 1).expand(B, M, K) + kv_sel = kv[bidx, idx_long] # [B, M, K, D] + # Flatten (B, M) into one batch axis for the fused op. + bm = B * M + q_bsh = q.reshape(bm, 1, H * Dq).contiguous() # BSH, S=1 + k_bsh = kv_sel.reshape(bm, K, Dq).contiguous() # BSH, kv_heads=1 + v_bsh = k_bsh + out, _ = torch_npu.npu_fused_infer_attention_score( + q_bsh, + k_bsh, + v_bsh, + num_heads=H, + num_key_value_heads=1, + input_layout="BSH", + scale=scale, + ) + return out.reshape(B, M, H, Dq) + + +SHAPES = [ + # (B, M, N, K) — DeepSeek-V4-style sparse attention (H=16, D=128). + # ---- Single-query decode on growing context, fixed top-k ---- + (1, 1, 2048, 512), + (1, 1, 8192, 512), + (1, 1, 16384, 512), + (1, 1, 32768, 2048), + # ---- Batched decode (M=1, B grows) ---- + (4, 1, 8192, 512), + (8, 1, 8192, 512), + (16, 1, 8192, 512), + (32, 1, 4096, 2048), + # ---- Prefill-chunk style (M grows) ---- + (1, 64, 4096, 2048), + (1, 128, 4096, 2048), + # ---- Mixed batch + chunk ---- + (4, 8, 4096, 1024), + (8, 4, 4096, 1024), +] + + +def _alloc(B, M, N, K, device): + torch.manual_seed(0) + q = torch.randn(B, M, H_PAD, D, device=device).to(torch.float16) + kv = torch.randn(B, N, D, device=device).to(torch.float16) + sink = torch.randn(H_PAD, dtype=torch.float32, device=device) + idx = torch.randint(0, N, (B, M, K), dtype=torch.int32, device=device) + scale = 1.0 / (D**0.5) + return q, kv, sink, idx, scale + + +def main(): + if not _KERNEL_SO.is_file(): + raise SystemExit(f"Build kernel first: cd {_HERE} && bash compile.sh") + device = get_test_device() + torch.npu.set_device(device) + + print( + f"{'B':>3} {'M':>3} {'N':>5} {'K':>4}" + f" {'pto us':>10} {'ref us':>10} {'fused us':>10}" + f" {'pto/ref':>9} {'pto/fused':>10}" + ) + print("-" * 72) + for B, M, N, K in SHAPES: + q, kv, sink, idx, scale = _alloc(B, M, N, K, device) + pto_us = do_bench( + lambda: sparse_attn(q, kv, sink, idx, scale), + warmup_iters=5, + benchmark_iters=50, + unit="us", + ) + ref_us = do_bench( + lambda: sparse_attn_ref(q, kv, sink, idx, scale), + warmup_iters=5, + benchmark_iters=50, + unit="us", + ) + try: + fused_us = do_bench( + lambda: fused_sparse_attn(q, kv, idx, scale), + warmup_iters=5, + benchmark_iters=50, + unit="us", + ) + fused_str = f"{fused_us:>10.2f}" + fused_sp = f"{fused_us / pto_us:>9.2f}x" + except Exception as e: # noqa: BLE001 + fused_str = f"{'fail':>10}" + fused_sp = f"{'-':>10}" + if B == SHAPES[0][0] and M == SHAPES[0][1]: + print(f" (fused baseline failed: {e})") + print( + f"{B:>3} {M:>3} {N:>5} {K:>4}" + f" {pto_us:>10.2f} {ref_us:>10.2f} {fused_str}" + f" {ref_us / pto_us:>8.2f}x {fused_sp}" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/aot/deepseek_v4/sparse_attn/caller.cpp b/examples/aot/deepseek_v4/sparse_attn/caller.cpp new file mode 100644 index 00000000..1b2c1ed9 --- /dev/null +++ b/examples/aot/deepseek_v4/sparse_attn/caller.cpp @@ -0,0 +1,17 @@ +#ifndef KERNEL_CPP +#define KERNEL_CPP "sparse_attn.cpp" +#endif +#include KERNEL_CPP + +extern "C" void call_kernel( + uint32_t blockDim, void *stream, + uint8_t *q, uint8_t *kv, uint8_t *o, + uint8_t *attn_sink, uint8_t *topk_idxs, + int32_t B, int32_t M, int32_t N, int32_t TOPK, + float scale) +{ + sparse_attn<<>>( + (__fp16 *)q, (__fp16 *)kv, (__fp16 *)o, + (float *)attn_sink, (int32_t *)topk_idxs, + B, M, N, TOPK, scale); +} diff --git a/examples/aot/deepseek_v4/sparse_attn/compile.sh b/examples/aot/deepseek_v4/sparse_attn/compile.sh new file mode 100644 index 00000000..0022d537 --- /dev/null +++ b/examples/aot/deepseek_v4/sparse_attn/compile.sh @@ -0,0 +1,22 @@ +#!/usr/bin/env bash +set -e +rm -f sparse_attn.pto sparse_attn.cpp sparse_attn_lib.so + +python ./sparse_attn_builder.py > ./sparse_attn.pto +ptoas --enable-insert-sync ./sparse_attn.pto -o ./sparse_attn.cpp + +PTO_LIB_PATH=${PTO_LIB_PATH:-/sources/pto-isa} +bisheng \ + -I${PTO_LIB_PATH}/include \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + ./caller.cpp \ + -o ./sparse_attn_lib.so diff --git a/examples/aot/deepseek_v4/sparse_attn/run_sparse_attn.py b/examples/aot/deepseek_v4/sparse_attn/run_sparse_attn.py new file mode 100644 index 00000000..f65a503f --- /dev/null +++ b/examples/aot/deepseek_v4/sparse_attn/run_sparse_attn.py @@ -0,0 +1,69 @@ +"""Run the deepseek_v4 ``sparse_attn`` PTO kernel and validate against +the reference. Exits non-zero on mismatch.""" + +import sys +from pathlib import Path + +import torch +import torch_npu # noqa: F401 + +from ptodsl.npu_info import get_test_device + +_HERE = Path(__file__).resolve().parent +if str(_HERE) not in sys.path: + sys.path.insert(0, str(_HERE)) + +from sparse_attn_util import ( # noqa: E402 + BLOCK, + D, + H_PAD, + sparse_attn, + sparse_attn_ref, +) + + +def main() -> int: + device = get_test_device() + torch.npu.set_device(device) + torch.manual_seed(0) + scale = 1.0 / (D**0.5) + + cases = [ + # (B, M, N, K, H, sentinel_frac) + (1, 1, BLOCK * 2, BLOCK, H_PAD, 0.0), + (1, 4, BLOCK * 4, BLOCK * 2, H_PAD, 0.0), + (2, 2, BLOCK * 8, BLOCK * 2, H_PAD, 0.0), + (8, 2, BLOCK * 4, BLOCK * 2, H_PAD, 0.0), + # Padded-heads case (TileLang parity: wrapper pads h<16 → 16): + (4, 2, BLOCK * 4, BLOCK * 2, 8, 0.0), + (4, 2, BLOCK * 4, BLOCK, 1, 0.0), + # Sentinel masking: ~25% of top-k slots are -1 (TileLang parity). + (4, 4, BLOCK * 4, BLOCK * 2, H_PAD, 0.25), + (8, 2, BLOCK * 4, BLOCK, 4, 0.25), + ] + for B, M, N, K, H, sentinel_frac in cases: + q = torch.randn(B, M, H, D, device=device).to(torch.float16) + kv = torch.randn(B, N, D, device=device).to(torch.float16) + attn_sink = torch.randn(H, dtype=torch.float32, device=device) + topk_idxs = ( + torch.stack([torch.randperm(N, device=device)[:K] for _ in range(B * M)]) + .reshape(B, M, K) + .to(torch.int32) + ) + if sentinel_frac > 0.0: + mask = torch.rand(B, M, K, device=device) < sentinel_frac + topk_idxs = torch.where(mask, torch.full_like(topk_idxs, -1), topk_idxs) + + o_pto = sparse_attn(q, kv, attn_sink, topk_idxs, scale) + o_ref = sparse_attn_ref(q, kv, attn_sink, topk_idxs, scale) + torch.testing.assert_close(o_pto, o_ref, rtol=5e-3, atol=5e-3) + print( + f"sparse_attn B={B} M={M} N={N} K={K} H={H} " + f"sentinel={sentinel_frac}: OK" + ) + print("sparse_attn: all shapes PASSED") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/aot/deepseek_v4/sparse_attn/sparse_attn_builder.py b/examples/aot/deepseek_v4/sparse_attn/sparse_attn_builder.py new file mode 100644 index 00000000..dcfe74d2 --- /dev/null +++ b/examples/aot/deepseek_v4/sparse_attn/sparse_attn_builder.py @@ -0,0 +1,299 @@ +"""PTO DSL port of TileLang ``sparse_attn`` — full FlashAttention with +indexed (top-k) KV gather. + +GPU semantics (per query (b, m)): + For K top-k indices into ``kv[b, :, :]``, compute ``softmax(qK^T * scale + || sink) @ V`` (sink is per-head additive logit; included in the + softmax denominator but **dropped** from the V mix). + +NPU implementation +------------------ +Pure ``vector_section`` kernel. The matmul shapes are tiny (``[H, D] @ +[D]`` per K-position) so we avoid multifunc cube↔vector pipes and +compute QK / PV incrementally per top-k position via VEC-pipe +``col_expand_mul`` + ``row_sum`` (an outer-product equivalent of matmul). + +Per-head softmax state is **stored as full ``[H, D]`` tiles** +(replicated across the D axis) rather than as ``[H, 1]`` reductions — +this dodges the col-major⇄row-major reshape aliasing that +auto-sync analysis can miss. The replicated form is cheap on the +Ascend vector pipe (one ``row_expand`` to broadcast the per-head +``[H, 1]`` reduction back to ``[H, D]``) and lets every elementwise +softmax op operate on plain row-major tiles. + +KV gather: each of the K positions is loaded individually from GM by + + 1) ``pto.load_scalar(int32, idx_ptr, off)`` — read the index + 2) ``pto.slice_view`` with that dynamic row offset + 3) ``pto.load`` of one ``[1, D]`` row. + +Online streaming softmax with `is_first` initialisation: on iter k=0 +we set ``m_prev = logit``, ``l_run = 1``, ``acc_o = bcast(kv_row)`` +directly; later iterations use the exp-rescaled update. + +Tile shapes (H = H_PAD = 16, D = 128, all VEC, fp32 unless noted): + + q_tile [H, D] fp16 — loaded once per query + q_fp32 [H, D] + kv_row_fp16 [1, D] fp16 — one position at a time + kv_row_fp32 [1, D] + kv_row_HD [H, D] — kv_row broadcast across heads + tmp_HD [H, D] — q*kv_row scratch / outer scratch + acc_o [H, D] — running output + out_tile [H, D] fp16 + logit_col [H, 1] col — per-head dot-product result + red_tmp [H, 1] col — row_sum scratch + sink_col [H, 1] col — attn_sink loaded once + + m_prev_HD [H, D] ┐ + m_new_HD [H, D] │ per-head softmax stats, replicated + exp_diff_HD [H, D] │ across the D axis (m_prev_HD[h, d] + p_HD [H, D] │ is the same scalar for every d). + l_run_HD [H, D] ┘ +""" + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + +const = s.const + +H_PAD = 16 +D = 128 +BLOCK = 64 # exported for util/test; no internal tiling effect + + +def meta_data(): + fp16 = pto.float16 + fp32 = pto.float32 + i32 = pto.int32 + + ptr_fp16 = pto.PtrType(fp16) + ptr_fp32 = pto.PtrType(fp32) + ptr_i32 = pto.PtrType(i32) + + tv_fp16_2d = pto.TensorType(rank=2, dtype=fp16) + tv_fp32_2d = pto.TensorType(rank=2, dtype=fp32) + + sv_qrow = pto.SubTensorType(shape=[H_PAD, D], dtype=fp16) + sv_kvrow = pto.SubTensorType(shape=[1, D], dtype=fp16) + sv_sink_col = pto.SubTensorType(shape=[H_PAD, 1], dtype=fp32) + sv_orow = pto.SubTensorType(shape=[H_PAD, D], dtype=fp16) + + row_cfg = pto.TileBufConfig() + # row_sum dst must be col-major [H, 1] (NoneBox). row_expand_mul + # accepts col-major src1 (verified in fa_builder). We avoid plain + # `row_expand` (which needs row-major src) by broadcasting via + # row_expand_mul with a constant-ones tile. + col_cfg = pto.TileBufConfig(blayout="ColMajor", slayout="NoneBox") + + tile_q_fp16 = pto.TileBufType( + shape=[H_PAD, D], dtype=fp16, memory_space="VEC", config=row_cfg + ) + tile_HD_fp32 = pto.TileBufType( + shape=[H_PAD, D], dtype=fp32, memory_space="VEC", config=row_cfg + ) + tile_o_fp16 = pto.TileBufType( + shape=[H_PAD, D], dtype=fp16, memory_space="VEC", config=row_cfg + ) + tile_kv_fp16 = pto.TileBufType( + shape=[1, D], dtype=fp16, memory_space="VEC", config=row_cfg + ) + tile_kv_fp32 = pto.TileBufType( + shape=[1, D], dtype=fp32, memory_space="VEC", config=row_cfg + ) + tile_col_stat = pto.TileBufType( + shape=[H_PAD, 1], + dtype=fp32, + memory_space="VEC", + config=col_cfg, + ) + return locals() + + +@to_ir_module(meta_data=meta_data) +def sparse_attn( + q_ptr: "ptr_fp16", + kv_ptr: "ptr_fp16", + o_ptr: "ptr_fp16", + sink_ptr: "ptr_fp32", + idx_ptr: "ptr_i32", + B_i32: "i32", + M_i32: "i32", + N_i32: "i32", + TOPK_i32: "i32", + scale_f32: "fp32", +) -> None: + c0 = const(0) + c1 = const(1) + cH = const(H_PAD) + cD = const(D) + + f0 = const(0.0, s.float32) + f1 = const(1.0, s.float32) + # Use a very large finite negative as -inf substitute (PTO C++ codegen + # does not currently emit a valid literal for IEEE -inf). + f_neg_inf = const(-1.0e30, s.float32) + i32_zero = const(0, s.int32) + + B = s.index_cast(B_i32) + M = s.index_cast(M_i32) + N = s.index_cast(N_i32) + TOPK = s.index_cast(TOPK_i32) + + with pto.vector_section(): + cid = pto.get_block_idx() + sub_bid = pto.get_subblock_idx() + sub_bnum = pto.get_subblock_num() + num_blocks = pto.get_block_num() + wid = s.index_cast(cid * sub_bnum + sub_bid) + ncores = s.index_cast(num_blocks * sub_bnum) + + total = B * M + + # --- GM tensor views ---------------------------------------- + tvQ = pto.as_tensor( + tv_fp16_2d, ptr=q_ptr, shape=[total * cH, cD], strides=[cD, c1] + ) + tvKV = pto.as_tensor( + tv_fp16_2d, ptr=kv_ptr, shape=[B * N, cD], strides=[cD, c1] + ) + tvO = pto.as_tensor( + tv_fp16_2d, ptr=o_ptr, shape=[total * cH, cD], strides=[cD, c1] + ) + # Sink fp32 [H_PAD] viewed as [H_PAD, 1] (1D contiguous == [H, 1] + # col-stride-1 == [H, 1] row-stride-1). + tv_sink = pto.as_tensor( + tv_fp32_2d, ptr=sink_ptr, shape=[cH, c1], strides=[c1, c1] + ) + + # --- Tile buffers ------------------------------------------- + q_tile = pto.alloc_tile(tile_q_fp16) + q_fp32 = pto.alloc_tile(tile_HD_fp32) + kv_row_fp16 = pto.alloc_tile(tile_kv_fp16) + kv_row_fp32 = pto.alloc_tile(tile_kv_fp32) + kv_row_HD = pto.alloc_tile(tile_HD_fp32) + + tmp_HD = pto.alloc_tile(tile_HD_fp32) + acc_o = pto.alloc_tile(tile_HD_fp32) + # All-ones [H, D] used to broadcast [H, 1] col-major reductions + # via row_expand_mul (which accepts col-major src1, unlike plain + # row_expand which requires row-major src). + ones_HD = pto.alloc_tile(tile_HD_fp32) + + # Per-head softmax stats, replicated across D so all elementwise + # ops stay on row-major [H, D] tiles. + m_prev_HD = pto.alloc_tile(tile_HD_fp32) + m_new_HD = pto.alloc_tile(tile_HD_fp32) + exp_diff_HD = pto.alloc_tile(tile_HD_fp32) + p_HD = pto.alloc_tile(tile_HD_fp32) + l_run_HD = pto.alloc_tile(tile_HD_fp32) + + logit_col = pto.alloc_tile(tile_col_stat) + # row_sum tmp must be same shape/layout as the src tile. + red_tmp = pto.alloc_tile(tile_HD_fp32) + sink_col = pto.alloc_tile(tile_col_stat) + + out_tile = pto.alloc_tile(tile_o_fp16) + + # --- Sink load (once per worker) ---------------------------- + sink_view = pto.slice_view( + sv_sink_col, source=tv_sink, offsets=[c0, c0], sizes=[cH, c1] + ) + pto.load(sink_view, sink_col) + + # --- Per-query loop ----------------------------------------- + for bm in pto.range(wid, total, ncores): + b = bm // M + + # Load Q[bm*H : bm*H + H, :] in fp16, then cvt to fp32. + q_off = bm * cH + q_view = pto.slice_view( + sv_qrow, source=tvQ, offsets=[q_off, c0], sizes=[cH, cD] + ) + pto.load(q_view, q_tile) + tile.cvt(q_tile, q_fp32) + # Build ones_HD from q_fp32 (guaranteed finite – randn fp16 + # has no NaN/Inf), so `q_fp32 * 0 + 1 = 1` everywhere. + tile.muls(q_fp32, f0, ones_HD) + tile.adds(ones_HD, f1, ones_HD) + + # Initialise running stats (canonical FlashAttention init): + # m_prev = -inf, l_run = 0, acc_o = 0. + # This handles the `idx == -1` sentinel cleanly: skipped + # iterations leave m_prev=-inf so the first valid position + # naturally produces exp(-inf - finite)=0 and exp(0)=1, i.e. + # acc_o = kv_row_HD, l_run = 1, m_prev = logit (correct). + # If ALL idx are -1, l_run += exp(sink-(-inf))=+inf, and + # acc_o / +inf = 0 (matches reference). + tile.muls(ones_HD, f0, l_run_HD) + tile.muls(ones_HD, f0, acc_o) + tile.muls(ones_HD, f0, m_prev_HD) + tile.adds(m_prev_HD, f_neg_inf, m_prev_HD) + + idx_base = bm * TOPK + kv_base = b * N + + for k in pto.range(c0, TOPK, c1): + # ---- Gather one KV row by index (skip if -1) ------ + idx_off = idx_base + k + idx_i32 = pto.load_scalar(s.int32, idx_ptr, idx_off) + is_valid = s.ge(idx_i32, i32_zero) + with pto.if_context(is_valid): + idx_idx = s.index_cast(idx_i32) + kv_row_off = kv_base + idx_idx + kv_view = pto.slice_view( + sv_kvrow, + source=tvKV, + offsets=[kv_row_off, c0], + sizes=[c1, cD], + ) + pto.load(kv_view, kv_row_fp16) + tile.cvt(kv_row_fp16, kv_row_fp32) + + # Broadcast kv_row [1, D] → kv_row_HD [H, D]. + tile.col_expand_mul(ones_HD, kv_row_fp32, kv_row_HD) + + # QK: logit_col[h] = (q · kv_row)[h] * scale. + tile.col_expand_mul(q_fp32, kv_row_fp32, tmp_HD) + tile.row_sum(tmp_HD, red_tmp, logit_col) + # Broadcast logit_col [H, 1] col-major → [H, D] row-major. + tile.row_expand_mul(ones_HD, logit_col, tmp_HD) + tile.muls(tmp_HD, scale_f32, tmp_HD) + # tmp_HD now holds logit replicated across D. + + # Online softmax update. + # m_new = max(m_prev, logit) (per-head, replicated) + tile.max(m_prev_HD, tmp_HD, m_new_HD) + # exp_diff = exp(m_prev - m_new) + tile.sub(m_prev_HD, m_new_HD, exp_diff_HD) + tile.exp(exp_diff_HD, exp_diff_HD) + # p = exp(logit - m_new) + tile.sub(tmp_HD, m_new_HD, p_HD) + tile.exp(p_HD, p_HD) + # l_run = exp_diff * l_run + p + tile.mul(l_run_HD, exp_diff_HD, l_run_HD) + tile.add(l_run_HD, p_HD, l_run_HD) + # acc_o = exp_diff * acc_o + p * kv_row_HD + tile.mul(acc_o, exp_diff_HD, acc_o) + tile.mul(p_HD, kv_row_HD, tmp_HD) + tile.add(acc_o, tmp_HD, acc_o) + # m_prev = m_new + tile.muls(m_new_HD, f1, m_prev_HD) + + # --- Finalise: l_run += exp(sink - m_prev); acc_o /= l_run. + # Broadcast sink_col [H, 1] → tmp_HD [H, D] via row_expand_mul. + tile.row_expand_mul(ones_HD, sink_col, tmp_HD) + tile.sub(tmp_HD, m_prev_HD, exp_diff_HD) + tile.exp(exp_diff_HD, exp_diff_HD) + tile.add(l_run_HD, exp_diff_HD, l_run_HD) + tile.div(acc_o, l_run_HD, acc_o) + + tile.cvt(acc_o, out_tile) + o_view = pto.slice_view( + sv_orow, source=tvO, offsets=[q_off, c0], sizes=[cH, cD] + ) + pto.store(out_tile, o_view) + + +if __name__ == "__main__": + print(sparse_attn) diff --git a/examples/aot/deepseek_v4/sparse_attn/sparse_attn_util.py b/examples/aot/deepseek_v4/sparse_attn/sparse_attn_util.py new file mode 100644 index 00000000..ebcab57b --- /dev/null +++ b/examples/aot/deepseek_v4/sparse_attn/sparse_attn_util.py @@ -0,0 +1,133 @@ +"""Reference + ctypes wrapper for the deepseek_v4 ``sparse_attn`` PTO kernel. + +The NPU kernel implements full FlashAttention with indexed (top-k) KV +gather; this module provides the GPU-semantics reference and the ctypes +shim that the test uses to invoke the compiled kernel. + +Per-(b, m) attention with top-k sparse KV: + q[b, m, h, d] fp16 H_PAD heads (padded), D = 128 + kv[b, n, d] fp16 one KV head, N positions per batch + o[b, m, h, d] fp16 + attn_sink[H_PAD] fp32 (per-head additive sink-logit) + topk_idxs[b, m, K] int32 (indices into the KV n-axis) +""" + +import ctypes +from pathlib import Path + +import torch + + +_HERE = Path(__file__).resolve().parent +_KERNEL_SO = _HERE / "sparse_attn_lib.so" + +H_PAD = 16 +D = 128 +BLOCK = 64 + + +def sparse_attn_ref( + q: torch.Tensor, # [B, M, H, D] fp16 (H may be < H_PAD) + kv: torch.Tensor, # [B, N, D] fp16 + attn_sink: torch.Tensor, # [H] fp32 + topk_idxs: torch.Tensor, # [B, M, K] int32 (-1 marks invalid slot) + scale: float, +) -> torch.Tensor: + B, M, H, Dq = q.shape + Bk, N, Dk = kv.shape + assert (B, Dq) == (Bk, Dk) and H <= H_PAD and Dq == D + K = topk_idxs.shape[-1] + + qf = q.to(torch.float32) + kf = kv.to(torch.float32) + + out = torch.zeros_like(qf) + for b in range(B): + for m in range(M): + raw_idx = topk_idxs[b, m] # [K] int32, may be -1 + invalid = raw_idx == -1 # [K] + safe_idx = raw_idx.clone().to(torch.long) + safe_idx[invalid] = 0 # avoid OOB gather + kv_sel = kf[b, safe_idx] # [K, D] + logits = (qf[b, m] @ kv_sel.T) * scale # [H, K] + logits[:, invalid] = float("-inf") # mask sentinel slots + sink = attn_sink.to(torch.float32).view(H, 1) # [H, 1] + logits_full = torch.cat([logits, sink], dim=-1) # [H, K+1] + p = torch.softmax(logits_full, dim=-1) + p_kv = p[:, :K] # [H, K] + # Zero contribution from invalid slots (softmax already gave 0 + # because logits were -inf, but be defensive against NaN). + p_kv = torch.where(invalid.view(1, K), torch.zeros_like(p_kv), p_kv) + out[b, m] = p_kv @ kv_sel # [H, D] + return out.to(torch.float16) + + +_ARGTYPES = [ + ctypes.c_uint32, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_int32, + ctypes.c_int32, + ctypes.c_int32, + ctypes.c_int32, + ctypes.c_float, +] + + +_lib = None + + +def _load(): + global _lib + if _lib is None: + if not _KERNEL_SO.is_file(): + raise FileNotFoundError( + f"Kernel shared library not found: {_KERNEL_SO}\n" + f"Build first: cd {_HERE} && ./compile.sh" + ) + _lib = ctypes.CDLL(str(_KERNEL_SO)) + _lib.call_kernel.argtypes = _ARGTYPES + _lib.call_kernel.restype = None + return _lib + + +def sparse_attn(q, kv, attn_sink, topk_idxs, scale: float): + assert q.is_npu and kv.is_npu and q.dtype == kv.dtype == torch.float16 + B, M, H, Dq = q.shape + Bk, N, Dk = kv.shape + K = topk_idxs.shape[-1] + assert H <= H_PAD and Dq == D + orig_H = H + # Pad heads to H_PAD: kernel statically expects H == H_PAD == 16. + if H < H_PAD: + pad_q = q.new_zeros(B, M, H_PAD - H, Dq) + q = torch.cat([q, pad_q], dim=2).contiguous() + pad_sink = attn_sink.new_zeros(H_PAD - H) + attn_sink = torch.cat([attn_sink, pad_sink]).contiguous() + H = H_PAD + o = torch.empty_like(q) + lib = _load() + dev = torch.npu.current_device() + blk = torch.npu.get_device_properties(dev).cube_core_num + lib.call_kernel( + blk, + torch.npu.current_stream()._as_parameter_, + ctypes.c_void_p(q.data_ptr()), + ctypes.c_void_p(kv.data_ptr()), + ctypes.c_void_p(o.data_ptr()), + ctypes.c_void_p(attn_sink.data_ptr()), + ctypes.c_void_p(topk_idxs.data_ptr()), + ctypes.c_int32(B), + ctypes.c_int32(M), + ctypes.c_int32(N), + ctypes.c_int32(K), + ctypes.c_float(scale), + ) + torch.npu.synchronize() + if orig_H < H_PAD: + o = o.narrow(2, 0, orig_H).contiguous() + return o