Skip to content

[WIP] feat: Add PTO-DSL Flash Attention performance kernel with Python validation runner#117

Open
chenshengxin2026 wants to merge 1 commit into
hw-native-sys:mainfrom
chenshengxin2026:feat/flash-attn-v1
Open

[WIP] feat: Add PTO-DSL Flash Attention performance kernel with Python validation runner#117
chenshengxin2026 wants to merge 1 commit into
hw-native-sys:mainfrom
chenshengxin2026:feat/flash-attn-v1

Conversation

@chenshengxin2026
Copy link
Copy Markdown

@chenshengxin2026 chenshengxin2026 commented May 7, 2026

Summary

  • Add kernels/python/flash_atten/ directory containing a PTO-DSL Flash Attention kernel (kernels/fa_builder.py) ported from kernels/manual/common/flash_atten/fa_performance_kernel.cpp. Four-stage cross-core software pipeline (compute_qk / compute_p / compute_pv / compute_gu) with TILE_S1=256, CUBE_S1=128, QK_PRELOAD=4 mirroring the manual kernel.
  • Address-based pipes (PR #606): cube/vec init use gm_slot_tensor and pop/free carry the tensor_view entry, restoring real free-notification semantics for the GM FIFO at long S1 (was the S1>=4096 hang root cause).
  • Per-row_slice softmax/GU loops keep the working tile at 32 KiB so three fp32 working tiles co-exist with pv/o tiles in 192 KiB UB at TILE_S1=256.
  • Add compile.sh (.pto -> .cpp -> .so via ptoas + bisheng) which post-processes LocalSlotNum=8 -> 2 in the ptoas-generated C++ (temporary; ptoas's address-based pipe lowering currently ignores local_slot_num).
  • Add caller.cpp shared-library launcher wrapper, run.py Torch-NPU runner (correctness check at S1=1024 vs torch reference, bench sweep over FA_BENCH_LENGTHS), and run_fa.sh one-shot driver (PYTHONPATH / PTO_LIB_PATH / PTOAS plus the task-submit wrap required for NPU access on this server).

DSL temporary shims in fa_builder.py (to be removed once pto-dsl ships the new dialect):

  • PTODSL_SKIP_VERIFY=1: the installed Python MLIR verifier predates the gm_slot_tensor address-based init shape; ptoas accepts it.
  • canonicalize_ptoas_ir() rewrites the generic-form ops the older Python binding emits into the custom assembly form ptoas parses.

Performance

910B1, 24 cubes, head_dim=128, S0=128, causal=False, single Q block (block_dim=1, single cube path). do_bench warmup=10, iter=100.

==================================== Benchmark (pto-dsl fa) ====================================
  manual target: S0=128 HEAD=128 CUBE_S0=128 CUBE_S1=128 TILE_S1=256 QK_PRELOAD=4 causal=False
  dsl effective: Q_ROWS=128 HEAD=128 CUBE_S0=128 S1_TILE=256 QK_PRELOAD=4 NUM_Q_BLOCKS=1
  same host-visible shape and QK_PRELOAD as the manual non-causal S0=128 path
  cores=24
  lengths: [1024, 2048, 8192]
------------------------------------------------------------------------------------------------
  s1=  1024  tiles=  4  fa=   17.65us ( 3802.5 GF/s)  ref=   28.72us ( 2337.0 GF/s)  speedup=1.63x  err: ours=4.43e-05 ref=7.57e-05
  s1=  2048  tiles=  8  fa=   27.78us ( 4830.8 GF/s)  ref=   29.74us ( 4513.4 GF/s)  speedup=1.07x  err: ours=2.72e-05 ref=7.88e-05
  s1=  8192  tiles= 32  fa=   83.15us ( 6456.4 GF/s)  ref=   48.29us (11118.7 GF/s)  speedup=0.58x  err: ours=1.48e-05 ref=2.57e-05
================================================================================================

ref above is torch_npu.npu_fused_infer_attention_score (multi-core fused kernel, included for upper-bound reference only).

vs manual C++ baseline (matched shape, kernel-only)

Manual C++ from kernels/manual/common/flash_atten/build/report.csv (HEAD=128, S0=128, TILE_S1=256, QK_PRELOAD=4, SOC_VERSION=Ascend910B1, single-core):

S1 Manual C++ (us) Manual TFLOPS DSL (us) DSL TFLOPS DSL / Manual
1024 37.46 1.79 17.65 3.80 0.47x
2048 59.84 2.24 27.78 4.83 0.46x
8192 91.38 5.88 83.15 6.46 0.91x

DSL outperforms the manual C++ baseline at all three lengths under matched shape and same kernel-only timing window; speedup is 2.12x / 2.15x / 1.10x at S1=1024 / 2048 / 8192. The gap narrows at long S1 because the manual amortizes its larger fixed pipe-barrier window over more work.

Measurement note: DSL timing is ptodsl.do_bench (host-side stream-event window across warmup=10/iter=100). Manual timing is duration_us from build/report.csv, captured by get_sys_cnt() at kernel entry/exit including a pipe_barrier(PIPE_ALL) (see kernels/manual/common/flash_atten/PERFORMANCE_MEASUREMENT.md). Both are kernel-only but the manual window includes the trailing full-pipe barrier; numbers above are NOT averaged across multiple iterations on the manual side.

Testing

  • Correctness: bash run_fa.sh passes at S1=1024 (atol/rtol 1e-3) vs fa_reference (torch fp32 SDPA).
  • Bench sweep over S1 in {1024, 2048, 8192} runs end-to-end under task-submit (bash run_fa.sh), all three lengths pass the in-bench correctness probe (max err <= 4.43e-05).
  • Manual C++ baseline rebuilt and re-run with the same H=128 / S0=128 / TILE_S1=256 / QK_PRELOAD=4 case set (bash run.sh -r npu -v Ascend910B1 -a "128,128,1024,128,256;128,128,2048,128,256;128,128,8192,128,256"); all three cases pass the manual-side o_out compare.

Copy link
Copy Markdown

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a pto-dsl version of Flash Attention, including the kernel builder, host shim, and benchmarking infrastructure. The review identified several high-priority issues: missing error handling for rtGetC2cCtrlAddr which could lead to null pointer dereferences, and buffer aliasing that undermines the software pipeline's performance and safety. Additionally, there is a discrepancy between the implemented tile size and the documentation, and several instances where tile.muls should be replaced with tile.mov for better efficiency.

Comment on lines +29 to +30
(void)rtGetC2cCtrlAddr(reinterpret_cast<uint64_t *>(&fftsAddr), &fftsLen);
(void)fftsLen;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The return value of rtGetC2cCtrlAddr is ignored. If this function fails, fftsAddr will remain nullptr, which will cause the kernel to crash or exhibit undefined behavior when pto.set_ffts is called inside the kernel. You should check the return code and handle potential errors appropriately.

Comment on lines +286 to +293
k_mat = [k_mat_a, k_mat_a]
k_right = [k_right_a, k_right_a]
qk_acc = [qk_acc_a, qk_acc_a]
p_recv = [p_recv_a, p_recv_a]
p_left = [p_left_a, p_left_a]
v_mat = [v_mat_a, v_mat_a]
v_right = [v_right_a, v_right_a]
pv_acc = [pv_acc_a, pv_acc_a]
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Aliasing multiple logical buffers (e.g., k_mat[0] and k_mat[1]) to the same physical tile (k_mat_a) effectively disables software pipelining and double-buffering. In a kernel designed for performance with QK_PRELOAD=2, this creates a race condition where subsequent loads overwrite data still in flight, or forces serialization that defeats the purpose of the pipeline. Given that S1_TILE=32 uses very little L1 memory (~8KB per tile), you should allocate separate tiles to enable true ping-pong buffering.

S0 = 128
S0_HALF = S0 // 2
HEAD = 128
S1_TILE = 32
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There is a significant inconsistency between the code (S1_TILE = 32) and the PR description/comments (e.g., lines 13, 28, 32, 58) which claim TILE_S1=256. While a comment at line 97 suggests 32 was chosen to avoid VEC UB overflow, this deviation means the kernel is not achieving the stated goal of parity with the manual performance kernel. If 256 is the target, this constant should be updated; otherwise, the documentation and PR summary should be corrected to reflect the actual implementation.


if is_init:
tile.row_expand_sub(qk, local_max, p_fp32)
tile.muls(local_max_r, f32_one, running_max_r)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using tile.muls with a multiplier of 1.0 to perform a tile copy is less efficient and less readable than a direct move. Since tile.mov is used elsewhere in this file (e.g., line 503), it should be used here as well.

Suggested change
tile.muls(local_max_r, f32_one, running_max_r)
tile.mov(local_max_r, running_max_r)

# exp_max = m_prev - m_new (rescale factor for running_sum / O)
tile.sub(running_max_r, local_max_r, exp_max_r)
# running_max <- m_new
tile.muls(local_max_r, f32_one, running_max_r)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using tile.muls with a multiplier of 1.0 to perform a tile copy is less efficient and less readable than a direct move. Use tile.mov instead.

Suggested change
tile.muls(local_max_r, f32_one, running_max_r)
tile.mov(local_max_r, running_max_r)

@chenshengxin2026 chenshengxin2026 changed the title Add pto-dsl Python port of Flash Attention v1 perf kernel feat: Add PTO-DSL Flash Attention v1 performance kernel with Python validation runner May 7, 2026
@chenshengxin2026 chenshengxin2026 changed the title feat: Add PTO-DSL Flash Attention v1 performance kernel with Python validation runner [WIP] feat: Add PTO-DSL Flash Attention v1 performance kernel with Python validation runner May 7, 2026
@zhoubot
Copy link
Copy Markdown
Collaborator

zhoubot commented May 8, 2026

Triage review (2026-05-08): this PR is mergeable from a repository-state perspective: GitHub reports it as clean against main, and all four CI checks pass (Pre-commit, Docs build, CPU SIM smoke, CPU SIM full ST). I reviewed the changed file set and the runner/build flow at a high level.

Before merging, please resolve the readiness signal: the title still says [WIP], while the PR is not draft and CI is green. Either retitle it as ready or mark it draft if the known long-sequence performance gap is still a blocker. I would also prefer a small README/index entry for kernels/python/flash_atten-v1/, because the body contains important constraints (S1_TILE=32, QK_PRELOAD=2, causal unsupported, long-S1 gap) that will be hard to discover after merge.

No blocking conflict found in this triage pass; this still deserves normal owner review of the PTO-DSL kernel semantics and the NPU benchmark claims before merge.

@chenshengxin2026
Copy link
Copy Markdown
Author

Added kernels/python/flash_atten-v2/ (commit 35b35de) — manual-aligned variant at TILE_S1=256, CUBE_S1=128, kTileFactor=2, all three pipes on the address-based slot model from PTOAS PR #606. Single-call correctness PASSED at S1=1024 (max_err 4.43e-05) and S1=2048 (max_err 2.72e-05); S1>=4096 currently aicore-timeouts because ptoas emits TPipe<...,SlotNum=8,LocalSlotNum=8,...> for globaltensor pipe init instead of the manual's ...,8,2,.... Filed details + suggested fix as #118.

@chenshengxin2026
Copy link
Copy Markdown
Author

Upstream sync — issues / PRs filed for the long-term design

The local fixes that make kernels/python/flash_atten/ build and run end-to-end split into three layers. The kernel itself is in this PR; the supporting frontend / lowering changes are tracked upstream so the workarounds in compile.sh and fa_builder.py (regex post-process, PTODSL_SKIP_VERIFY=1, canonicalize_ptoas_ir) can be retired:

pto-dsl (Python frontend, address-based pipe API)

PTOAS (lowering, gm_slot_tensor + LocalSlotNum)

Cleanups that become possible once the above land

Workaround Removable when
compile.sh regex s/TPipe<..., 8, 8, false>/TPipe<..., 8, 2, false>/ hw-native-sys/PTOAS#650 lands and fa_builder.py writes local_slot_num=2 on the IR.
os.environ.setdefault("PTODSL_SKIP_VERIFY", "1") in fa_builder.py The MLIR Python binding/verifier shipped with pto-dsl catches up to the address-based init shape (the additive API in PTO-ISA/pto-dsl#8 is the prerequisite, but the binding refresh is a follow-up).
canonicalize_ptoas_ir(text) text rewrite pto-dsl's printer emits the custom assembly form for the new ops natively (also a binding follow-up after PTO-ISA/pto-dsl#8).

The kernel logic itself — entry-carrying tfree_from_aic/aiv, address-based pipes, SlotNum=8 / LocalSlotNum=2, TILE_S1=256 / QK_PRELOAD=4 — is the long-term design and stays.

@chenshengxin2026 chenshengxin2026 changed the title [WIP] feat: Add PTO-DSL Flash Attention v1 performance kernel with Python validation runner [WIP] feat: Add PTO-DSL Flash Attention performance kernel with Python validation runner May 9, 2026
…ith TILE_S1=256/QK_PRELOAD=4

Replace the earlier flash_atten-v1/ scaffold and the parity flash_atten-v2/
experiment with a single PTO-DSL flash-attention kernel that reaches manual
parity end-to-end, drops the bench launcher cleanly through ptoas + bisheng,
and runs the full 1024/2048/8192 sweep on 910B2 / 24 cubes.

Kernel (kernels/fa_builder.py):
- TILE_S1=256, CUBE_S1=128, QK_PRELOAD=4 mirroring
  kernels/manual/common/flash_atten/fa_performance_kernel.cpp.
- Per-row_slice softmax/GU loops keep the working tile at 32 KiB, so three
  fp32 working tiles co-exist with pv/o tiles in 192 KiB UB at TILE_S1=256.
- 4-slot QK preload with a paired exp_max ring; steady-state group of 4
  matches the manual schedule.
- Address-based pipes (PR #606): cube/vec init use gm_slot_tensor and
  pop/free carry the tensor_view entry, restoring real free-notification
  semantics for the GM FIFO at long S1 (was the S1>=4096 hang root cause).

compile.sh:
- Build N=4/8/32 variants by default (S1=1024/2048/8192).
- Post-process LocalSlotNum=8 -> 2 in the ptoas-generated C++; required
  while ptoas's address-based pipe lowering ignores local_slot_num on the
  IR. This is a temporary patch and will be removed once PTOAS lowers
  gm_slot_tensor + local_slot_num correctly (filed upstream).

DSL temporary shims in fa_builder.py (to be removed once pto-dsl ships
the new dialect):
- PTODSL_SKIP_VERIFY=1: the installed Python MLIR verifier predates the
  gm_slot_tensor address-based init shape; ptoas accepts it.
- canonicalize_ptoas_ir() rewrites the generic-form ops the older Python
  binding emits into the custom assembly form ptoas parses.

Verified on 910B2 / 24 cubes, head_dim=128, S0=128, causal=False
(block_dim=1, single Q block per kernel):

DSL kernel-only (ptodsl.do_bench, warmup=10, iter=100):
  | S1   | DSL us | DSL TFLOPS | torch_npu fused us | max diff |
  |-----:|-------:|-----------:|-------------------:|---------:|
  | 1024 |  17.65 |       3.80 |              28.72 | 4.43e-05 |
  | 2048 |  27.78 |       4.83 |              29.74 | 2.72e-05 |
  | 8192 |  83.15 |       6.46 |              48.29 | 1.48e-05 |

Manual C++ baseline kernel-only (build/report.csv duration_us, same shape,
SOC_VERSION=Ascend910B1):
  | S1   | Manual us | Manual TFLOPS | DSL / Manual |
  |-----:|----------:|--------------:|-------------:|
  | 1024 |     37.46 |          1.79 |        0.47x |
  | 2048 |     59.84 |          2.24 |        0.46x |
  | 8192 |     91.38 |          5.88 |        0.91x |

DSL outperforms the manual C++ baseline at all three lengths under
matched shape and same kernel-only timing window; speedup decays with S1
because manual amortizes its larger fixed pipe-barrier window over more
work. Both still trail torch_npu.npu_fused_infer_attention_score at
S1=8192 (multi-core); single-Q-block, single-cube path here is the
direct manual-parity comparison.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants