feat(flash-attn): add Python DSL Flash Attention example under kernels/python/flash_atten#117
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements a pto-dsl version of Flash Attention, including the kernel builder, host shim, and benchmarking infrastructure. The review identified several high-priority issues: missing error handling for rtGetC2cCtrlAddr which could lead to null pointer dereferences, and buffer aliasing that undermines the software pipeline's performance and safety. Additionally, there is a discrepancy between the implemented tile size and the documentation, and several instances where tile.muls should be replaced with tile.mov for better efficiency.
| (void)rtGetC2cCtrlAddr(reinterpret_cast<uint64_t *>(&fftsAddr), &fftsLen); | ||
| (void)fftsLen; |
There was a problem hiding this comment.
b4545d9 to
f42198c
Compare
f42198c to
c461b41
Compare
|
Triage review (2026-05-08): this PR is mergeable from a repository-state perspective: GitHub reports it as clean against Before merging, please resolve the readiness signal: the title still says No blocking conflict found in this triage pass; this still deserves normal owner review of the PTO-DSL kernel semantics and the NPU benchmark claims before merge. |
|
Added |
35b35de to
871cceb
Compare
6980ecf to
882d1b2
Compare
882d1b2 to
75eec67
Compare
75eec67 to
d5e76ff
Compare
…s/python/flash_atten Port the manual Flash Attention kernel under kernels/manual/common/flash_atten to the PTO Python DSL (ptodsl) and add a build/run/benchmark entry point. - kernels/python/flash_atten/kernels/fa_builder.py: PTO Python DSL builder for a four-stage Cube/Vector software pipeline (compute_qk -> compute_p -> compute_pv -> compute_gu) with TILE_S1=256, CUBE_S1=128, QK_PRELOAD=4, matching the manual kernel shape. - kernels/python/flash_atten/caller.cpp: host shim exported as call_kernel for ctypes. - kernels/python/flash_atten/compile.sh: ptoas + bisheng build pipeline, emits build_artifacts/fa.mlir, fa.cpp, fa.so. - kernels/python/flash_atten/run.py: Torch-NPU driver with correctness check vs FP32 torch reference / npu_fused_infer_attention_score and a sweep over case1..case8 (S0=S1 from 1024 up to 131072), TFLOP/s report and TSV summary. - kernels/python/flash_atten/README.md, README_zh.md: usage, supported platform, build/run, custom cases, output format. - .gitignore: ignore kernels/python/flash_atten/build_artifacts/. Design references: - kernels/manual/common/flash_atten (in-tree manual kernel; pipeline, TILE_S1/CUBE_S1/QK_PRELOAD shape and FIFO layout) - https://github.com/huawei-csl/pto-dsl/tree/main/examples/aot/flash_attention/140tflops (AOT 140 TFLOPS reference; benchmark conventions, TFLOP/s accounting)
d5e76ff to
ee133d0
Compare
Summary
Add a Python DSL Flash Attention example under
kernels/python/flash_atten/.kernels/manual/common/flash_atten(same four-stage Cube/Vector pipelinecompute_qk -> compute_p -> compute_pv -> compute_gu, sameTILE_S1=256 / CUBE_S1=128 / QK_PRELOAD=4shape and FIFO layout).Files
kernels/python/flash_atten/kernels/fa_builder.py— PTO Python DSL kernel builder (ptodsl).kernels/python/flash_atten/caller.cpp— host shim exported ascall_kernelfor ctypes.kernels/python/flash_atten/compile.sh—ptoas+bishengbuild pipeline (build_artifacts/fa.mlir,fa.cpp,fa.so).kernels/python/flash_atten/run.py— Torch-NPU driver: build, correctness check vs FP32 reference /torch_npu.npu_fused_infer_attention_score, sweepcase1..case8(S0=S1 from 1024 to 131072), TFLOP/s report, TSV summary.kernels/python/flash_atten/README.md,README_zh.md— usage, supported platform, build/run, custom cases..gitignore— ignorekernels/python/flash_atten/build_artifacts/.Kernel scope
HEAD = 128,S0 = 128per Q block,TILE_S1 = 256,CUBE_S1 = 128,QK_PRELOAD = 4, non-causal only.FA_Q_ROWS(multiple of 128); total KV rows supplied at runtime; each S1 must be compatible withS1_TILE=256andQK_PRELOAD=4.Build & run
Performance
Ascend 910B1,
HEAD=128,S0=S1, non-causal.do_benchwarmup=10, iter=100.fa_usis kernel-only host-side stream-event window;TF/suses matmul + scale + softmax FLOP counts (same accounting as the 140 TFLOPS reference).Python DSL vs Huawei CSL 140 TFLOPS reference
Also vs
torch_npu.npu_fused_infer_attention_score(multi-core fused kernel, upper-bound reference only)Headline:
Testing
compile.shproducesbuild_artifacts/fa.soperFA_Q_ROWS.torch_npu.npu_fused_infer_attention_score).python3 run.pyrunscase1..case8end-to-end and emits the per-case TFLOP/s and a TSV summary.