[WIP] feat: Add PTO-DSL Flash Attention performance kernel with Python validation runner#117
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a pto-dsl version of Flash Attention, including the kernel builder, host shim, and benchmarking infrastructure. The review identified several high-priority issues: missing error handling for rtGetC2cCtrlAddr which could lead to null pointer dereferences, and buffer aliasing that undermines the software pipeline's performance and safety. Additionally, there is a discrepancy between the implemented tile size and the documentation, and several instances where tile.muls should be replaced with tile.mov for better efficiency.
| (void)rtGetC2cCtrlAddr(reinterpret_cast<uint64_t *>(&fftsAddr), &fftsLen); | ||
| (void)fftsLen; |
There was a problem hiding this comment.
| k_mat = [k_mat_a, k_mat_a] | ||
| k_right = [k_right_a, k_right_a] | ||
| qk_acc = [qk_acc_a, qk_acc_a] | ||
| p_recv = [p_recv_a, p_recv_a] | ||
| p_left = [p_left_a, p_left_a] | ||
| v_mat = [v_mat_a, v_mat_a] | ||
| v_right = [v_right_a, v_right_a] | ||
| pv_acc = [pv_acc_a, pv_acc_a] |
There was a problem hiding this comment.
Aliasing multiple logical buffers (e.g., k_mat[0] and k_mat[1]) to the same physical tile (k_mat_a) effectively disables software pipelining and double-buffering. In a kernel designed for performance with QK_PRELOAD=2, this creates a race condition where subsequent loads overwrite data still in flight, or forces serialization that defeats the purpose of the pipeline. Given that S1_TILE=32 uses very little L1 memory (~8KB per tile), you should allocate separate tiles to enable true ping-pong buffering.
| S0 = 128 | ||
| S0_HALF = S0 // 2 | ||
| HEAD = 128 | ||
| S1_TILE = 32 |
There was a problem hiding this comment.
There is a significant inconsistency between the code (S1_TILE = 32) and the PR description/comments (e.g., lines 13, 28, 32, 58) which claim TILE_S1=256. While a comment at line 97 suggests 32 was chosen to avoid VEC UB overflow, this deviation means the kernel is not achieving the stated goal of parity with the manual performance kernel. If 256 is the target, this constant should be updated; otherwise, the documentation and PR summary should be corrected to reflect the actual implementation.
|
|
||
| if is_init: | ||
| tile.row_expand_sub(qk, local_max, p_fp32) | ||
| tile.muls(local_max_r, f32_one, running_max_r) |
There was a problem hiding this comment.
Using tile.muls with a multiplier of 1.0 to perform a tile copy is less efficient and less readable than a direct move. Since tile.mov is used elsewhere in this file (e.g., line 503), it should be used here as well.
| tile.muls(local_max_r, f32_one, running_max_r) | |
| tile.mov(local_max_r, running_max_r) |
| # exp_max = m_prev - m_new (rescale factor for running_sum / O) | ||
| tile.sub(running_max_r, local_max_r, exp_max_r) | ||
| # running_max <- m_new | ||
| tile.muls(local_max_r, f32_one, running_max_r) |
b4545d9 to
f42198c
Compare
f42198c to
c461b41
Compare
|
Triage review (2026-05-08): this PR is mergeable from a repository-state perspective: GitHub reports it as clean against Before merging, please resolve the readiness signal: the title still says No blocking conflict found in this triage pass; this still deserves normal owner review of the PTO-DSL kernel semantics and the NPU benchmark claims before merge. |
|
Added |
35b35de to
871cceb
Compare
Upstream sync — issues / PRs filed for the long-term designThe local fixes that make
|
| Workaround | Removable when |
|---|---|
compile.sh regex s/TPipe<..., 8, 8, false>/TPipe<..., 8, 2, false>/ |
hw-native-sys/PTOAS#650 lands and fa_builder.py writes local_slot_num=2 on the IR. |
os.environ.setdefault("PTODSL_SKIP_VERIFY", "1") in fa_builder.py |
The MLIR Python binding/verifier shipped with pto-dsl catches up to the address-based init shape (the additive API in PTO-ISA/pto-dsl#8 is the prerequisite, but the binding refresh is a follow-up). |
canonicalize_ptoas_ir(text) text rewrite |
pto-dsl's printer emits the custom assembly form for the new ops natively (also a binding follow-up after PTO-ISA/pto-dsl#8). |
The kernel logic itself — entry-carrying tfree_from_aic/aiv, address-based pipes, SlotNum=8 / LocalSlotNum=2, TILE_S1=256 / QK_PRELOAD=4 — is the long-term design and stays.
6980ecf to
882d1b2
Compare
…ith TILE_S1=256/QK_PRELOAD=4 Replace the earlier flash_atten-v1/ scaffold and the parity flash_atten-v2/ experiment with a single PTO-DSL flash-attention kernel that reaches manual parity end-to-end, drops the bench launcher cleanly through ptoas + bisheng, and runs the full 1024/2048/8192 sweep on 910B2 / 24 cubes. Kernel (kernels/fa_builder.py): - TILE_S1=256, CUBE_S1=128, QK_PRELOAD=4 mirroring kernels/manual/common/flash_atten/fa_performance_kernel.cpp. - Per-row_slice softmax/GU loops keep the working tile at 32 KiB, so three fp32 working tiles co-exist with pv/o tiles in 192 KiB UB at TILE_S1=256. - 4-slot QK preload with a paired exp_max ring; steady-state group of 4 matches the manual schedule. - Address-based pipes (PR #606): cube/vec init use gm_slot_tensor and pop/free carry the tensor_view entry, restoring real free-notification semantics for the GM FIFO at long S1 (was the S1>=4096 hang root cause). compile.sh: - Build N=4/8/32 variants by default (S1=1024/2048/8192). - Post-process LocalSlotNum=8 -> 2 in the ptoas-generated C++; required while ptoas's address-based pipe lowering ignores local_slot_num on the IR. This is a temporary patch and will be removed once PTOAS lowers gm_slot_tensor + local_slot_num correctly (filed upstream). DSL temporary shims in fa_builder.py (to be removed once pto-dsl ships the new dialect): - PTODSL_SKIP_VERIFY=1: the installed Python MLIR verifier predates the gm_slot_tensor address-based init shape; ptoas accepts it. - canonicalize_ptoas_ir() rewrites the generic-form ops the older Python binding emits into the custom assembly form ptoas parses. Verified on 910B2 / 24 cubes, head_dim=128, S0=128, causal=False (block_dim=1, single Q block per kernel): DSL kernel-only (ptodsl.do_bench, warmup=10, iter=100): | S1 | DSL us | DSL TFLOPS | torch_npu fused us | max diff | |-----:|-------:|-----------:|-------------------:|---------:| | 1024 | 17.65 | 3.80 | 28.72 | 4.43e-05 | | 2048 | 27.78 | 4.83 | 29.74 | 2.72e-05 | | 8192 | 83.15 | 6.46 | 48.29 | 1.48e-05 | Manual C++ baseline kernel-only (build/report.csv duration_us, same shape, SOC_VERSION=Ascend910B1): | S1 | Manual us | Manual TFLOPS | DSL / Manual | |-----:|----------:|--------------:|-------------:| | 1024 | 37.46 | 1.79 | 0.47x | | 2048 | 59.84 | 2.24 | 0.46x | | 8192 | 91.38 | 5.88 | 0.91x | DSL outperforms the manual C++ baseline at all three lengths under matched shape and same kernel-only timing window; speedup decays with S1 because manual amortizes its larger fixed pipe-barrier window over more work. Both still trail torch_npu.npu_fused_infer_attention_score at S1=8192 (multi-core); single-Q-block, single-cube path here is the direct manual-parity comparison.
882d1b2 to
75eec67
Compare
Summary
kernels/python/flash_atten/directory containing a PTO-DSL Flash Attention kernel (kernels/fa_builder.py) ported fromkernels/manual/common/flash_atten/fa_performance_kernel.cpp. Four-stage cross-core software pipeline (compute_qk / compute_p / compute_pv / compute_gu) withTILE_S1=256,CUBE_S1=128,QK_PRELOAD=4mirroring the manual kernel.gm_slot_tensorand pop/free carry the tensor_view entry, restoring real free-notification semantics for the GM FIFO at long S1 (was the S1>=4096 hang root cause).compile.sh(.pto -> .cpp -> .so via ptoas + bisheng) which post-processesLocalSlotNum=8 -> 2in the ptoas-generated C++ (temporary; ptoas's address-based pipe lowering currently ignoreslocal_slot_num).caller.cppshared-library launcher wrapper,run.pyTorch-NPU runner (correctness check at S1=1024 vs torch reference, bench sweep overFA_BENCH_LENGTHS), andrun_fa.shone-shot driver (PYTHONPATH / PTO_LIB_PATH / PTOAS plus thetask-submitwrap required for NPU access on this server).DSL temporary shims in
fa_builder.py(to be removed once pto-dsl ships the new dialect):PTODSL_SKIP_VERIFY=1: the installed Python MLIR verifier predates thegm_slot_tensoraddress-based init shape; ptoas accepts it.canonicalize_ptoas_ir()rewrites the generic-form ops the older Python binding emits into the custom assembly form ptoas parses.Performance
910B1, 24 cubes, head_dim=128, S0=128, causal=False, single Q block (block_dim=1, single cube path).
do_benchwarmup=10, iter=100.refabove istorch_npu.npu_fused_infer_attention_score(multi-core fused kernel, included for upper-bound reference only).vs manual C++ baseline (matched shape, kernel-only)
Manual C++ from
kernels/manual/common/flash_atten/build/report.csv(HEAD=128, S0=128, TILE_S1=256, QK_PRELOAD=4, SOC_VERSION=Ascend910B1, single-core):DSL outperforms the manual C++ baseline at all three lengths under matched shape and same kernel-only timing window; speedup is 2.12x / 2.15x / 1.10x at S1=1024 / 2048 / 8192. The gap narrows at long S1 because the manual amortizes its larger fixed pipe-barrier window over more work.
Testing
bash run_fa.shpasses at S1=1024 (atol/rtol 1e-3) vsfa_reference(torch fp32 SDPA).task-submit(bash run_fa.sh), all three lengths pass the in-bench correctness probe (max err <= 4.43e-05).bash run.sh -r npu -v Ascend910B1 -a "128,128,1024,128,256;128,128,2048,128,256;128,128,8192,128,256"); all three cases pass the manual-sideo_outcompare.