Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions examples/aot/deepseek_v4/OVERVIEW.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# DeepSeek-V4 PTO ports — overview

> This file is intentionally **not named `README.md`** so that
> [`validate_all_examples.py`](../../validate_all_examples.py) walks
> into each kernel sub-directory directly instead of trying to run a
> repo-level recipe from here.

PTO DSL ports of the six custom kernels used by the DeepSeek-V4
reference implementation. Every kernel is self-contained in its own
folder and follows the standard examples-tree workflow:

1. `bash ./compile.sh` — emits `.pto` → `.cpp` → `*_lib.so`.
2. `python ./run_*.py` — runs the kernel on NPU and asserts numerical
equivalence with a PyTorch reference (exits non-zero on mismatch).
3. (optional) `python ./bench_*.py` — microbenchmarks vs PyTorch
baselines (only `sparse_attn/` and `hc_split_sinkhorn/`).

## Kernels

| Folder | What it does | Pipe(s) |
|---|---|---|
| [act_quant/](act_quant/) | Per-row absmax fp16 → int8 quant (`max(|x|)/127`, `round(x/scale)`) | vector |
| [fp4_act_quant/](fp4_act_quant/) | Per-row fp16 → mxfp4 (e2m1) quant with shared exponent + lookup-table cast | vector |
| [fp8_gemm/](fp8_gemm/) | Per-channel fp8 (e4m3) GEMM with host-side fused `Sa`/`Sb` pre-scale | cube + vector |
| [fp4_gemm/](fp4_gemm/) | Per-channel fp4 (e2m1) GEMM with host-side fused `Sa`/`Sb` pre-scale | cube + vector |
| [hc_split_sinkhorn/](hc_split_sinkhorn/) | Fused MoE-router head: pre/post sigmoid + 20-iter Sinkhorn, all on-device | vector |
| [sparse_attn/](sparse_attn/) | FlashAttention with indexed top-k KV gather + per-head sink logit | vector |

## Run a single kernel

```bash
cd examples/aot/deepseek_v4/sparse_attn
bash ./compile.sh
python ./run_sparse_attn.py
```

The generated `.pto`, `.cpp`, `.so` files are gitignored.

## Run all of them

From the repo root:

```bash
python examples/validate_all_examples.py
```

This walks every `README.md` under `examples/`, runs the bash block in
each, and reports pass/fail. The deepseek_v4 kernels appear in the
listing as e.g. `aot/deepseek_v4/sparse_attn`.

## Sample bench output

`sparse_attn/`, vs `torch.gather` + `npu_fused_infer_attention_score`
(MQA mode, sink logit dropped — speed baseline only):

```
B M N K pto us ref us fused us pto/ref pto/fused
------------------------------------------------------------------------
1 1 128 64 161.15 533.05 265.03 3.31x 1.64x
1 4 256 128 209.56 1692.93 252.36 8.08x 1.20x
4 4 1024 128 207.77 6071.60 246.57 29.22x 1.19x
8 8 2048 128 304.49 24658.49 244.67 80.98x 0.80x
```

`hc_split_sinkhorn/`, vs eager PyTorch reference:

```
n pto us ref us speedup
----------------------------------------
64 173.27 2803.42 16.18x
1024 218.70 2761.33 12.63x
16384 1786.32 2741.09 1.53x
```

## Implementation notes

- **`fp8_gemm` / `fp4_gemm`** — the GPU op fuses an outer `Sa[m] * Sb[n]`
per-channel rescale into the GEMM. The PTO kernels keep the matmul
pure (cube fp32 accum → fp16 cast) and instead **pre-scale `A` on the
host** by the per-row factor, leaving a clean per-output-channel `Sb`
to apply on the vector pipe. Avoids two extra cube fragments per tile
and matches reference within 5 × 10⁻³ relative error.
- **`hc_split_sinkhorn`** — all three router heads (pre / post / 20-iter
Sinkhorn over `[n, 4, 4]`) run inside one `vector_section`. ε is added
once after the initial softmax to match the reference order exactly.
- **`sparse_attn`** — pure `vector_section` FlashAttention with online
streaming softmax. The matmul shapes (`[16, 128] · [128]` per K
position, K ≤ 128) are too small to amortize cube launch overhead, and
KV is gathered by arbitrary index so it cannot live in L1 contiguously
anyway. Per-head softmax stats are stored as full `[H, D]` tiles
replicated across the D axis to dodge a col-major⇄row-major reshape
alias that auto-sync analysis can otherwise miss. KV gather uses
`pto.load_scalar` of the index → `pto.slice_view` with that dynamic
row offset → `pto.load` of one `[1, D]` row.
10 changes: 10 additions & 0 deletions examples/aot/deepseek_v4/act_quant/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Generated build artifacts (compile.sh outputs)
act_quant.pto
act_quant.cpp
act_quant_lib.so

