diff --git a/.gitignore b/.gitignore index 55d5fdd67..460b50f32 100644 --- a/.gitignore +++ b/.gitignore @@ -13,8 +13,9 @@ site/ site_zh/ .venv-mkdocs/ *.o -output/ -__pycache__/ +output/ +__pycache__/ +kernels/python/flash_atten/build_artifacts/ *.log *.swp *.pdf diff --git a/kernels/python/flash_atten/README.md b/kernels/python/flash_atten/README.md new file mode 100644 index 000000000..48efcc8b2 --- /dev/null +++ b/kernels/python/flash_atten/README.md @@ -0,0 +1,154 @@ +# Python DSL Flash Attention Example + +## Overview + +This example demonstrates a high-performance Flash Attention implementation written with the PTO Python DSL (`ptodsl`). It is a Python-DSL port and parity experiment for the manual kernel in `kernels/manual/common/flash_atten`, and it follows the same four-stage software pipeline: + +```text +compute_qk (Cube) -> compute_p (Vector) -> compute_pv (Cube) -> compute_gu (Vector) +``` + +The implementation also references the Huawei CSL PTO DSL AOT Flash Attention 140 TFLOPS example: + +```text +https://github.com/huawei-csl/pto-dsl/tree/main/examples/aot/flash_attention/140tflops +``` + +The case is useful for validating that the Python DSL can express a production-style Flash Attention pipeline, including Cube/Vector cooperation, runtime S1 looping, software FIFO staging through global memory, correctness checks, and performance comparison against `torch_npu.npu_fused_infer_attention_score`. + +## Supported Platform + +- Ascend A3-class target (`--pto-arch=a3`, `--npu-arch=dav-2201` in `compile.sh`) +- CANN environment with `bisheng` +- PTO assembler `ptoas` +- Python environment with `ptodsl`, `torch`, and `torch_npu` + +## Directory Layout + +```text +kernels/python/flash_atten/ +├── caller.cpp # Host shim exported as call_kernel for ctypes +├── compile.sh # Generates MLIR/C++ and builds build_artifacts/fa.so +├── kernels/ +│ └── fa_builder.py # PTO Python DSL Flash Attention kernel builder +└── run.py # Build, run, verify, and benchmark entry point +``` + +Generated files are placed under `build_artifacts/`: + +```text +build_artifacts/fa.mlir # MLIR emitted by fa_builder.py +build_artifacts/fa.cpp # C++ emitted by ptoas +build_artifacts/fa.so # Shared library loaded by run.py +build_artifacts/fa_summary_*.tsv +``` + +## Kernel Scope + +Current shape and feature constraints are intentionally aligned with the manual parity target: + +- `HEAD = 128` +- `S0 = 128` per Q block +- `TILE_S1 = 256` +- `CUBE_S1 = 128` +- `QK_PRELOAD = 4` +- non-causal attention only +- total Q rows are configured by `FA_Q_ROWS` and must be a multiple of `128` +- total KV rows are supplied at runtime by `run.py`; each S1 length must be compatible with `S1_TILE=256` and `QK_PRELOAD=4` + +The generated shared library is specialized for the current `FA_Q_ROWS`, while S1 is handled at runtime. + +## Build and Run + +1. Configure the Ascend CANN environment. + +```bash +source ${ASCEND_INSTALL_PATH}/bin/setenv.bash +``` + +2. Enter the example directory and set the PTO include path. + +```bash +cd ${git_clone_path}/kernels/python/flash_atten +export PTO_LIB_PATH=${git_clone_path} +``` + +If `ptoas` or `bisheng` are not in `PATH`, set them explicitly: + +```bash +export PTOAS=/path/to/ptoas +export BISHENG=/path/to/bisheng +``` + +3. Run one default benchmark case. + +```bash +python3 run.py --case case1 +``` + +4. Run the full default benchmark suite. + +```bash +python3 run.py +``` + +The default suite runs `case1` to `case8` and recompiles the kernel for each `FA_Q_ROWS` value: + +| Case | Q rows (S0 total) | KV rows (S1) | +| --- | ---: | ---: | +| `case1` | 1024 | 1024 | +| `case2` | 2048 | 2048 | +| `case3` | 4096 | 4096 | +| `case4` | 8192 | 8192 | +| `case5` | 16384 | 16384 | +| `case6` | 32768 | 32768 | +| `case7` | 65536 | 65536 | +| `case8` | 131072 | 131072 | + +## Custom Cases + +Run a custom shape by setting `FA_Q_ROWS` and `FA_BENCH_LENGTHS`: + +```bash +FA_Q_ROWS=1024 FA_BENCH_LENGTHS=1024 python3 run.py +``` + +Run several S1 lengths for one compiled Q shape: + +```bash +FA_Q_ROWS=2048 FA_BENCH_LENGTHS=1024,2048,4096 python3 run.py +``` + +Control benchmark iterations: + +```bash +FA_Q_ROWS=1024 FA_BENCH_LENGTHS=1024 FA_BENCH_WARMUP=20 FA_BENCH_ITERS=200 python3 run.py +``` + +Reuse an existing `build_artifacts/fa.so` when it was already compiled for the same `FA_Q_ROWS`: + +```bash +FA_Q_ROWS=1024 bash compile.sh +FA_Q_ROWS=1024 FA_BENCH_LENGTHS=1024 python3 run.py --no-build +``` + +## Output and Correctness + +`run.py` prints latency, throughput, speedup, and max error for each shape. It compares the DSL kernel with: + +- a host FP32 PyTorch reference when `Q_ROWS * S1` is small enough +- `torch_npu.npu_fused_infer_attention_score` for all benchmark sizes + +Throughput is reported as TFLOP/s using matmul, scale, and softmax operation counts, following the 140 TFLOPS reference script convention. + +A summary TSV is generated automatically for the default suite. You can choose the output path with `FA_SUMMARY_TSV`: + +```bash +FA_SUMMARY_TSV=/tmp/fa_summary.tsv python3 run.py --case case1 +``` + +## Notes + +- `compile.sh` defaults `PTO_LIB_PATH` to `/sources/pto-isa`; set `PTO_LIB_PATH=${git_clone_path}` when working from this repository. +- `--no-build` is only suitable for one selected case because `fa.so` is rebuilt per `FA_Q_ROWS`. +- Large sequence lengths can skip the host FP32 reference to avoid allocating a very large QK matrix; correctness is then checked against the NPU fused reference with a looser tolerance. diff --git a/kernels/python/flash_atten/README_zh.md b/kernels/python/flash_atten/README_zh.md new file mode 100644 index 000000000..3404328b9 --- /dev/null +++ b/kernels/python/flash_atten/README_zh.md @@ -0,0 +1,154 @@ +# Python DSL Flash Attention 用例 + +## 概览 + +本用例展示如何使用 PTO Python DSL(`ptodsl`)实现高性能 Flash Attention。该实现是 `kernels/manual/common/flash_atten` 手写 kernel 的 Python-DSL 迁移与对齐实验,保留了手写版本中的四阶段软流水: + +```text +compute_qk(Cube) -> compute_p(Vector) -> compute_pv(Cube) -> compute_gu(Vector) +``` + +实现过程中也参考了 Huawei CSL PTO DSL AOT Flash Attention 140 TFLOPS 示例: + +```text +https://github.com/huawei-csl/pto-dsl/tree/main/examples/aot/flash_attention/140tflops +``` + +该用例的意义在于验证 Python DSL 对高性能 Flash Attention 这类复杂算子的表达能力,包括 Cube/Vector 协同、运行时 S1 循环、通过全局内存进行软件 FIFO 暂存、结果正确性校验,以及与 `torch_npu.npu_fused_infer_attention_score` 的性能对比。 + +## 支持平台 + +- Ascend A3 类目标(`compile.sh` 中使用 `--pto-arch=a3`、`--npu-arch=dav-2201`) +- 已配置 `bisheng` 的 CANN 环境 +- PTO 汇编器 `ptoas` +- 包含 `ptodsl`、`torch`、`torch_npu` 的 Python 环境 + +## 目录结构 + +```text +kernels/python/flash_atten/ +├── caller.cpp # Host 侧 shim,导出供 ctypes 调用的 call_kernel +├── compile.sh # 生成 MLIR/C++,并构建 build_artifacts/fa.so +├── kernels/ +│ └── fa_builder.py # PTO Python DSL Flash Attention kernel 构造器 +└── run.py # 构建、运行、校验和性能测试入口 +``` + +生成产物位于 `build_artifacts/`: + +```text +build_artifacts/fa.mlir # fa_builder.py 生成的 MLIR +build_artifacts/fa.cpp # ptoas 生成的 C++ +build_artifacts/fa.so # run.py 加载的动态库 +build_artifacts/fa_summary_*.tsv +``` + +## Kernel 范围 + +当前形状和功能约束有意与手写版本对齐: + +- `HEAD = 128` +- 每个 Q block 的 `S0 = 128` +- `TILE_S1 = 256` +- `CUBE_S1 = 128` +- `QK_PRELOAD = 4` +- 仅支持非 causal attention +- Q 总行数通过 `FA_Q_ROWS` 配置,并且必须是 `128` 的整数倍 +- KV 总行数由 `run.py` 在运行时传入;每个 S1 长度需要满足 `S1_TILE=256` 和 `QK_PRELOAD=4` 的整除约束 + +生成的动态库会针对当前 `FA_Q_ROWS` 特化,S1 长度则在运行时处理。 + +## 构建与运行 + +1. 配置 Ascend CANN 环境。 + +```bash +source ${ASCEND_INSTALL_PATH}/bin/setenv.bash +``` + +2. 进入用例目录并设置 PTO 头文件路径。 + +```bash +cd ${git_clone_path}/kernels/python/flash_atten +export PTO_LIB_PATH=${git_clone_path} +``` + +如果 `ptoas` 或 `bisheng` 不在 `PATH` 中,可以显式设置: + +```bash +export PTOAS=/path/to/ptoas +export BISHENG=/path/to/bisheng +``` + +3. 运行一个默认 benchmark case。 + +```bash +python3 run.py --case case1 +``` + +4. 运行完整默认 benchmark 集合。 + +```bash +python3 run.py +``` + +默认集合会运行 `case1` 到 `case8`,并针对每个 `FA_Q_ROWS` 重新编译 kernel: + +| Case | Q rows(S0 total) | KV rows(S1) | +| --- | ---: | ---: | +| `case1` | 1024 | 1024 | +| `case2` | 2048 | 2048 | +| `case3` | 4096 | 4096 | +| `case4` | 8192 | 8192 | +| `case5` | 16384 | 16384 | +| `case6` | 32768 | 32768 | +| `case7` | 65536 | 65536 | +| `case8` | 131072 | 131072 | + +## 自定义 Case + +通过 `FA_Q_ROWS` 和 `FA_BENCH_LENGTHS` 运行自定义形状: + +```bash +FA_Q_ROWS=1024 FA_BENCH_LENGTHS=1024 python3 run.py +``` + +对同一个 Q 形状测试多个 S1 长度: + +```bash +FA_Q_ROWS=2048 FA_BENCH_LENGTHS=1024,2048,4096 python3 run.py +``` + +控制 benchmark 迭代次数: + +```bash +FA_Q_ROWS=1024 FA_BENCH_LENGTHS=1024 FA_BENCH_WARMUP=20 FA_BENCH_ITERS=200 python3 run.py +``` + +当 `build_artifacts/fa.so` 已经按相同 `FA_Q_ROWS` 编译过时,可以复用已有动态库: + +```bash +FA_Q_ROWS=1024 bash compile.sh +FA_Q_ROWS=1024 FA_BENCH_LENGTHS=1024 python3 run.py --no-build +``` + +## 输出与正确性 + +`run.py` 会输出每个形状的时延、吞吐、相对 `torch_npu` 融合 attention 的加速比,以及最大误差。正确性对比包括: + +- 当 `Q_ROWS * S1` 不太大时,使用 host 侧 FP32 PyTorch reference +- 所有 benchmark 尺寸都会对比 `torch_npu.npu_fused_infer_attention_score` + +TFLOP/s 统计包含 matmul、scale 和 softmax 操作量,保持与 140 TFLOPS 参考脚本一致的计数口径。 + +默认集合会自动生成 summary TSV。也可以通过 `FA_SUMMARY_TSV` 指定输出路径: + +```bash +FA_SUMMARY_TSV=/tmp/fa_summary.tsv python3 run.py --case case1 +``` + +## 注意事项 + +- `compile.sh` 默认将 `PTO_LIB_PATH` 设为 `/sources/pto-isa`;在本仓工作时建议显式设置 `PTO_LIB_PATH=${git_clone_path}`。 +- `--no-build` 只适合单个已选 case,因为 `fa.so` 会按 `FA_Q_ROWS` 重新构建。 +- 长序列可能跳过 host 侧 FP32 reference,以避免分配过大的 QK 矩阵;此时会用更宽松阈值对比 NPU fused reference。 diff --git a/kernels/python/flash_atten/caller.cpp b/kernels/python/flash_atten/caller.cpp new file mode 100644 index 000000000..2e47d62cd --- /dev/null +++ b/kernels/python/flash_atten/caller.cpp @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * + * Host shim for the pto-dsl flash-attention kernel. compile.sh injects + * -DKERNEL_CPP="\"build_artifacts/fa.cpp\"" + * which makes this TU include the ptoas-generated kernel that defines + * `call_both`. The single exported symbol `call_kernel` is what run.py + * calls via ctypes. + */ + +#ifndef KERNEL_CPP +#error "KERNEL_CPP must be defined at compile time (see compile.sh)." +#endif + +#include + +extern "C" int rtGetC2cCtrlAddr(uint64_t *ctrlAddr, uint32_t *ctrlLen); + +#include KERNEL_CPP + +extern "C" void call_kernel(uint32_t blockDim, void *stream, uint8_t *gmSlotBuffer, uint8_t *q, uint8_t *k, uint8_t *v, + uint8_t *o, int64_t s0, int64_t s1) +{ + void *fftsAddr = nullptr; + uint32_t fftsLen = 0; + (void)rtGetC2cCtrlAddr(reinterpret_cast(&fftsAddr), &fftsLen); + (void)fftsLen; + + call_both<<>>((__gm__ uint64_t *)fftsAddr, (__gm__ float *)gmSlotBuffer, + (__gm__ half *)gmSlotBuffer, (__gm__ half *)q, (__gm__ half *)k, + (__gm__ half *)v, (__gm__ float *)o, s0, s1); +} diff --git a/kernels/python/flash_atten/compile.sh b/kernels/python/flash_atten/compile.sh new file mode 100755 index 000000000..915d09ce4 --- /dev/null +++ b/kernels/python/flash_atten/compile.sh @@ -0,0 +1,50 @@ +#!/usr/bin/env bash +# Copyright (c) 2026 Huawei Technologies Co., Ltd. +# CANN Open Software License Agreement Version 2.0 +# +# Build the pto-dsl flash-attention runtime-S1 .so. The generated kernel loops +# over s1 / S1_TILE at runtime, so one fa.{mlir,cpp,so} covers all supported +# benchmark lengths. +# +# Usage: +# bash compile.sh # build build_artifacts/fa.so +# PTO_LIB_PATH=/abs/pto-isa bash compile.sh # override include path + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +ARTIFACT_DIR="${SCRIPT_DIR}/build_artifacts" +PTO_LIB_PATH="${PTO_LIB_PATH:-/sources/pto-isa}" +PTOAS="${PTOAS:-ptoas}" +BISHENG="${BISHENG:-bisheng}" + +mkdir -p "${ARTIFACT_DIR}" + +MLIR_PATH="${ARTIFACT_DIR}/fa.mlir" +GENERATED_CPP="${ARTIFACT_DIR}/fa.cpp" +LIB_PATH="${ARTIFACT_DIR}/fa.so" + +echo "==> Building runtime-S1 fa -> ${LIB_PATH}" +rm -f "${MLIR_PATH}" "${GENERATED_CPP}" "${LIB_PATH}" + +python "${SCRIPT_DIR}/kernels/fa_builder.py" > "${MLIR_PATH}" +"${PTOAS}" --pto-arch=a3 --enable-insert-sync "${MLIR_PATH}" > "${GENERATED_CPP}" + +"${BISHENG}" \ + -I"${PTO_LIB_PATH}/include" \ + -fPIC -shared -D_FORTIFY_SOURCE=2 -O2 \ + -Wno-macro-redefined -Wno-ignored-attributes -fstack-protector-strong \ + -xcce -Xhost-start -Xhost-end \ + -mllvm -cce-aicore-stack-size=0x8000 \ + -mllvm -cce-aicore-function-stack-size=0x8000 \ + -mllvm -cce-aicore-record-overflow=true \ + -mllvm -cce-aicore-addr-transform \ + -mllvm -cce-aicore-dcci-insert-for-scalar=false \ + -cce-enable-mix \ + --npu-arch=dav-2201 -DMEMORY_BASE \ + -std=gnu++17 \ + -DKERNEL_CPP="\"${GENERATED_CPP}\"" \ + "${SCRIPT_DIR}/caller.cpp" \ + -o "${LIB_PATH}" + +echo "Done. Built ${LIB_PATH}" diff --git a/kernels/python/flash_atten/kernels/fa_builder.py b/kernels/python/flash_atten/kernels/fa_builder.py new file mode 100644 index 000000000..fe7aaa9b7 --- /dev/null +++ b/kernels/python/flash_atten/kernels/fa_builder.py @@ -0,0 +1,783 @@ +""" +Copyright (c) 2026 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under +the terms and conditions of CANN Open Software License Agreement Version 2.0 +(the "License"). Please refer to the License for details. You may not use this +file except in compliance with the License. + +pto-dsl translation of `kernels/manual/common/flash_atten/fa_performance_kernel.cpp`. + +Host-visible shape: + * HEAD is fixed at 128 (current kernel locks head dim to manual case). + * `S0` here is the per-AIC-core Q-block size and stays at 128. + * `Q_ROWS` is the total Q sequence length; configurable per case via the + FA_Q_ROWS env var (defaults to 128). It must be a multiple of S0=128 + because the kernel iterates Q in S0-row blocks across cores. Common + Prefill cases set FA_Q_ROWS to 1024..131072. + * S1 (total KV rows) is taken at runtime via the kernel argument, so the + same .so handles any S1 that satisfies the S1_TILE / QK_PRELOAD + multiplicity check in run.py::_num_tiles. + +Internal S1 tiling is intentionally set to the manual `TILE_S1=256` for +parity experiments. + +The reference C++ kernel is a 4-stage cross-core software-pipelined Flash +Attention: + + compute_qk (Cube): TLOAD Q/K -> matmul -> TPUSH on QKPipe [C2V fp32] + compute_p (Vec ): TPOP QK -> streaming softmax -> TPUSH P [V2C fp16] + compute_pv (Cube): TPOP P -> TLOAD V -> matmul -> TPUSH PV [C2V fp32] + compute_gu (Vec ): TPOP PV -> rescale-and-add into running O + +DSL constraints relative to the C++ source: + + * `tile.triu` is not exposed -> CAUSAL_MASK=False only. + * MAT/RIGHT subview verifier rejects partial-column subviews -> we cannot + fuse "load wide K once, sub-tile matmul into ACC subview"; this parity + experiment therefore uses one DSL tile with TILE_S1=256. + * `--enable-insert-sync` requires careful event-slot reuse at QK_PRELOAD=4; + the 8-slot `exp_max` ring keeps softmax(t+4) from clobbering gu(t)'s + rescale factor. + * TILE_S1=256 is expected to stress VEC UB allocation; keep this variant + as the direct manual-parity experiment even when ptoas reports local-memory + allocation failures. + * `pto.alloc_tile` is single-output -> Python aliases (`[buf, buf]`) + preserve the `[buf]` indexing pattern from the C++ source without + pretending we have ping-pong storage. + +The generated kernel takes S1 at runtime and loops over s1 / S1_TILE, so +one .so covers the benchmark lengths instead of emitting one fully unrolled +variant per sequence length. +""" + +import math +import os + +from ptodsl import pto, tile, to_ir_module +from ptodsl import scalar as s + + +const = s.const + + +# ============================================================================= +# Static shapes -- aligned with manual case `case_float_H_128_S0_128_S1_1024`. +# Single Q block on a single cube core (NUM_Q_BLOCKS=1, block_dim=1) to mirror +# the manual case's benchmark intent (S0/CUBE_S0 = 1 in generated_cases.h). +# Host-visible Q[128,128] / K[128,1024] / V[1024,128] / O[128,128]. +# +# TILE_S1=256 mirrors the manual kernel. This intentionally ignores the +# previous smaller-tile workaround so compile/on-board behavior can expose the +# exact DSL-vs-manual resource gap. +# ============================================================================= +MANUAL_S0 = 128 +MANUAL_HEAD = 128 +MANUAL_CUBE_S0 = 128 +MANUAL_CUBE_S1 = 128 +MANUAL_TILE_S1 = 256 +MANUAL_QK_PRELOAD = 4 +MANUAL_CAUSAL_MASK = False + +S0 = 128 +S0_HALF = S0 // 2 +HEAD = 128 +VEC_CORES = 2 +# Manual alignment: TILE_S1 / CUBE_S1 / kTileFactor mirror the C++ values. +# kernels/manual/common/flash_atten/fa_performance_kernel.h: kFaTileS1=256, +# kFaCubeS1=128. Vec_S0 = Cube_S0/VEC_CORES/kTileFactor adds an inner row_slice +# loop in vec to keep the [Vec_S0, S1_TILE] working tile at 32 KiB at S1_TILE +# =256, which lets three fp32 working tiles co-exist with pv/o tiles in 192 +# KiB UB. VecGuRows = S0_HALF (full subblock row count) is used by GU/PV which +# do not row-split. +CUBE_S1 = 128 +S1_TILE = 256 +TILE_FACTOR = S1_TILE // CUBE_S1 +Vec_S0 = S0 // VEC_CORES // TILE_FACTOR # = 32 (per row_slice) +VecGuRows = S0 // VEC_CORES # = 64 (full subblock = S0_HALF) + +Q_ROWS = int(os.environ.get("FA_Q_ROWS", "128")) +if Q_ROWS % S0 != 0: + raise ValueError( + "FA_Q_ROWS={} must be a multiple of the per-core Q-block size S0={}. " + "Choose an S0 that is divisible by the baked matmul/softmax tile " + "shape S0=128.".format(Q_ROWS, S0) + ) +NUM_Q_BLOCKS = Q_ROWS // S0 + +# QK preload depth. Manual defaults to 4; the exp_max ring below has one slot +# per preloaded logical S1 tile so GU never reads a slot after softmax overwrote +# it for tile+QK_PRELOAD. +QK_PRELOAD = 4 + +# Per-pipe slot sizes (bytes). +SLOT_SIZE_QK = S0 * S1_TILE * 4 # fp32 QK accumulator +SLOT_SIZE_PV = S0 * HEAD * 4 # fp32 PV accumulator +SLOT_SIZE_P = S0 * S1_TILE * 2 # fp16 softmax(QK) sent vec -> cube + +# `dir_mask=1`/`dir_mask=2` always map to slot_num=8 on a3. +SLOT_NUM = 8 +# GM-staged FIFO bytes / fp32 elements per AIC block. +GM_BYTES_PER_BLOCK = (SLOT_SIZE_QK + SLOT_SIZE_PV + SLOT_SIZE_P) * SLOT_NUM +GM_ELEMS_PER_BLOCK = GM_BYTES_PER_BLOCK // 4 +GM_QK_OFF_F32 = 0 +GM_PV_OFF_F32 = (SLOT_SIZE_QK * SLOT_NUM) // 4 +GM_P_OFF_F32 = GM_PV_OFF_F32 + (SLOT_SIZE_PV * SLOT_NUM) // 4 + +SPLIT_UP_DOWN = 1 # TileSplitAxis::TILE_UP_DOWN + + +# ============================================================================= +# Type definitions exposed to the DSL via meta_data(). +# ============================================================================= +def meta_data(): + fp16 = pto.float16 + fp32 = pto.float32 + ffts_ty = pto.ffts_type + ptr_fp16 = pto.PtrType(fp16) + ptr_fp32 = pto.PtrType(fp32) + i64 = pto.int64 + + qkv_tensor_ty = pto.TensorType(rank=2, dtype=fp16) + o_tensor_ty = pto.TensorType(rank=2, dtype=fp32) + + q_sub_ty = pto.SubTensorType(shape=[S0, HEAD], dtype=fp16) + kt_sub_ty = pto.SubTensorType(shape=[HEAD, CUBE_S1], dtype=fp16) + v_sub_ty = pto.SubTensorType(shape=[CUBE_S1, HEAD], dtype=fp16) + o_sub_slice_ty = pto.SubTensorType(shape=[Vec_S0, HEAD], dtype=fp32) + + # ---- Address-based slot descriptors (PR #606). ---- + # The QK pipe slot tensor_view describes the GM region one slot covers; + # talloc/tpop_into bind the declared global entry to the current FIFO slot + # at runtime. partition_view carves a [S0, CUBE_S1] sub-region for each + # cube subtile within the slot. + qk_slot_ty = pto.TensorType(shape=[S0, S1_TILE], dtype=fp32) + qk_slot_part_ty = pto.SubTensorType(shape=[S0, CUBE_S1], dtype=fp32) + qk_vec_slot_ty = pto.TensorType(shape=[VecGuRows, S1_TILE], dtype=fp32) + # Vec consumes the QK slot in TILE_FACTOR row_slices of [Vec_S0, S1_TILE] + # each, mirroring the manual `compute_p` row_slice loop and shrinking the + # per-iteration UB working tile from 64 KiB to 32 KiB at S1_TILE=256. + qk_vec_slot_part_ty = pto.SubTensorType( + shape=[Vec_S0, S1_TILE], dtype=fp32 + ) + # PV slot (cube -> vec, fp32 [S0, HEAD]); width does not scale with S1_TILE + # because one PV is produced per logical TILE_S1 by accumulating sub-PV + # matmuls into the same accumulator (manual C++ semantic). + pv_slot_ty = pto.TensorType(shape=[S0, HEAD], dtype=fp32) + pv_slot_part_ty = pto.SubTensorType(shape=[S0, HEAD], dtype=fp32) + pv_vec_slot_ty = pto.TensorType(shape=[VecGuRows, HEAD], dtype=fp32) + # GU also runs per-row_slice so each pop returns the full subblock view + # but the actual TLOAD targets a [Vec_S0, HEAD] row-slice partition. + pv_vec_slot_part_ty = pto.SubTensorType(shape=[Vec_S0, HEAD], dtype=fp32) + # P slot (vec -> cube, fp16 [S0, S1_TILE]). Vec produces the FULL S1_TILE- + # wide softmax tile across TILE_FACTOR row_slices: one [Vec_S0, S1_TILE] + # store per slice. Cube consumes via TILE_FACTOR sub-loads of [S0, CUBE_S1] + # halves so that each PV matmul matches its CUBE_S1 wide accumulator slot. + p_slot_ty = pto.TensorType(shape=[VecGuRows, S1_TILE], dtype=fp16) + p_slot_part_ty = pto.SubTensorType(shape=[Vec_S0, S1_TILE], dtype=fp16) + p_cube_slot_ty = pto.TensorType(shape=[S0, S1_TILE], dtype=fp16) + p_cube_slot_part_ty = pto.SubTensorType(shape=[S0, CUBE_S1], dtype=fp16) + + # ---- Cube tiles (L1 / L0A / L0B / L0C). ---- + # Cube tiles size to CUBE_S1 (the matmul subtile width); the wider TILE_S1 + # only shows up in the GM-staged FIFO slot, where TILE_FACTOR sub-tiles are + # stored into one slot before TPUSH. + q_mat_ty = pto.TileBufType(shape=[S0, HEAD], dtype=fp16, memory_space="MAT") + q_left_ty = pto.TileBufType(shape=[S0, HEAD], dtype=fp16, memory_space="LEFT") + k_mat_ty = pto.TileBufType( + shape=[HEAD, CUBE_S1], + dtype=fp16, + memory_space="MAT", + config=pto.TileBufConfig(blayout="RowMajor", slayout="ColMajor"), + ) + k_right_ty = pto.TileBufType( + shape=[HEAD, CUBE_S1], dtype=fp16, memory_space="RIGHT" + ) + qk_acc_ty = pto.TileBufType(shape=[S0, CUBE_S1], dtype=fp32, memory_space="ACC") + p_recv_ty = pto.TileBufType(shape=[S0, CUBE_S1], dtype=fp16, memory_space="MAT") + p_left_ty = pto.TileBufType(shape=[S0, CUBE_S1], dtype=fp16, memory_space="LEFT") + v_mat_ty = pto.TileBufType(shape=[CUBE_S1, HEAD], dtype=fp16, memory_space="MAT") + v_right_ty = pto.TileBufType( + shape=[CUBE_S1, HEAD], dtype=fp16, memory_space="RIGHT" + ) + pv_acc_ty = pto.TileBufType(shape=[S0, HEAD], dtype=fp32, memory_space="ACC") + + # ---- Vec tiles (UB). The QK softmax stage uses Vec_S0 rows per row_slice + # iteration (manual alignment), while PV/GU stages use the full VecGuRows + # row count of the subblock. + qk_vec_ty = pto.TileBufType( + shape=[Vec_S0, S1_TILE], dtype=fp32, memory_space="VEC" + ) + p_fp32_ty = pto.TileBufType( + shape=[Vec_S0, S1_TILE], dtype=fp32, memory_space="VEC" + ) + p_fp16_ty = pto.TileBufType( + shape=[Vec_S0, S1_TILE], dtype=fp16, memory_space="VEC" + ) + pv_vec_ty = pto.TileBufType( + shape=[Vec_S0, HEAD], dtype=fp32, memory_space="VEC" + ) + o_vec_ty = pto.TileBufType(shape=[Vec_S0, HEAD], dtype=fp32, memory_space="VEC") + + # Reduction tile (per-row scalar). Per-slice = [Vec_S0, 1]; running state + # is held as a list of TILE_FACTOR red tiles, one per row_slice. + red_ty = pto.TileBufType( + shape=[Vec_S0, 1], + dtype=fp32, + memory_space="VEC", + config=pto.TileBufConfig(blayout="ColMajor", slayout="NoneBox"), + ) + red_row_ty = pto.TileBufType( + shape=[1, Vec_S0], dtype=fp32, memory_space="VEC" + ) + + return locals() + + +# ============================================================================= +# Module +# ============================================================================= +@to_ir_module(meta_data=meta_data, module=True) +def module(): + + # ------------------------------------------------------------------------- + # Helper: even share of NUM_Q_BLOCKS across this core grid. + # The C++ kernel uses one Q-row block per AIC core (block_idx -> Q rows); + # in DSL we let the launcher choose blockDim and split inside. + # ------------------------------------------------------------------------- + def compute_qb_range(c1): + cNUM_Q_BLOCKS = const(NUM_Q_BLOCKS) + num_blocks = s.index_cast(pto.get_block_num()) + bid = s.index_cast(pto.get_block_idx()) + floor_div = cNUM_Q_BLOCKS // num_blocks + extra = cNUM_Q_BLOCKS % num_blocks + fat_start = bid * (floor_div + c1) + thin_start = extra * (floor_div + c1) + (bid - extra) * floor_div + qb_start = s.select(bid < extra, fat_start, thin_start) + per_core = s.select(bid < extra, floor_div + c1, floor_div) + return bid, qb_start, qb_start + per_core + + # ========================================================================= + # Cube kernel + # ========================================================================= + @pto.func(kernel="cube") + def cube_kernel( + gm_slot_buffer: "ptr_fp32", + gm_slot_buffer_fp16: "ptr_fp16", + gm_q: "ptr_fp16", + gm_k: "ptr_fp16", + gm_v: "ptr_fp16", + s0_i64: "i64", + s1_i64: "i64", + ) -> None: + c0 = const(0) + c1 = const(1) + c2 = const(2) + cS0 = const(S0) + cHEAD = const(HEAD) + cCUBE_S1 = const(CUBE_S1) + cS1_TILE = const(S1_TILE) + cPRELOAD = const(QK_PRELOAD) + s0 = s.index_cast(s0_i64) + s1 = s.index_cast(s1_i64) + num_tiles_s1 = s1 // cS1_TILE + steady_tiles = num_tiles_s1 - cPRELOAD + + bid, qb_start, qb_end = compute_qb_range(c1) + + gm_blk = pto.add_ptr(gm_slot_buffer, bid * const(GM_ELEMS_PER_BLOCK)) + gm_qk = pto.add_ptr(gm_blk, const(GM_QK_OFF_F32)) + gm_pv = pto.add_ptr(gm_blk, const(GM_PV_OFF_F32)) + # The P slot is fp16-typed, so address it via the fp16-cast slot buffer. + # GM_P_OFF_F32 is in fp32 elements; double for fp16 element stride. + gm_blk_fp16 = pto.add_ptr(gm_slot_buffer_fp16, bid * const(2 * GM_ELEMS_PER_BLOCK)) + gm_p = pto.add_ptr(gm_blk_fp16, const(2 * GM_P_OFF_F32)) + + # ---- QK pipe (cube producer): l2g2l GM-staged slot ---- + qk_slot_view = pto.as_tensor( + qk_slot_ty, + ptr=gm_qk, + shape=[cS0, cS1_TILE], + strides=[cS1_TILE, c1], + ) + qk_pipe = pto.initialize_l2g2l_pipe( + dir_mask=1, + slot_size=SLOT_SIZE_QK, + slot_num=SLOT_NUM, + gm_addr=qk_slot_view, + flag_base=0, + ) + + # ---- PV pipe (cube producer): l2g2l GM-staged slot ---- + pv_slot_view = pto.as_tensor( + pv_slot_ty, + ptr=gm_pv, + shape=[cS0, cHEAD], + strides=[cHEAD, c1], + ) + pv_pipe = pto.initialize_l2g2l_pipe( + dir_mask=1, + slot_size=SLOT_SIZE_PV, + slot_num=SLOT_NUM, + gm_addr=pv_slot_view, + flag_base=4, + ) + + # ---- P pipe (cube consumer of vec output): l2g2l GM-staged slot ---- + p_slot_view_cube = pto.as_tensor( + p_cube_slot_ty, + ptr=gm_p, + shape=[cS0, cS1_TILE], + strides=[cS1_TILE, c1], + ) + p_pipe = pto.initialize_l2g2l_pipe( + dir_mask=2, + slot_size=SLOT_SIZE_P, + slot_num=SLOT_NUM, + gm_addr=p_slot_view_cube, + flag_base=2, + ) + + # ---- Allocate cube tiles. Match the manual kernel's ping-pong for + # K/P/V MAT tiles where L1 capacity allows it. RIGHT is single-buffered + # because two 128x128 RIGHT tiles for both QK and PV overflow L0B. + q_mat = pto.alloc_tile(q_mat_ty) + q_left = pto.alloc_tile(q_left_ty) + k_mat_a = pto.alloc_tile(k_mat_ty) + k_mat_b = pto.alloc_tile(k_mat_ty) + k_right_a = pto.alloc_tile(k_right_ty) + qk_acc_a = pto.alloc_tile(qk_acc_ty) + p_recv_a = pto.alloc_tile(p_recv_ty) + p_left_a = pto.alloc_tile(p_left_ty) + v_mat_a = pto.alloc_tile(v_mat_ty) + v_right_a = pto.alloc_tile(v_right_ty) + pv_acc_a = pto.alloc_tile(pv_acc_ty) + k_mat = [k_mat_a, k_mat_b] + k_right = [k_right_a, k_right_a] + qk_acc = [qk_acc_a, qk_acc_a] + p_recv = [p_recv_a, p_recv_a] + p_left = [p_left_a, p_left_a] + v_mat = [v_mat_a, v_mat_a] + v_right = [v_right_a, v_right_a] + pv_acc = [pv_acc_a, pv_acc_a] + + tv_q = pto.as_tensor( + qkv_tensor_ty, ptr=gm_q, shape=[s0, cHEAD], strides=[cHEAD, c1] + ) + tv_k = pto.as_tensor( + qkv_tensor_ty, + ptr=gm_k, + shape=[cHEAD, s1], + strides=[c1, cHEAD], + layout="DN", + ) + tv_v = pto.as_tensor( + qkv_tensor_ty, ptr=gm_v, shape=[s1, cHEAD], strides=[cHEAD, c1] + ) + + qk_entry = pto.declare_global(qk_slot_ty) + p_entry = pto.declare_global(p_cube_slot_ty) + pv_entry = pto.declare_global(pv_slot_ty) + + # Closures over the shared tile state. The steady state overlaps PV for + # the current S1 tile with QK for the next S1 tile at CUBE_S1 granularity. + def emit_qk_sub(s1_tile_idx, sub, b): + kt_view = pto.slice_view( + kt_sub_ty, + source=tv_k, + offsets=[c0, s1_tile_idx * cS1_TILE + const(sub * CUBE_S1)], + sizes=[cHEAD, cCUBE_S1], + ) + pto.load(kt_view, k_mat[b]) + tile.mov(k_mat[b], k_right[b]) + tile.matmul(q_left, k_right[b], qk_acc[b]) + slot_part = pto.slice_view( + qk_slot_part_ty, + source=qk_entry, + offsets=[c0, const(sub * CUBE_S1)], + sizes=[cS0, cCUBE_S1], + ) + pto.store(qk_acc[b], slot_part) + + def emit_qk(s1_tile_idx, b): + pto.talloc(qk_entry, qk_pipe, SPLIT_UP_DOWN) + for sub in range(TILE_FACTOR): + emit_qk_sub(s1_tile_idx, sub, b) + pto.tpush(qk_entry, qk_pipe, SPLIT_UP_DOWN) + + def emit_pv_sub(t_idx, sub, b): + p_part = pto.slice_view( + p_cube_slot_part_ty, + source=p_entry, + offsets=[c0, const(sub * CUBE_S1)], + sizes=[cS0, cCUBE_S1], + ) + pto.load(p_part, p_recv[b]) + tile.mov(p_recv[b], p_left[b]) + v_view = pto.slice_view( + v_sub_ty, + source=tv_v, + offsets=[t_idx * cS1_TILE + const(sub * CUBE_S1), c0], + sizes=[cCUBE_S1, cHEAD], + ) + pto.load(v_view, v_mat[b]) + tile.mov(v_mat[b], v_right[b]) + if sub == 0: + tile.matmul(p_left[b], v_right[b], pv_acc[b]) + else: + tile.matmul_acc(pv_acc[b], p_left[b], v_right[b], pv_acc[b]) + + def push_pv(b): + pto.tfree(p_pipe, SPLIT_UP_DOWN, entry=p_entry) + pto.talloc(pv_entry, pv_pipe, SPLIT_UP_DOWN) + pv_part = pto.slice_view( + pv_slot_part_ty, + source=pv_entry, + offsets=[c0, c0], + sizes=[cS0, cHEAD], + ) + pto.store(pv_acc[b], pv_part) + pto.tpush(pv_entry, pv_pipe, SPLIT_UP_DOWN) + + def emit_pv(t_idx, b): + pto.tpop_into(p_entry, p_pipe, SPLIT_UP_DOWN) + for sub in range(TILE_FACTOR): + emit_pv_sub(t_idx, sub, b) + push_pv(b) + + def emit_qk_pv_interleaved(next_idx, current_idx, b): + pto.tpop_into(p_entry, p_pipe, SPLIT_UP_DOWN) + for sub in range(TILE_FACTOR): + emit_pv_sub(current_idx, sub, b) + if sub == 0: + pto.talloc(qk_entry, qk_pipe, SPLIT_UP_DOWN) + if sub == TILE_FACTOR - 1: + push_pv(b) + emit_qk_sub(next_idx, sub, b) + if sub == TILE_FACTOR - 1: + pto.tpush(qk_entry, qk_pipe, SPLIT_UP_DOWN) + + # ---- Q-block loop ---- + for qb in pto.range(qb_start, qb_end, c1): + q_view = pto.slice_view( + q_sub_ty, + source=tv_q, + offsets=[qb * cS0, c0], + sizes=[cS0, cHEAD], + ) + pto.load(q_view, q_mat) + tile.mov(q_mat, q_left) + + # ---- prologue: emit QK[0..QK_PRELOAD-1] ------------------------- + # V loading is now inline in emit_pv (per-sub-tile), so no preload. + for kp in range(QK_PRELOAD): + emit_qk(const(kp), kp % 2) + + # ---- steady state ------------------------------------------------ + # Match the 140tflops schedule: consume current P/PV and emit the + # next QK slot at CUBE_S1 sub-tile granularity. + for base in pto.range(c0, steady_tiles, cPRELOAD): + emit_qk_pv_interleaved(base + cPRELOAD + const(0), base + const(0), 0) + emit_qk_pv_interleaved(base + cPRELOAD + const(1), base + const(1), 1) + emit_qk_pv_interleaved(base + cPRELOAD + const(2), base + const(2), 0) + emit_qk_pv_interleaved(base + cPRELOAD + const(3), base + const(3), 1) + + # ---- epilogue: drain the last QK_PRELOAD PVs ------------------- + for k in range(QK_PRELOAD): + b = k % 2 + t_idx = steady_tiles + const(k) + emit_pv(t_idx, b) + + # ========================================================================= + # Vector kernel + # ========================================================================= + @pto.func(kernel="vector") + def vector_kernel( + gm_slot_buffer: "ptr_fp32", + gm_slot_buffer_fp16: "ptr_fp16", + gm_o: "ptr_fp32", + s0_i64: "i64", + s1_i64: "i64", + ) -> None: + c0 = const(0) + c1 = const(1) + cS0 = const(S0) + cS0_HALF = const(S0_HALF) + cVecGuRows = const(VecGuRows) + cVec_S0 = const(Vec_S0) + cHEAD = const(HEAD) + cS1_TILE = const(S1_TILE) + cSLOT_NUM = const(SLOT_NUM) + cPRELOAD = const(QK_PRELOAD) + s0 = s.index_cast(s0_i64) + s1 = s.index_cast(s1_i64) + num_tiles_s1 = s1 // cS1_TILE + steady_tiles = num_tiles_s1 - cPRELOAD + + bid, qb_start, qb_end = compute_qb_range(c1) + + gm_blk = pto.add_ptr(gm_slot_buffer, bid * const(GM_ELEMS_PER_BLOCK)) + gm_qk = pto.add_ptr(gm_blk, const(GM_QK_OFF_F32)) + gm_pv = pto.add_ptr(gm_blk, const(GM_PV_OFF_F32)) + gm_blk_fp16 = pto.add_ptr(gm_slot_buffer_fp16, bid * const(2 * GM_ELEMS_PER_BLOCK)) + gm_p = pto.add_ptr(gm_blk_fp16, const(2 * GM_P_OFF_F32)) + + # ---- QK pipe (vec consumer): l2g2l GM-staged slot ---- + # Vec sees one slot as [VecGuRows, S1_TILE] -- SPLIT_UP_DOWN halves + # the row count when crossing into the subblock; per row_slice we + # tload a [Vec_S0, S1_TILE] partition. + qk_slot_view = pto.as_tensor( + qk_vec_slot_ty, + ptr=gm_qk, + shape=[cVecGuRows, cS1_TILE], + strides=[cS1_TILE, c1], + ) + qk_pipe = pto.initialize_l2g2l_pipe( + dir_mask=1, + slot_size=SLOT_SIZE_QK, + slot_num=SLOT_NUM, + gm_addr=qk_slot_view, + flag_base=0, + ) + # ---- PV pipe (vec consumer): l2g2l GM-staged slot ---- + pv_slot_view = pto.as_tensor( + pv_vec_slot_ty, + ptr=gm_pv, + shape=[cVecGuRows, cHEAD], + strides=[cHEAD, c1], + ) + pv_pipe = pto.initialize_l2g2l_pipe( + dir_mask=1, + slot_size=SLOT_SIZE_PV, + slot_num=SLOT_NUM, + gm_addr=pv_slot_view, + flag_base=4, + ) + + # ---- P pipe (vec producer): l2g2l GM-staged slot ---- + p_slot_view = pto.as_tensor( + p_slot_ty, + ptr=gm_p, + shape=[cVecGuRows, cS1_TILE], + strides=[cS1_TILE, c1], + ) + p_pipe = pto.initialize_l2g2l_pipe( + dir_mask=2, + slot_size=SLOT_SIZE_P, + slot_num=SLOT_NUM, + gm_addr=p_slot_view, + flag_base=2, + ) + + # ---- Vec tile allocations. + # Per-slice working tiles are reused across the row_slice loop (each + # iter overwrites the previous), so a single allocation per type is + # enough. Reduce/state tiles are per-row_slice arrays because each + # row_slice tracks its own running_max/running_sum independently. + qk_vec = pto.alloc_tile(qk_vec_ty) # [Vec_S0, S1_TILE] working + tmp = pto.alloc_tile(qk_vec_ty) # [Vec_S0, S1_TILE] row-reduce scratch + p_fp32 = pto.alloc_tile(p_fp32_ty) + p_fp16 = pto.alloc_tile(p_fp16_ty) + pv_vec = [pto.alloc_tile(pv_vec_ty) for _ in range(TILE_FACTOR)] + o_tile = [pto.alloc_tile(o_vec_ty) for _ in range(TILE_FACTOR)] + running_max = [pto.alloc_tile(red_ty) for _ in range(TILE_FACTOR)] + running_sum = [pto.alloc_tile(red_ty) for _ in range(TILE_FACTOR)] + local_max = [pto.alloc_tile(red_ty) for _ in range(TILE_FACTOR)] + local_sum = [pto.alloc_tile(red_ty) for _ in range(TILE_FACTOR)] + # Manual uses an 8-slot exp_max FIFO. With QK_PRELOAD=4, softmax(t+4) + # and gu(t) hit different slots for every steady tile. + exp_max_ring = [ + [pto.alloc_tile(red_ty) for _ in range(TILE_FACTOR)] + for _ in range(SLOT_NUM) + ] + + scale = const(1.0 / math.sqrt(HEAD), s.float32) + + sb_idx = s.index_cast(pto.get_subblock_idx()) + row_off_sb = sb_idx * cS0_HALF + + tv_o = pto.as_tensor( + o_tensor_ty, ptr=gm_o, shape=[s0, cHEAD], strides=[cHEAD, c1] + ) + + qk_entry = pto.declare_global(qk_vec_slot_ty) + p_entry = pto.declare_global(p_slot_ty) + pv_entry = pto.declare_global(pv_vec_slot_ty) + + # ---- emit_softmax(exp_max_slot, is_init): one streaming softmax ------ + # Translates pto_macro_fa_softmax: row_max on unscaled QK -> row diff + # -> scale -> stream-update (running_max, running_sum) -> exp -> cvt + # -> push P. Keeping running_max unscaled matches the manual macro. + def emit_softmax(exp_max_slots, is_init): + # Pop the wide QK slot (full subblock) and talloc one wide P slot; + # iterate TILE_FACTOR row_slices, doing per-slice softmax math on + # [Vec_S0, S1_TILE] tiles and per-slice reduce state. After all + # row_slices, push the wide P slot. + pto.tpop_into(qk_entry, qk_pipe, SPLIT_UP_DOWN) + pto.talloc(p_entry, p_pipe, SPLIT_UP_DOWN) + for row_slice in range(TILE_FACTOR): + slot_part = pto.slice_view( + qk_vec_slot_part_ty, + source=qk_entry, + offsets=[const(row_slice * Vec_S0), c0], + sizes=[cVec_S0, cS1_TILE], + ) + pto.load(slot_part, qk_vec) + qk = qk_vec + lmax = local_max[row_slice] + lsum = local_sum[row_slice] + rmax = running_max[row_slice] + rsum = running_sum[row_slice] + exp_slot = exp_max_slots[row_slice] + tile.row_max(qk, tmp, lmax) + + # Reshape reductions to row-major so scalar broadcast helpers work. + local_max_r = tile.reshape(red_row_ty, lmax) + running_max_r = tile.reshape(red_row_ty, rmax) + running_sum_r = tile.reshape(red_row_ty, rsum) + local_sum_r = tile.reshape(red_row_ty, lsum) + exp_max_r = tile.reshape(red_row_ty, exp_slot) + + if is_init: + tile.row_expand_sub(qk, lmax, p_fp32) + tile.mov(local_max_r, running_max_r) + tile.muls(p_fp32, scale, p_fp32) + tile.exp(p_fp32, p_fp32) + tile.row_sum(p_fp32, tmp, rsum) + else: + tile.max(local_max_r, running_max_r, local_max_r) + tile.sub(running_max_r, local_max_r, exp_max_r) + tile.mov(local_max_r, running_max_r) + tile.row_expand_sub(qk, lmax, p_fp32) + tile.muls(exp_max_r, scale, exp_max_r) + tile.muls(p_fp32, scale, p_fp32) + tile.exp(exp_max_r, exp_max_r) + tile.exp(p_fp32, p_fp32) + tile.mul(running_sum_r, exp_max_r, running_sum_r) + tile.row_sum(p_fp32, tmp, lsum) + tile.add(running_sum_r, local_sum_r, running_sum_r) + + tile.cvt(p_fp32, p_fp16) + p_part = pto.slice_view( + p_slot_part_ty, + source=p_entry, + offsets=[const(row_slice * Vec_S0), c0], + sizes=[cVec_S0, cS1_TILE], + ) + pto.store(p_fp16, p_part) + pto.tpush(p_entry, p_pipe, SPLIT_UP_DOWN) + pto.tfree(qk_pipe, SPLIT_UP_DOWN, entry=qk_entry) + + # ---- emit_gu(exp_max_slots, is_init): rescale + add running O ------ + # GU also runs per-row_slice: each row_slice owns its own o_tile and + # pv_vec, indexed by the same exp_max_slots used during softmax. + def emit_gu(exp_max_slots, is_init): + pto.tpop_into(pv_entry, pv_pipe, SPLIT_UP_DOWN) + for row_slice in range(TILE_FACTOR): + pv_part = pto.slice_view( + pv_vec_slot_part_ty, + source=pv_entry, + offsets=[const(row_slice * Vec_S0), c0], + sizes=[cVec_S0, cHEAD], + ) + pto.load(pv_part, pv_vec[row_slice]) + if is_init: + tile.mov(pv_vec[row_slice], o_tile[row_slice]) + else: + tile.row_expand_mul( + o_tile[row_slice], + exp_max_slots[row_slice], + o_tile[row_slice], + ) + tile.add(o_tile[row_slice], pv_vec[row_slice], o_tile[row_slice]) + pto.tfree(pv_pipe, SPLIT_UP_DOWN, entry=pv_entry) + + for qb in pto.range(qb_start, qb_end, c1): + # ---- vec prologue: softmax(0..QK_PRELOAD-1) -------------------- + for kp in range(QK_PRELOAD): + emit_softmax(exp_max_ring[kp], is_init=(kp == 0)) + + # ---- vec steady state. Match the 140tflops order: drain the + # current PV/GU tile before producing the future P tile. + with pto.if_context(steady_tiles > c0): + emit_gu(exp_max_ring[0], is_init=True) + emit_softmax(exp_max_ring[4], is_init=False) + emit_gu(exp_max_ring[1], is_init=False) + emit_softmax(exp_max_ring[5], is_init=False) + emit_gu(exp_max_ring[2], is_init=False) + emit_softmax(exp_max_ring[6], is_init=False) + emit_gu(exp_max_ring[3], is_init=False) + emit_softmax(exp_max_ring[7], is_init=False) + + for base in pto.range(cPRELOAD, steady_tiles, cPRELOAD): + with pto.if_context((base % cSLOT_NUM) == c0, has_else=True) as ring_branch: + emit_gu(exp_max_ring[0], is_init=False) + emit_softmax(exp_max_ring[4], is_init=False) + emit_gu(exp_max_ring[1], is_init=False) + emit_softmax(exp_max_ring[5], is_init=False) + emit_gu(exp_max_ring[2], is_init=False) + emit_softmax(exp_max_ring[6], is_init=False) + emit_gu(exp_max_ring[3], is_init=False) + emit_softmax(exp_max_ring[7], is_init=False) + with ring_branch.else_context(): + emit_gu(exp_max_ring[4], is_init=False) + emit_softmax(exp_max_ring[0], is_init=False) + emit_gu(exp_max_ring[5], is_init=False) + emit_softmax(exp_max_ring[1], is_init=False) + emit_gu(exp_max_ring[6], is_init=False) + emit_softmax(exp_max_ring[2], is_init=False) + emit_gu(exp_max_ring[7], is_init=False) + emit_softmax(exp_max_ring[3], is_init=False) + + # ---- vec epilogue: drain last QK_PRELOAD gus ------------------- + with pto.if_context(steady_tiles == c0, has_else=True) as branch: + for k in range(QK_PRELOAD): + slot = exp_max_ring[k] + emit_gu(slot, is_init=(k == 0)) + with branch.else_context(): + with pto.if_context((steady_tiles % cSLOT_NUM) == c0, has_else=True) as drain_branch: + for k in range(QK_PRELOAD): + emit_gu(exp_max_ring[k], is_init=False) + with drain_branch.else_context(): + for k in range(QK_PRELOAD): + emit_gu(exp_max_ring[QK_PRELOAD + k], is_init=False) + + # Final divide + GM store, one row_slice at a time. + for row_slice in range(TILE_FACTOR): + tile.row_expand_div( + o_tile[row_slice], + running_sum[row_slice], + o_tile[row_slice], + ) + o_view = pto.slice_view( + o_sub_slice_ty, + source=tv_o, + offsets=[ + qb * cS0 + row_off_sb + const(row_slice * Vec_S0), + c0, + ], + sizes=[cVec_S0, cHEAD], + ) + pto.store(o_tile[row_slice], o_view) + + # ========================================================================= + # Entry point invoked by the host caller via <<<>>> + # ========================================================================= + @pto.func + def call_both( + ffts_addr: "ffts_ty", + gm_slot_buffer: "ptr_fp32", + gm_slot_buffer_fp16: "ptr_fp16", + gm_q: "ptr_fp16", + gm_k: "ptr_fp16", + gm_v: "ptr_fp16", + gm_o: "ptr_fp32", + s0_i64: "i64", + s1_i64: "i64", + ) -> None: + pto.set_ffts(ffts_addr) + pto.call(cube_kernel, gm_slot_buffer, gm_slot_buffer_fp16, gm_q, gm_k, gm_v, s0_i64, s1_i64) + pto.call(vector_kernel, gm_slot_buffer, gm_slot_buffer_fp16, gm_o, s0_i64, s1_i64) + + +if __name__ == "__main__": + print(module.operation.get_asm(print_generic_op_form=True)) diff --git a/kernels/python/flash_atten/run.py b/kernels/python/flash_atten/run.py new file mode 100644 index 000000000..68ea9beb6 --- /dev/null +++ b/kernels/python/flash_atten/run.py @@ -0,0 +1,461 @@ +#!/usr/bin/env python3 +""" +Copyright (c) 2026 Huawei Technologies Co., Ltd. +This program is free software, you can redistribute it and/or modify it under +the terms and conditions of CANN Open Software License Agreement Version 2.0 +(the "License"). Please refer to the License for details. You may not use this +file except in compliance with the License. + +Runner for the pto-dsl flash-attention port. By default this invokes +`bash compile.sh` for the current `FA_Q_ROWS` (total Q seqlen), loads the +resulting .so, then benchmarks each `FA_BENCH_LENGTHS` entry as the KV seqlen +S1. Pass `--no-build` to reuse an existing `build_artifacts/fa.so`. + +When `FA_Q_ROWS` and `FA_BENCH_LENGTHS` are not set, the default benchmark +matrix matches `fa_benchmark_shapes.md` and is named case1..case8. +""" + +import argparse +import csv +import ctypes +import datetime +import math +import os +import subprocess +import sys + +import torch +import torch_npu # noqa: F401 -- registers the npu backend + +THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +sys.path.insert(0, os.path.join(THIS_DIR, "kernels")) +import fa_builder # noqa: E402 + +from ptodsl import do_bench # noqa: E402 +from ptodsl.utils.npu_info import get_num_cube_cores, get_test_device # noqa: E402 + +ARTIFACT_DIR = os.path.join(THIS_DIR, "build_artifacts") +COMPILE_SCRIPT = os.path.join(THIS_DIR, "compile.sh") +SUMMARY_TSV = os.environ.get("FA_SUMMARY_TSV", "") +CASE_ID = os.environ.get("FA_CASE_ID", "direct") + +# Default single-case length. The top-level default path runs DEFAULT_BENCH_CASES +# unless FA_Q_ROWS or FA_BENCH_LENGTHS is explicitly set. +DEFAULT_BENCH_LENGTHS = (1024,) +DEFAULT_BENCH_CASES = ( + ("case1", 1024, 1024), + ("case2", 2048, 2048), + ("case3", 4096, 4096), + ("case4", 8192, 8192), + ("case5", 16384, 16384), + ("case6", 32768, 32768), + ("case7", 65536, 65536), + ("case8", 131072, 131072), +) + +ATOL = 1e-3 +RTOL = 1e-3 +# Skip the host fp32 reference once the QK matrix would exceed ~256M fp32 +# elements (1 GiB host RAM). Above this we still assert against the NPU fused +# reference (torch_npu.npu_fused_infer_attention_score), with looser tolerance. +HOST_REF_MAX_QK_ELEMS = 256 * 1024 * 1024 +ATOL_FUSED_ONLY = 5e-3 +RTOL_FUSED_ONLY = 5e-3 + + +def _manual_target_summary(): + return ( + "manual target: S0={s0} HEAD={head} CUBE_S0={cube_s0} " + "CUBE_S1={cube_s1} TILE_S1={tile_s1} QK_PRELOAD={preload} " + "causal={causal}" + ).format( + s0=fa_builder.MANUAL_S0, + head=fa_builder.MANUAL_HEAD, + cube_s0=fa_builder.MANUAL_CUBE_S0, + cube_s1=fa_builder.MANUAL_CUBE_S1, + tile_s1=fa_builder.MANUAL_TILE_S1, + preload=fa_builder.MANUAL_QK_PRELOAD, + causal=fa_builder.MANUAL_CAUSAL_MASK, + ) + + +def _dsl_effective_summary(): + return ( + "dsl effective: Q_ROWS={q_rows} HEAD={head} CUBE_S0={cube_s0} " + "S1_TILE={s1_tile} QK_PRELOAD={preload} NUM_Q_BLOCKS={q_blocks}" + ).format( + q_rows=fa_builder.Q_ROWS, + head=fa_builder.HEAD, + cube_s0=fa_builder.S0, + s1_tile=fa_builder.S1_TILE, + preload=fa_builder.QK_PRELOAD, + q_blocks=fa_builder.NUM_Q_BLOCKS, + ) + + +def _bench_lengths(): + raw = os.environ.get("FA_BENCH_LENGTHS", "") + if not raw: + return DEFAULT_BENCH_LENGTHS + return tuple(int(x) for x in raw.split(",") if x.strip()) + + +def _bench_iters(): + return int(os.environ.get("FA_BENCH_ITERS", "100")) + + +def _bench_warmup(): + return int(os.environ.get("FA_BENCH_WARMUP", "10")) + + +def _default_suite_requested(args): + return args.case is not None or ("FA_Q_ROWS" not in os.environ and "FA_BENCH_LENGTHS" not in os.environ) + + +def _selected_default_cases(case_arg): + if not case_arg: + return DEFAULT_BENCH_CASES + + lookup = {case_id: (case_id, s0, s1) for case_id, s0, s1 in DEFAULT_BENCH_CASES} + selected = [] + for name in (part.strip() for part in case_arg.split(",")): + if not name: + continue + if name not in lookup: + valid = ", ".join(case_id for case_id, _, _ in DEFAULT_BENCH_CASES) + raise ValueError("unknown case '{}'; valid cases: {}".format(name, valid)) + selected.append(lookup[name]) + return tuple(selected) + + +def _default_summary_tsv(): + stamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + return os.path.join(ARTIFACT_DIR, "fa_summary_{}_{}.tsv".format(stamp, os.getpid())) + + +def _summary_fields(): + return [ + "case_id", + "q_rows", + "seq_len", + "tiles", + "status", + "fa_us", + "fa_tflops", + "fused_us", + "fused_tflops", + "speedup", + "err_kernel", + "err_fused", + "ref", + "note", + ] + + +def _append_summary_row(row): + if not SUMMARY_TSV: + return + exists = os.path.exists(SUMMARY_TSV) + summary_dir = os.path.dirname(SUMMARY_TSV) + if summary_dir: + os.makedirs(summary_dir, exist_ok=True) + with open(SUMMARY_TSV, "a", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=_summary_fields(), delimiter="\t") + if not exists: + writer.writeheader() + writer.writerow({key: row.get(key, "") for key in _summary_fields()}) + + +def _num_tiles(seq_len): + if seq_len % fa_builder.S1_TILE != 0: + raise ValueError("seq_len {} is not a multiple of S1_TILE={}".format(seq_len, fa_builder.S1_TILE)) + num_tiles = seq_len // fa_builder.S1_TILE + if num_tiles < fa_builder.QK_PRELOAD or (num_tiles - fa_builder.QK_PRELOAD) % fa_builder.QK_PRELOAD != 0: + raise ValueError( + "seq_len {} maps to {} tiles; the current QK_PRELOAD={} runtime loop requires " + "num_tiles >= QK_PRELOAD and (num_tiles - QK_PRELOAD) % QK_PRELOAD == 0".format( + seq_len, num_tiles, fa_builder.QK_PRELOAD + ) + ) + return num_tiles + + +def _lib_path(): + return os.path.join(ARTIFACT_DIR, "fa.so") + + +def _require_lib(): + p = _lib_path() + if not os.path.exists(p): + raise FileNotFoundError( + "Missing prebuilt .so: {}\n" + " Run `bash compile.sh` (or `python3 run.py` without --no-build) for the current " + "FA_Q_ROWS={} to build the runtime-S1 kernel.".format(p, fa_builder.Q_ROWS) + ) + return p + + +def _build_lib(): + print("[fa] compiling PTODSL flash kernel...") + subprocess.run(["bash", COMPILE_SCRIPT], cwd=THIS_DIR, check=True) + return _require_lib() + + +def _load_lib(path): + lib = ctypes.CDLL(path) + lib.call_kernel.argtypes = [ctypes.c_uint32] + [ctypes.c_void_p] * 6 + [ctypes.c_int64, ctypes.c_int64] + lib.call_kernel.restype = None + return lib + + +def _to_void_p(t): + return ctypes.c_void_p(t.data_ptr()) + + +def _block_dim(): + return min(fa_builder.NUM_Q_BLOCKS, get_num_cube_cores()) + + +def _slot_elems(block_dim): + return fa_builder.GM_ELEMS_PER_BLOCK * block_dim + + +def attn_flops_matmul_softmax_scale( + batch_size, s_q, s_k, h, include_scale=True, count_exp_as_flop=True, count_max_as_flop=True +): + flops_matmul = 4 * batch_size * s_q * s_k * h + flops_scale = batch_size * s_q * s_k if include_scale else 0 + + rows = batch_size * s_q + softmax_ops = 0 + if count_max_as_flop: + softmax_ops += rows * (s_k - 1) + softmax_ops += rows * s_k + if count_exp_as_flop: + softmax_ops += rows * s_k + softmax_ops += rows * (s_k - 1) + softmax_ops += rows * s_k + + return flops_matmul + flops_scale + softmax_ops + + +def tflops(flops, ms): + return flops / (ms * 1e-3) / 1e12 + + +def fa_reference(q, k, v): + """Plain torch fp32 reference: O = softmax(QK^T * scale) @ V.""" + scale = 1.0 / math.sqrt(q.shape[1]) + scores = q.float() @ k.float().T * scale + return (torch.softmax(scores, dim=-1) @ v.float()).float() + + +def fused_attention(q, k, v): + """torch_npu fused reference for benchmarking the speedup ratio.""" + scale = 1.0 / math.sqrt(q.shape[1]) + out, _ = torch_npu.npu_fused_infer_attention_score( + q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), num_heads=1, input_layout="BSH", scale=scale, next_tokens=65535 + ) + return out.squeeze(0) + + +def _alloc_io(seq_len, device): + Q_ROWS = fa_builder.Q_ROWS + HEAD = fa_builder.HEAD + block_dim = _block_dim() + + q = torch.randn((Q_ROWS, HEAD), dtype=torch.float16, device=device) + k = torch.randn((seq_len, HEAD), dtype=torch.float16, device=device) + v = torch.randn((seq_len, HEAD), dtype=torch.float16, device=device) + gm_slot = torch.zeros((_slot_elems(block_dim),), dtype=torch.float32, device=device) + o = torch.zeros((Q_ROWS, HEAD), dtype=torch.float32, device=device) + return q, k, v, gm_slot, o, block_dim + + +def _invoke(lib, block_dim, gm_slot, q, k, v, o, stream_ptr): + lib.call_kernel( + block_dim, + stream_ptr, + _to_void_p(gm_slot), + _to_void_p(q), + _to_void_p(k), + _to_void_p(v), + _to_void_p(o), + q.shape[0], + k.shape[0], + ) + + +def benchmark(lib, device, num_tiles, warmup, iters): + torch.manual_seed(0) + seq_len = fa_builder.S1_TILE * num_tiles + q, k, v, gm_slot, o, block_dim = _alloc_io(seq_len, device) + stream_ptr = torch.npu.current_stream()._as_parameter_ + + def run_kernel(): + _invoke(lib, block_dim, gm_slot, q, k, v, o, stream_ptr) + + def run_fused(): + fused_attention(q, k, v) + + kernel_us = do_bench(run_kernel, warmup_iters=warmup, benchmark_iters=iters, unit="us") + fused_us = do_bench(run_fused, warmup_iters=warmup, benchmark_iters=iters, unit="us") + + # One untimed correctness probe per length so silent miscompiles don't + # hide behind a passing benchmark. + run_kernel() + torch.npu.synchronize() + o_kernel = o.clone() + o_fused = fused_attention(q, k, v) + torch.npu.synchronize() + + qk_elems = q.shape[0] * k.shape[0] + use_host_ref = qk_elems <= HOST_REF_MAX_QK_ELEMS + if use_host_ref: + o_golden = fa_reference(q, k, v) + err_kernel = (o_kernel.cpu().float() - o_golden.cpu()).abs().max().item() + err_fused = (o_fused.cpu().float() - o_golden.cpu()).abs().max().item() + torch.testing.assert_close(o_kernel.cpu().float(), o_golden.cpu(), rtol=RTOL, atol=ATOL) + ref_label = "host_fp32" + else: + # Host fp32 reference would need ~{qk_elems*4} bytes for the QK^T + # matrix; fall back to the NPU fused reference as the correctness + # baseline with a slightly looser tolerance. + err_kernel = (o_kernel.cpu().float() - o_fused.cpu().float()).abs().max().item() + err_fused = 0.0 + torch.testing.assert_close( + o_kernel.cpu().float(), o_fused.cpu().float(), rtol=RTOL_FUSED_ONLY, atol=ATOL_FUSED_ONLY + ) + ref_label = "npu_fused (host fp32 ref skipped, qk_elems={:.1e})".format(qk_elems) + + matmul_flops = 4 * fa_builder.Q_ROWS * fa_builder.HEAD * seq_len + attention_flops = attn_flops_matmul_softmax_scale(1, fa_builder.Q_ROWS, seq_len, fa_builder.HEAD) + kernel_ms = kernel_us / 1000.0 + fused_ms = fused_us / 1000.0 + return { + "seq_len": seq_len, + "num_tiles": num_tiles, + "kernel_us": kernel_us, + "fused_us": fused_us, + "kernel_gflops": matmul_flops / (kernel_us * 1e-6) / 1e9, + "fused_gflops": matmul_flops / (fused_us * 1e-6) / 1e9, + "kernel_tflops": tflops(attention_flops, kernel_ms), + "fused_tflops": tflops(attention_flops, fused_ms), + "speedup": fused_us / kernel_us, + "err_kernel": err_kernel, + "err_fused": err_fused, + "ref": ref_label, + } + + +def _print_row(r): + print( + " s0={s0:>6} s1={seq_len:>6} tiles={num_tiles:>3} " + "fa={kernel_us:8.2f}us ({kernel_gflops:7.1f} GF/s, {kernel_tflops:6.2f} TFLOP/s) " + "npu_fused_infer_attention={fused_us:8.2f}us " + "({fused_gflops:7.1f} GF/s, {fused_tflops:6.2f} TFLOP/s) " + "speedup={speedup:.2f}x " + "err: ours={err_kernel:.2e} npu_fused_infer_attention={err_fused:.2e} " + "ref={ref}".format(s0=fa_builder.Q_ROWS, **r) + ) + + +def _append_benchmark_summary(result, status="OK", note=""): + _append_summary_row( + { + "case_id": CASE_ID, + "q_rows": fa_builder.Q_ROWS, + "seq_len": result["seq_len"], + "tiles": result["num_tiles"], + "status": status, + "fa_us": "{:.2f}".format(result["kernel_us"]), + "fa_tflops": "{:.2f}".format(result["kernel_tflops"]), + "fused_us": "{:.2f}".format(result["fused_us"]), + "fused_tflops": "{:.2f}".format(result["fused_tflops"]), + "speedup": "{:.2f}".format(result["speedup"]), + "err_kernel": "{:.2e}".format(result["err_kernel"]), + "err_fused": "{:.2e}".format(result["err_fused"]), + "ref": result["ref"], + "note": note, + } + ) + + +def _parse_args(): + parser = argparse.ArgumentParser(description="Build/run the pto-dsl flash-attention benchmark.") + parser.add_argument( + "--no-build", action="store_true", help="skip compile.sh and load the existing build_artifacts/fa.so" + ) + parser.add_argument( + "--case", help="comma-separated default benchmark cases to run (case1..case8); defaults to all cases" + ) + return parser.parse_args() + + +def _run_default_suite(args): + cases = _selected_default_cases(args.case) + if not cases: + raise ValueError("no benchmark cases selected") + if args.no_build and len(cases) != 1: + raise ValueError( + "--no-build is only valid for a single selected case because build_artifacts/fa.so is rebuilt per FA_Q_ROWS" + ) + + summary_tsv = SUMMARY_TSV or _default_summary_tsv() + print("[fa] running default benchmark suite from fa_benchmark_shapes.md") + print("[fa] selected cases: {}".format(", ".join(case_id for case_id, _, _ in cases))) + print("[fa] summary TSV: {}".format(summary_tsv)) + + for case_id, s0, s1 in cases: + print("\n{:=^110}".format(" {} S0={} S1={} ".format(case_id, s0, s1))) + env = os.environ.copy() + env.update( + {"FA_CASE_ID": case_id, "FA_Q_ROWS": str(s0), "FA_BENCH_LENGTHS": str(s1), "FA_SUMMARY_TSV": summary_tsv} + ) + cmd = [sys.executable, os.path.abspath(__file__)] + if args.no_build: + cmd.append("--no-build") + subprocess.run(cmd, cwd=THIS_DIR, env=env, check=True) + + print("\n[fa] default benchmark suite done.") + print("[fa] summary TSV: {}".format(summary_tsv)) + + +def main(): + args = _parse_args() + if _default_suite_requested(args): + _run_default_suite(args) + return + + device = get_test_device() + torch.npu.set_device(device) + + lengths = _bench_lengths() + targets = [(L, _num_tiles(L)) for L in lengths] + lib_path = _require_lib() if args.no_build else _build_lib() + lib = _load_lib(lib_path) + + warmup = _bench_warmup() + iters = _bench_iters() + + print("\n{:=^110}".format(" Benchmark (pto-dsl fa) ")) + print(" " + _manual_target_summary()) + print(" " + _dsl_effective_summary()) + print(" same host-visible shape and QK_PRELOAD as the manual non-causal S0=128 path") + print(" TFLOP/s counts matmul + scale + softmax operations, matching 140tflops/run.py") + print(" reference kernel: torch_npu.npu_fused_infer_attention_score") + print(" host fp32 reference is skipped when Q_ROWS*S1 > {}".format(HOST_REF_MAX_QK_ELEMS)) + print(" cores={} warmup={} iters={}".format(get_num_cube_cores(), warmup, iters)) + print(" S0(Q_ROWS)={} S1 lengths: {}".format(fa_builder.Q_ROWS, list(lengths))) + print("-" * 110) + + for _, nt in targets: + result = benchmark(lib, device, num_tiles=nt, warmup=warmup, iters=iters) + _print_row(result) + _append_benchmark_summary(result) + + print("=" * 110) + + +if __name__ == "__main__": + main()