# Python cache
__pycache__/

# Benchmark scratch
perf_data/
9 changes: 9 additions & 0 deletions examples/aot/deepseek_v4/act_quant/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# act_quant — fp16 → int8 per-row absmax quantization

PTO DSL port of the deepseek_v4 `act_quant` op. Per row computes
`scale = max(|x|) / 127`, then `y = round(x / scale)`.

```bash
bash ./compile.sh
python ./run_act_quant.py
```
147 changes: 147 additions & 0 deletions examples/aot/deepseek_v4/act_quant/act_quant_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
"""PTO DSL port of TileLang act_quant kernel.

Original (GPU): block-wise FP8 quantization, BF16 -> FP8(e4m3) with FP32
or E8M0 per-block scale. inplace=True does fused quant+dequant back to BF16.

NPU port: BF16/FP8 are not native to PTO; we use FP16 input -> int8 output
with FP32 per-block scale. The shape contract matches the original:

X: [M, N] fp16
Y: [M, N] int8 (quantized) or fp16 (inplace dequant)
S: [M, N/B] fp32 per-block reciprocal scale

`block_size` is the per-row group size on the K-dim (last axis).
"""

from ptodsl import pto, tile, to_ir_module
from ptodsl import scalar as s

const = s.const

BLOCK_SIZE = 128 # K-dim group size; matches GPU `block_size`
BLK_M = 32 # rows per tile (matches GPU `blk_m`)
INT8_MAX = 127.0


def meta_data():
fp16 = pto.float16
fp32 = pto.float32
i8 = pto.int8
i32 = pto.int32

ptr_fp16 = pto.PtrType(fp16)
ptr_i8 = pto.PtrType(i8)
ptr_fp32 = pto.PtrType(fp32)

tv_fp16 = pto.TensorType(rank=2, dtype=fp16)
tv_i8 = pto.TensorType(rank=2, dtype=i8)
tv_fp32 = pto.TensorType(rank=2, dtype=fp32)

sv_fp16 = pto.SubTensorType(shape=[BLK_M, BLOCK_SIZE], dtype=fp16)
sv_i8 = pto.SubTensorType(shape=[BLK_M, BLOCK_SIZE], dtype=i8)
sv_scale = pto.SubTensorType(shape=[BLK_M, 1], dtype=fp32)

row_cfg = pto.TileBufConfig()
col_cfg = pto.TileBufConfig(blayout="ColMajor")

tile_fp16 = pto.TileBufType(
shape=[BLK_M, BLOCK_SIZE], dtype=fp16, memory_space="VEC"
)
tile_fp32 = pto.TileBufType(
shape=[BLK_M, BLOCK_SIZE], dtype=fp32, memory_space="VEC"
)
tile_i8 = pto.TileBufType(shape=[BLK_M, BLOCK_SIZE], dtype=i8, memory_space="VEC")
tile_amax = pto.TileBufType(
shape=[BLK_M, 1], dtype=fp32, memory_space="VEC", config=col_cfg
)

return locals()


@to_ir_module(meta_data=meta_data)
def act_quant(
x_ptr: "ptr_fp16",
y_ptr: "ptr_i8",
s_ptr: "ptr_fp32",
M_i32: "i32",
N_i32: "i32",
) -> None:
c0 = const(0)
c1 = const(1)
cBM = const(BLK_M)
cBK = const(BLOCK_SIZE)
inv_max = const(1.0 / INT8_MAX, s.float32)

M = s.index_cast(M_i32)
N = s.index_cast(N_i32)
nblk_n = s.ceil_div(N, cBK)

with pto.vector_section():
cid = pto.get_block_idx()
sub_bid = pto.get_subblock_idx()
sub_bnum = pto.get_subblock_num()
num_blocks = pto.get_block_num()
vid = s.index_cast(cid * sub_bnum + sub_bid)
ncores = s.index_cast(num_blocks * sub_bnum)

nblk_m = s.ceil_div(M, cBM)
total_blocks = nblk_m * nblk_n

tv_x = pto.as_tensor(tv_fp16, ptr=x_ptr, shape=[M, N], strides=[N, c1])
tv_y = pto.as_tensor(tv_i8, ptr=y_ptr, shape=[M, N], strides=[N, c1])
# Scale layout is COL-MAJOR in memory (strides=[1, M]) so that a
# [BLK_M, 1] col-major amax tile maps to a contiguous 32-element
# write at offset `blk_n * M + row_off`.
tv_s = pto.as_tensor(tv_fp32, ptr=s_ptr, shape=[M, nblk_n], strides=[c1, M])

tb_x = pto.alloc_tile(tile_fp16)
tb_xf = pto.alloc_tile(tile_fp32)
tb_abs = pto.alloc_tile(tile_fp32)
tb_tmp = pto.alloc_tile(tile_fp32)
tb_amax = pto.alloc_tile(tile_amax)
tb_y = pto.alloc_tile(tile_i8)

with pto.if_context(vid < total_blocks):
for bi in pto.range(vid, total_blocks, ncores):
blk_m = bi // nblk_n
blk_n = bi % nblk_n
row_off = blk_m * cBM
col_off = blk_n * cBK

sv_x = pto.slice_view(
sv_fp16,
source=tv_x,
offsets=[row_off, col_off],
sizes=[cBM, cBK],
)
sv_y = pto.slice_view(
sv_i8,
source=tv_y,
offsets=[row_off, col_off],
sizes=[cBM, cBK],
)
sv_s = pto.slice_view(
sv_scale,
source=tv_s,
offsets=[row_off, blk_n],
sizes=[cBM, c1],
)

pto.load(sv_x, tb_x)
tile.cvt(tb_x, tb_xf) # fp16 -> fp32
tile.abs(tb_xf, tb_abs) # |x|
tile.row_max(tb_abs, tb_tmp, tb_amax) # amax per row
# scale = amax / 127 (fp32 reciprocal-style scale)
tile.muls(tb_amax, inv_max, tb_amax)
# y = x / scale, then cvt -> fp16 -> i8 (NPU has no direct
# fp32->i8 cvt; routing through fp16 matches the existing
# quant_dynamic_multicore example).
tile.row_expand_div(tb_xf, tb_amax, tb_xf)
tile.cvt(tb_xf, tb_x) # fp32 -> fp16 (reuse tb_x)
tile.cvt(tb_x, tb_y, rmode="round") # fp16 -> int8
pto.store(tb_y, sv_y)
pto.store(tb_amax, sv_s)


if __name__ == "__main__":
print(act_quant)
95 changes: 95 additions & 0 deletions examples/aot/deepseek_v4/act_quant/act_quant_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Reference + ctypes wrapper for the deepseek_v4 ``act_quant`` PTO kernel.

Reference matches the GPU TileLang behaviour adapted to the NPU port:
FP16 input, int8 output, FP32 per-row-block reciprocal scale, K-group=128.
"""

import ctypes
from pathlib import Path

import torch


_HERE = Path(__file__).resolve().parent
_KERNEL_SO = _HERE / "act_quant_lib.so"

BLOCK_SIZE = 128
INT8_MAX = 127.0


def act_quant_ref(x: torch.Tensor, block_size: int = BLOCK_SIZE):
"""Reference: per-row-block symmetric int8 quant.

``x``: [M, N] fp16, N % block_size == 0.
Returns ``(y_int8 [M, N], s_fp32 [M, N // block_size])`` on the same device.
"""
assert x.dtype == torch.float16, "fp16 input expected"
assert x.dim() == 2 and x.shape[1] % block_size == 0
M, N = x.shape
nb = N // block_size

x_f32 = x.to(torch.float32).reshape(M, nb, block_size)
amax = x_f32.abs().amax(dim=-1, keepdim=False) # [M, nb]
scale = (amax / INT8_MAX).clamp(min=1e-12) # avoid /0
y = (x_f32 / scale.unsqueeze(-1)).round().clamp(-127, 127)
y_i8 = y.to(torch.int8).reshape(M, N)
return y_i8, scale.to(torch.float32)


_ARGTYPES = [
ctypes.c_uint32,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_int32,
ctypes.c_int32,
]


def _missing_msg() -> str:
return (
f"Kernel shared library not found: {_KERNEL_SO}\n"
"Build it first:\n"
f" cd {_HERE} && ./compile.sh"
)


_lib = None


def _load():
global _lib
if _lib is None:
if not _KERNEL_SO.is_file():
raise FileNotFoundError(_missing_msg())
_lib = ctypes.CDLL(str(_KERNEL_SO))
_lib.call_kernel.argtypes = _ARGTYPES
_lib.call_kernel.restype = None
return _lib


def act_quant(x: torch.Tensor):
"""Run the PTO kernel. ``x``: [M, N] fp16 NPU tensor; N % BLOCK_SIZE == 0."""
assert x.is_npu and x.dtype == torch.float16
M, N = x.shape
assert N % BLOCK_SIZE == 0
y = torch.empty((M, N), dtype=torch.int8, device=x.device)
# Kernel writes scale in COL-MAJOR layout (strides=[1, M]).
# Allocate as a transpose of a contiguous [N//B, M] tensor.
s_storage = torch.empty((N // BLOCK_SIZE, M), dtype=torch.float32, device=x.device)
s = s_storage.t() # logical shape [M, N//BLOCK_SIZE], strides [1, M]
lib = _load()
dev = torch.npu.current_device()
blk = torch.npu.get_device_properties(dev).cube_core_num
lib.call_kernel(
blk,
torch.npu.current_stream()._as_parameter_,
ctypes.c_void_p(x.data_ptr()),
ctypes.c_void_p(y.data_ptr()),
ctypes.c_void_p(s.data_ptr()),
ctypes.c_int32(M),
ctypes.c_int32(N),
)
torch.npu.synchronize()
return y, s
13 changes: 13 additions & 0 deletions examples/aot/deepseek_v4/act_quant/caller.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#ifndef KERNEL_CPP
#define KERNEL_CPP "act_quant.cpp"
#endif
#include KERNEL_CPP

extern "C" void call_kernel(
uint32_t blockDim, void *stream,
uint8_t *x, uint8_t *y, uint8_t *scale,
int32_t M, int32_t N)
{
act_quant<<<blockDim, nullptr, stream>>>(
(__fp16 *)x, (int8_t *)y, (float *)scale, M, N);
}
22 changes: 22 additions & 0 deletions examples/aot/deepseek_v4/act_quant/compile.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#!/usr/bin/env bash
set -e
rm -f act_quant.pto act_quant.cpp act_quant_lib.so

python ./act_quant_builder.py > ./act_quant.pto
ptoas --enable-insert-sync ./act_quant.pto -o ./act_quant.cpp

PTO_LIB_PATH=${PTO_LIB_PATH:-/sources/pto-isa}
bisheng \
-I${PTO_LIB_PATH}/include \
-fPIC -shared -D_FORTIFY_SOURCE=2 -O2 -std=c++17 \
-Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \
-xcce -Xhost-start -Xhost-end \
-mllvm -cce-aicore-stack-size=0x8000 \
-mllvm -cce-aicore-function-stack-size=0x8000 \
-mllvm -cce-aicore-record-overflow=true \
-mllvm -cce-aicore-addr-transform \
-mllvm -cce-aicore-dcci-insert-for-scalar=false \
--npu-arch=dav-2201 -DMEMORY_BASE \
-std=gnu++17 \
./caller.cpp \
-o ./act_quant_lib.so
Loading
Loading